diff options
11 files changed, 393 insertions, 302 deletions
diff --git a/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp b/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp index eab0cbe69d6..5fb3dba26a7 100644 --- a/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp +++ b/parallel-libs/streamexecutor/examples/CUDASaxpy.cpp @@ -115,6 +115,11 @@ int main() { cg::SaxpyKernel Kernel = getOrDie(Device->createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec)); + se::RegisteredHostMemory<float> RegisteredX = + getOrDie(Device->registerHostMemory<float>(HostX)); + se::RegisteredHostMemory<float> RegisteredY = + getOrDie(Device->registerHostMemory<float>(HostY)); + // Allocate memory on the device. se::GlobalDeviceMemory<float> X = getOrDie(Device->allocateDeviceMemory<float>(ArraySize)); @@ -123,10 +128,10 @@ int main() { // Run operations on a stream. se::Stream Stream = getOrDie(Device->createStream()); - Stream.thenCopyH2D<float>(HostX, X) - .thenCopyH2D<float>(HostY, Y) + Stream.thenCopyH2D(RegisteredX, X) + .thenCopyH2D(RegisteredY, Y) .thenLaunch(ArraySize, 1, Kernel, A, X, Y) - .thenCopyD2H<float>(X, HostX); + .thenCopyD2H(X, RegisteredX); // Wait for the stream to complete. se::dieIfError(Stream.blockHostUntilDone()); diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h index 3e8c2c892ff..83840e82d01 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h @@ -18,6 +18,7 @@ #include <type_traits> #include "streamexecutor/Error.h" +#include "streamexecutor/HostMemory.h" #include "streamexecutor/KernelSpec.h" #include "streamexecutor/PlatformDevice.h" @@ -58,36 +59,19 @@ public: return GlobalDeviceMemory<T>(this, *MaybeMemory, ElementCount); } - /// Allocates an array of ElementCount entries of type T in host memory. - /// - /// Host memory allocated by this function can be used for asynchronous memory - /// copies on streams. See Stream::thenCopyD2H and Stream::thenCopyH2D. - template <typename T> Expected<T *> allocateHostMemory(size_t ElementCount) { - Expected<void *> MaybeMemory = - PDevice->allocateHostMemory(ElementCount * sizeof(T)); - if (!MaybeMemory) - return MaybeMemory.takeError(); - return static_cast<T *>(*MaybeMemory); - } - - /// Frees memory previously allocated with allocateHostMemory. - template <typename T> Error freeHostMemory(T *Memory) { - return PDevice->freeHostMemory(Memory); - } - /// Registers a previously allocated host array of type T for asynchronous /// memory operations. /// /// Host memory registered by this function can be used for asynchronous /// memory copies on streams. See Stream::thenCopyD2H and Stream::thenCopyH2D. template <typename T> - Error registerHostMemory(T *Memory, size_t ElementCount) { - return PDevice->registerHostMemory(Memory, ElementCount * sizeof(T)); - } - - /// Unregisters host memory previously registered by registerHostMemory. - template <typename T> Error unregisterHostMemory(T *Memory) { - return PDevice->unregisterHostMemory(Memory); + Expected<RegisteredHostMemory<T>> + registerHostMemory(llvm::MutableArrayRef<T> Memory) { + if (Error E = PDevice->registerHostMemory(Memory.data(), + Memory.size() * sizeof(T))) { + return std::move(E); + } + return RegisteredHostMemory<T>(this, Memory.data(), Memory.size()); } /// \anchor DeviceHostSyncCopyGroup @@ -98,9 +82,8 @@ public: /// device calls. /// /// There are no restrictions on the host memory that is used as a source or - /// destination in these copy methods, so there is no need to allocate that - /// host memory using allocateHostMemory or register it with - /// registerHostMemory. + /// destination in these copy methods, so there is no need to register that + /// host memory with registerHostMemory. /// /// Each of these methods has a single template parameter, T, that specifies /// the type of data being copied. The ElementCount arguments specify the @@ -303,6 +286,12 @@ private: return PDevice->freeDeviceMemory(Memory.getHandle()); } + // Only destroyRegisteredHostMemoryInternals may unregister host memory. + friend void internal::destroyRegisteredHostMemoryInternals(Device *, void *); + Error unregisterHostMemory(const void *Pointer) { + return PDevice->unregisterHostMemory(Pointer); + } + PlatformDevice *PDevice; }; diff --git a/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h b/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h index 21931512b0b..ee4ebb621e9 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h @@ -46,6 +46,8 @@ template <typename ElemT> class GlobalDeviceMemory; /// memory, and an element count for the size of the slice. template <typename ElemT> class GlobalDeviceMemorySlice { public: + using ElementTy = ElemT; + /// Intentionally implicit so GlobalDeviceMemory<T> can be passed to functions /// expecting GlobalDeviceMemorySlice<T> arguments. GlobalDeviceMemorySlice(const GlobalDeviceMemory<ElemT> &Memory) @@ -171,6 +173,8 @@ protected: template <typename ElemT> class GlobalDeviceMemory : public GlobalDeviceMemoryBase { public: + using ElementTy = ElemT; + GlobalDeviceMemory(GlobalDeviceMemory &&Other) = default; GlobalDeviceMemory &operator=(GlobalDeviceMemory &&Other) = default; diff --git a/parallel-libs/streamexecutor/include/streamexecutor/HostMemory.h b/parallel-libs/streamexecutor/include/streamexecutor/HostMemory.h new file mode 100644 index 00000000000..2e8e961aca1 --- /dev/null +++ b/parallel-libs/streamexecutor/include/streamexecutor/HostMemory.h @@ -0,0 +1,185 @@ +//===-- HostMemory.h - Types for registered host memory ---------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// +/// This file defines types that represent registered host memory buffers. Host +/// memory must be registered to participate in asynchronous copies to or from +/// device memory. +/// +//===----------------------------------------------------------------------===// + +#ifndef STREAMEXECUTOR_HOSTMEMORY_H +#define STREAMEXECUTOR_HOSTMEMORY_H + +#include <cassert> +#include <cstddef> +#include <type_traits> + +#include "llvm/ADT/ArrayRef.h" + +namespace streamexecutor { + +class Device; +template <typename ElemT> class RegisteredHostMemory; + +/// A mutable slice of registered host memory. +/// +/// The memory is registered in the sense of +/// streamexecutor::Device::registerHostMemory. +/// +/// Holds a reference to an underlying registered host memory buffer. Must not +/// be used after the underlying buffer is freed or unregistered. +template <typename ElemT> class MutableRegisteredHostMemorySlice { +public: + using ElementTy = ElemT; + + MutableRegisteredHostMemorySlice(RegisteredHostMemory<ElemT> &Registered) + : MutableArrayRef(Registered.getPointer(), Registered.getElementCount()) { + } + + ElemT *getPointer() const { return MutableArrayRef.data(); } + size_t getElementCount() const { return MutableArrayRef.size(); } + + /// Chops off the first N elements of the slice. + MutableRegisteredHostMemorySlice slice(size_t N) const { + return MutableRegisteredHostMemorySlice(MutableArrayRef.slice(N)); + } + + /// Chops off the first N elements of the slice and keeps the next M elements. + MutableRegisteredHostMemorySlice slice(size_t N, size_t M) const { + return MutableRegisteredHostMemorySlice(MutableArrayRef.slice(N, M)); + } + + /// Chops off the last N elements of the slice. + MutableRegisteredHostMemorySlice drop_back(size_t N) const { + return MutableRegisteredHostMemorySlice(MutableArrayRef.drop_back(N)); + } + +private: + MutableRegisteredHostMemorySlice(llvm::MutableArrayRef<ElemT> MutableArrayRef) + : MutableArrayRef(MutableArrayRef) {} + + llvm::MutableArrayRef<ElemT> MutableArrayRef; +}; + +/// An immutable slice of registered host memory. +/// +/// The memory is registered in the sense of +/// streamexecutor::Device::registerHostMemory. +/// +/// Holds a reference to an underlying registered host memory buffer. Must not +/// be used after the underlying buffer is freed or unregistered. +template <typename ElemT> class RegisteredHostMemorySlice { +public: + using ElementTy = ElemT; + + RegisteredHostMemorySlice(const RegisteredHostMemory<ElemT> &Registered) + : ArrayRef(Registered.getPointer(), Registered.getElementCount()) {} + + RegisteredHostMemorySlice( + MutableRegisteredHostMemorySlice<ElemT> MutableSlice) + : ArrayRef(MutableSlice.getPointer(), MutableSlice.getElementCount()) {} + + const ElemT *getPointer() const { return ArrayRef.data(); } + size_t getElementCount() const { return ArrayRef.size(); } + + /// Chops off the first N elements of the slice. + RegisteredHostMemorySlice slice(size_t N) const { + return RegisteredHostMemorySlice(ArrayRef.slice(N)); + } + + /// Chops off the first N elements of the slice and keeps the next M elements. + RegisteredHostMemorySlice slice(size_t N, size_t M) const { + return RegisteredHostMemorySlice(ArrayRef.slice(N, M)); + } + + /// Chops off the last N elements of the slice. + RegisteredHostMemorySlice drop_back(size_t N) const { + return RegisteredHostMemorySlice(ArrayRef.drop_back(N)); + } + +private: + llvm::ArrayRef<ElemT> ArrayRef; +}; + +namespace internal { + +/// Helper function to unregister host memory. +/// +/// This is a thin wrapper around streamexecutor::Device::unregisterHostMemory. +/// It is defined so this operation can be performed from the destructor of the +/// template class RegisteredHostMemory without including Device.h in this +/// header and creating a header inclusion cycle. +void destroyRegisteredHostMemoryInternals(Device *TheDevice, void *Pointer); + +} // namespace internal + +/// Registered host memory that knows how to unregister itself upon destruction. +/// +/// The memory is registered in the sense of +/// streamexecutor::Device::registerHostMemory. +/// +/// ElemT is the type of element stored in the host buffer. +template <typename ElemT> class RegisteredHostMemory { +public: + using ElementTy = ElemT; + + RegisteredHostMemory(Device *TheDevice, ElemT *Pointer, size_t ElementCount) + : TheDevice(TheDevice), Pointer(Pointer), ElementCount(ElementCount) { + assert(TheDevice != nullptr && "cannot construct a " + "RegisteredHostMemoryBase with a null " + "platform device"); + } + + RegisteredHostMemory(const RegisteredHostMemory &) = delete; + RegisteredHostMemory &operator=(const RegisteredHostMemory &) = delete; + + RegisteredHostMemory(RegisteredHostMemory &&Other) + : TheDevice(Other.TheDevice), Pointer(Other.Pointer), + ElementCount(Other.ElementCount) { + Other.TheDevice = nullptr; + Other.Pointer = nullptr; + } + + RegisteredHostMemory &operator=(RegisteredHostMemory &&Other) { + TheDevice = Other.TheDevice; + Pointer = Other.Pointer; + ElementCount = Other.ElementCount; + Other.TheDevice = nullptr; + Other.Pointer = nullptr; + } + + ~RegisteredHostMemory() { + internal::destroyRegisteredHostMemoryInternals(TheDevice, Pointer); + } + + ElemT *getPointer() { return static_cast<ElemT *>(Pointer); } + const ElemT *getPointer() const { return static_cast<ElemT *>(Pointer); } + size_t getElementCount() const { return ElementCount; } + + /// Creates an immutable slice for the entire contents of this memory. + RegisteredHostMemorySlice<ElemT> asSlice() const { + return RegisteredHostMemorySlice<ElemT>(*this); + } + + /// Creates a mutable slice for the entire contents of this memory. + MutableRegisteredHostMemorySlice<ElemT> asSlice() { + return MutableRegisteredHostMemorySlice<ElemT>(*this); + } + +private: + Device *TheDevice; + void *Pointer; + size_t ElementCount; +}; + +} // namespace streamexecutor + +#endif // STREAMEXECUTOR_HOSTMEMORY_H diff --git a/parallel-libs/streamexecutor/include/streamexecutor/PlatformDevice.h b/parallel-libs/streamexecutor/include/streamexecutor/PlatformDevice.h index 6437760203b..cc1ae405bbb 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/PlatformDevice.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/PlatformDevice.h @@ -68,8 +68,7 @@ public: /// Copies data from the device to the host. /// - /// HostDst should have been allocated by allocateHostMemory or registered - /// with registerHostMemory. + /// HostDst should have been registered with registerHostMemory. virtual Error copyD2H(const void *PlatformStreamHandle, const void *DeviceSrcHandle, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, size_t ByteCount) { @@ -78,8 +77,7 @@ public: /// Copies data from the host to the device. /// - /// HostSrc should have been allocated by allocateHostMemory or registered - /// with registerHostMemory. + /// HostSrc should have been registered with registerHostMemory. virtual Error copyH2D(const void *PlatformStreamHandle, const void *HostSrc, size_t SrcByteOffset, const void *DeviceDstHandle, size_t DstByteOffset, size_t ByteCount) { @@ -113,21 +111,6 @@ public: getName()); } - /// Allocates untyped host memory of a given size in bytes. - /// - /// Host memory allocated via this method is suitable for use with copyH2D and - /// copyD2H. - virtual Expected<void *> allocateHostMemory(size_t ByteCount) { - return make_error("allocateHostMemory not implemented for platform " + - getName()); - } - - /// Frees host memory allocated by allocateHostMemory. - virtual Error freeHostMemory(void *Memory) { - return make_error("freeHostMemory not implemented for platform " + - getName()); - } - /// Registers previously allocated host memory so it can be used with copyH2D /// and copyD2H. virtual Error registerHostMemory(void *Memory, size_t ByteCount) { @@ -136,7 +119,7 @@ public: } /// Unregisters host memory previously registered with registerHostMemory. - virtual Error unregisterHostMemory(void *Memory) { + virtual Error unregisterHostMemory(const void *Memory) { return make_error("unregisterHostMemory not implemented for platform " + getName()); } @@ -144,8 +127,8 @@ public: /// Copies the given number of bytes from device memory to host memory. /// /// Blocks the calling host thread until the copy is completed. Can operate on - /// any host memory, not just registered host memory or host memory allocated - /// by allocateHostMemory. Does not block any ongoing device calls. + /// any host memory, not just registered host memory. Does not block any + /// ongoing device calls. virtual Error synchronousCopyD2H(const void *DeviceSrcHandle, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, size_t ByteCount) { diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h index 704a81538e0..c293464d364 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h @@ -33,9 +33,11 @@ #include <cassert> #include <memory> #include <string> +#include <type_traits> #include "streamexecutor/DeviceMemory.h" #include "streamexecutor/Error.h" +#include "streamexecutor/HostMemory.h" #include "streamexecutor/Kernel.h" #include "streamexecutor/LaunchDimensions.h" #include "streamexecutor/PackedKernelArgumentArray.h" @@ -118,202 +120,154 @@ public: /// These methods enqueue a device memory copy operation on the stream and /// return without waiting for the operation to complete. /// - /// Any host memory used as a source or destination for one of these - /// operations must be allocated with Device::allocateHostMemory or registered - /// with Device::registerHostMemory. Otherwise, the enqueuing operation may - /// block until the copy operation is fully complete. - /// /// The arguments and bounds checking for these methods match the API of the /// \ref DeviceHostSyncCopyGroup /// "host-synchronous device memory copying functions" of Device. + /// + /// The template types SrcTy and DstTy must match the following constraints: + /// * Must define typename ElementTy (the type of element stored in the + /// memory); + /// * ElementTy for the source argument must be the same as ElementTy for + /// the destination argument; + /// * Must be convertible to the correct slice type: + /// * GlobalDeviceMemorySlice<ElementTy> for device memory arguments, + /// * RegisteredHostMemorySlice<ElementTy> for host memory source + /// arguments, + /// * MutableRegisteredHostMemorySlice<ElementT> for host memory + /// destination arguments. ///@{ - template <typename T> - Stream &thenCopyD2H(GlobalDeviceMemorySlice<T> Src, - llvm::MutableArrayRef<T> Dst, size_t ElementCount) { + // D2H + + template <typename SrcTy, typename DstTy> + Stream &thenCopyD2H(SrcTy &&Src, DstTy &&Dst, size_t ElementCount) { + using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy; + using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy; + static_assert(std::is_same<SrcElemTy, DstElemTy>::value, + "src/dst element type mismatch for thenCopyD2H"); + GlobalDeviceMemorySlice<SrcElemTy> SrcSlice(Src); + MutableRegisteredHostMemorySlice<DstElemTy> DstSlice(Dst); if (ElementCount > Src.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", from a device array of element count " + - llvm::Twine(Src.getElementCount())); - else if (ElementCount > Dst.size()) + llvm::Twine(SrcSlice.getElementCount())); + else if (ElementCount > DstSlice.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + - ", to a host array of element count " + llvm::Twine(Dst.size())); + ", to a host array of element count " + + llvm::Twine(DstSlice.getElementCount())); else - setError(PDevice->copyD2H(PlatformStreamHandle, - Src.getBaseMemory().getHandle(), - Src.getElementOffset() * sizeof(T), Dst.data(), - 0, ElementCount * sizeof(T))); + setError(PDevice->copyD2H( + PlatformStreamHandle, SrcSlice.getBaseMemory().getHandle(), + SrcSlice.getElementOffset() * sizeof(SrcElemTy), + DstSlice.getPointer(), 0, ElementCount * sizeof(DstElemTy))); return *this; } - template <typename T> - Stream &thenCopyD2H(GlobalDeviceMemorySlice<T> Src, - llvm::MutableArrayRef<T> Dst) { - if (Src.getElementCount() != Dst.size()) + template <typename SrcTy, typename DstTy> + Stream &thenCopyD2H(SrcTy &&Src, DstTy &&Dst) { + using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy; + using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy; + static_assert(std::is_same<SrcElemTy, DstElemTy>::value, + "src/dst element type mismatch for thenCopyD2H"); + GlobalDeviceMemorySlice<SrcElemTy> SrcSlice(Src); + MutableRegisteredHostMemorySlice<DstElemTy> DstSlice(Dst); + if (SrcSlice.getElementCount() != DstSlice.getElementCount()) setError("array size mismatch for D2H, device source has element count " + - llvm::Twine(Src.getElementCount()) + + llvm::Twine(SrcSlice.getElementCount()) + " but host destination has element count " + - llvm::Twine(Dst.size())); + llvm::Twine(DstSlice.getElementCount())); else - thenCopyD2H(Src, Dst, Src.getElementCount()); - return *this; - } - - template <typename T> - Stream &thenCopyD2H(GlobalDeviceMemorySlice<T> Src, T *Dst, - size_t ElementCount) { - thenCopyD2H(Src, llvm::MutableArrayRef<T>(Dst, ElementCount), ElementCount); - return *this; - } - - template <typename T> - Stream &thenCopyD2H(const GlobalDeviceMemory<T> &Src, - llvm::MutableArrayRef<T> Dst, size_t ElementCount) { - thenCopyD2H(Src.asSlice(), Dst, ElementCount); - return *this; - } - - template <typename T> - Stream &thenCopyD2H(const GlobalDeviceMemory<T> &Src, - llvm::MutableArrayRef<T> Dst) { - thenCopyD2H(Src.asSlice(), Dst); + thenCopyD2H(SrcSlice, DstSlice, SrcSlice.getElementCount()); return *this; } - template <typename T> - Stream &thenCopyD2H(const GlobalDeviceMemory<T> &Src, T *Dst, - size_t ElementCount) { - thenCopyD2H(Src.asSlice(), Dst, ElementCount); - return *this; - } - - template <typename T> - Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemorySlice<T> Dst, - size_t ElementCount) { - if (ElementCount > Src.size()) + // H2D + + template <typename SrcTy, typename DstTy> + Stream &thenCopyH2D(SrcTy &&Src, DstTy &&Dst, size_t ElementCount) { + using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy; + using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy; + static_assert(std::is_same<SrcElemTy, DstElemTy>::value, + "src/dst element type mismatch for thenCopyH2D"); + RegisteredHostMemorySlice<SrcElemTy> SrcSlice(Src); + GlobalDeviceMemorySlice<DstElemTy> DstSlice(Dst); + if (ElementCount > SrcSlice.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", from a host array of element count " + - llvm::Twine(Src.size())); - else if (ElementCount > Dst.getElementCount()) + llvm::Twine(SrcSlice.getElementCount())); + else if (ElementCount > DstSlice.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", to a device array of element count " + - llvm::Twine(Dst.getElementCount())); + llvm::Twine(DstSlice.getElementCount())); else - setError(PDevice->copyH2D( - PlatformStreamHandle, Src.data(), 0, Dst.getBaseMemory().getHandle(), - Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T))); + setError(PDevice->copyH2D(PlatformStreamHandle, SrcSlice.getPointer(), 0, + DstSlice.getBaseMemory().getHandle(), + DstSlice.getElementOffset() * sizeof(DstElemTy), + ElementCount * sizeof(SrcElemTy))); return *this; } - template <typename T> - Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemorySlice<T> Dst) { - if (Src.size() != Dst.getElementCount()) + template <typename SrcTy, typename DstTy> + Stream &thenCopyH2D(SrcTy &&Src, DstTy &&Dst) { + using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy; + using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy; + static_assert(std::is_same<SrcElemTy, DstElemTy>::value, + "src/dst element type mismatch for thenCopyH2D"); + RegisteredHostMemorySlice<SrcElemTy> SrcSlice(Src); + GlobalDeviceMemorySlice<DstElemTy> DstSlice(Dst); + if (SrcSlice.getElementCount() != DstSlice.getElementCount()) setError("array size mismatch for H2D, host source has element count " + - llvm::Twine(Src.size()) + + llvm::Twine(SrcSlice.getElementCount()) + " but device destination has element count " + - llvm::Twine(Dst.getElementCount())); + llvm::Twine(DstSlice.getElementCount())); else - thenCopyH2D(Src, Dst, Dst.getElementCount()); - return *this; - } - - template <typename T> - Stream &thenCopyH2D(T *Src, GlobalDeviceMemorySlice<T> Dst, - size_t ElementCount) { - thenCopyH2D(llvm::ArrayRef<T>(Src, ElementCount), Dst, ElementCount); - return *this; - } - - template <typename T> - Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemory<T> &Dst, - size_t ElementCount) { - thenCopyH2D(Src, Dst.asSlice(), ElementCount); - return *this; - } - - template <typename T> - Stream &thenCopyH2D(llvm::ArrayRef<T> Src, GlobalDeviceMemory<T> &Dst) { - thenCopyH2D(Src, Dst.asSlice()); - return *this; - } - - template <typename T> - Stream &thenCopyH2D(T *Src, GlobalDeviceMemory<T> &Dst, size_t ElementCount) { - thenCopyH2D(Src, Dst.asSlice(), ElementCount); + thenCopyH2D(SrcSlice, DstSlice, DstSlice.getElementCount()); return *this; } - template <typename T> - Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src, - GlobalDeviceMemorySlice<T> Dst, size_t ElementCount) { - if (ElementCount > Src.getElementCount()) + // D2D + + template <typename SrcTy, typename DstTy> + Stream &thenCopyD2D(SrcTy &&Src, DstTy &&Dst, size_t ElementCount) { + using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy; + using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy; + static_assert(std::is_same<SrcElemTy, DstElemTy>::value, + "src/dst element type mismatch for thenCopyD2D"); + GlobalDeviceMemorySlice<SrcElemTy> SrcSlice(Src); + GlobalDeviceMemorySlice<DstElemTy> DstSlice(Dst); + if (ElementCount > SrcSlice.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", from a device array of element count " + - llvm::Twine(Src.getElementCount())); - else if (ElementCount > Dst.getElementCount()) + llvm::Twine(SrcSlice.getElementCount())); + else if (ElementCount > DstSlice.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", to a device array of element count " + - llvm::Twine(Dst.getElementCount())); + llvm::Twine(DstSlice.getElementCount())); else - setError(PDevice->copyD2D( - PlatformStreamHandle, Src.getBaseMemory().getHandle(), - Src.getElementOffset() * sizeof(T), Dst.getBaseMemory().getHandle(), - Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T))); + setError(PDevice->copyD2D(PlatformStreamHandle, + SrcSlice.getBaseMemory().getHandle(), + SrcSlice.getElementOffset() * sizeof(SrcElemTy), + DstSlice.getBaseMemory().getHandle(), + DstSlice.getElementOffset() * sizeof(DstElemTy), + ElementCount * sizeof(SrcElemTy))); return *this; } - template <typename T> - Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src, - GlobalDeviceMemorySlice<T> Dst) { - if (Src.getElementCount() != Dst.getElementCount()) + template <typename SrcTy, typename DstTy> + Stream &thenCopyD2D(SrcTy &&Src, DstTy &&Dst) { + using SrcElemTy = typename std::remove_reference<SrcTy>::type::ElementTy; + using DstElemTy = typename std::remove_reference<DstTy>::type::ElementTy; + static_assert(std::is_same<SrcElemTy, DstElemTy>::value, + "src/dst element type mismatch for thenCopyD2D"); + GlobalDeviceMemorySlice<SrcElemTy> SrcSlice(Src); + GlobalDeviceMemorySlice<DstElemTy> DstSlice(Dst); + if (SrcSlice.getElementCount() != DstSlice.getElementCount()) setError("array size mismatch for D2D, device source has element count " + - llvm::Twine(Src.getElementCount()) + + llvm::Twine(SrcSlice.getElementCount()) + " but device destination has element count " + - llvm::Twine(Dst.getElementCount())); + llvm::Twine(DstSlice.getElementCount())); else - thenCopyD2D(Src, Dst, Src.getElementCount()); - return *this; - } - - template <typename T> - Stream &thenCopyD2D(const GlobalDeviceMemory<T> &Src, - GlobalDeviceMemorySlice<T> Dst, size_t ElementCount) { - thenCopyD2D(Src.asSlice(), Dst, ElementCount); - return *this; - } - - template <typename T> - Stream &thenCopyD2D(const GlobalDeviceMemory<T> &Src, - GlobalDeviceMemorySlice<T> Dst) { - thenCopyD2D(Src.asSlice(), Dst); - return *this; - } - - template <typename T> - Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src, - GlobalDeviceMemory<T> &Dst, size_t ElementCount) { - thenCopyD2D(Src, Dst.asSlice(), ElementCount); - return *this; - } - - template <typename T> - Stream &thenCopyD2D(GlobalDeviceMemorySlice<T> Src, - GlobalDeviceMemory<T> &Dst) { - thenCopyD2D(Src, Dst.asSlice()); - return *this; - } - - template <typename T> - Stream &thenCopyD2D(const GlobalDeviceMemory<T> &Src, - GlobalDeviceMemory<T> &Dst, size_t ElementCount) { - thenCopyD2D(Src.asSlice(), Dst.asSlice(), ElementCount); - return *this; - } - - template <typename T> - Stream &thenCopyD2D(const GlobalDeviceMemory<T> &Src, - GlobalDeviceMemory<T> &Dst) { - thenCopyD2D(Src.asSlice(), Dst.asSlice()); + thenCopyD2D(SrcSlice, DstSlice, SrcSlice.getElementCount()); return *this; } diff --git a/parallel-libs/streamexecutor/lib/CMakeLists.txt b/parallel-libs/streamexecutor/lib/CMakeLists.txt index be94cbabc46..fb3c0482762 100644 --- a/parallel-libs/streamexecutor/lib/CMakeLists.txt +++ b/parallel-libs/streamexecutor/lib/CMakeLists.txt @@ -8,6 +8,7 @@ add_se_library( Device.cpp DeviceMemory.cpp Error.cpp + HostMemory.cpp Kernel.cpp KernelSpec.cpp PackedKernelArgumentArray.cpp diff --git a/parallel-libs/streamexecutor/lib/HostMemory.cpp b/parallel-libs/streamexecutor/lib/HostMemory.cpp new file mode 100644 index 00000000000..f7fbe044aa3 --- /dev/null +++ b/parallel-libs/streamexecutor/lib/HostMemory.cpp @@ -0,0 +1,29 @@ +//===-- HostMemory.cpp - HostMemory implementation ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Implementation of HostMemory internals. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/HostMemory.h" +#include "streamexecutor/Device.h" + +namespace streamexecutor { +namespace internal { + +void destroyRegisteredHostMemoryInternals(Device *TheDevice, void *Pointer) { + // TODO(jhen): How to handle errors here? + if (Pointer) { + consumeError(TheDevice->unregisterHostMemory(Pointer)); + } +} + +} // namespace internal +} // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/unittests/CoreTests/DeviceTest.cpp b/parallel-libs/streamexecutor/unittests/CoreTests/DeviceTest.cpp index 5b16c3c865c..897399c02a8 100644 --- a/parallel-libs/streamexecutor/unittests/CoreTests/DeviceTest.cpp +++ b/parallel-libs/streamexecutor/unittests/CoreTests/DeviceTest.cpp @@ -84,16 +84,11 @@ TEST_F(DeviceTest, AllocateAndFreeDeviceMemory) { EXPECT_TRUE(static_cast<bool>(MaybeMemory)); } -TEST_F(DeviceTest, AllocateAndFreeHostMemory) { - se::Expected<int *> MaybeMemory = Device.allocateHostMemory<int>(10); - EXPECT_TRUE(static_cast<bool>(MaybeMemory)); - EXPECT_NO_ERROR(Device.freeHostMemory(*MaybeMemory)); -} - TEST_F(DeviceTest, RegisterAndUnregisterHostMemory) { std::vector<int> Data(10); - EXPECT_NO_ERROR(Device.registerHostMemory(Data.data(), 10)); - EXPECT_NO_ERROR(Device.unregisterHostMemory(Data.data())); + se::Expected<se::RegisteredHostMemory<int>> MaybeMemory = + Device.registerHostMemory<int>(Data); + EXPECT_TRUE(static_cast<bool>(MaybeMemory)); } // D2H tests diff --git a/parallel-libs/streamexecutor/unittests/CoreTests/SimpleHostPlatformDevice.h b/parallel-libs/streamexecutor/unittests/CoreTests/SimpleHostPlatformDevice.h index 5c5953098c4..60064804fbe 100644 --- a/parallel-libs/streamexecutor/unittests/CoreTests/SimpleHostPlatformDevice.h +++ b/parallel-libs/streamexecutor/unittests/CoreTests/SimpleHostPlatformDevice.h @@ -48,22 +48,12 @@ public: return streamexecutor::Error::success(); } - streamexecutor::Expected<void *> - allocateHostMemory(size_t ByteCount) override { - return std::malloc(ByteCount); - } - - streamexecutor::Error freeHostMemory(void *Memory) override { - std::free(const_cast<void *>(Memory)); - return streamexecutor::Error::success(); - } - streamexecutor::Error registerHostMemory(void *Memory, size_t ByteCount) override { return streamexecutor::Error::success(); } - streamexecutor::Error unregisterHostMemory(void *Memory) override { + streamexecutor::Error unregisterHostMemory(const void *Memory) override { return streamexecutor::Error::success(); } diff --git a/parallel-libs/streamexecutor/unittests/CoreTests/StreamTest.cpp b/parallel-libs/streamexecutor/unittests/CoreTests/StreamTest.cpp index 65598540d67..34516e2c2da 100644 --- a/parallel-libs/streamexecutor/unittests/CoreTests/StreamTest.cpp +++ b/parallel-libs/streamexecutor/unittests/CoreTests/StreamTest.cpp @@ -39,6 +39,10 @@ public: HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23}, Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35}, + RegisteredHost5(getOrDie( + Device.registerHostMemory(llvm::MutableArrayRef<int>(Host5)))), + RegisteredHost7(getOrDie( + Device.registerHostMemory(llvm::MutableArrayRef<int>(Host7)))), DeviceA5(getOrDie(Device.allocateDeviceMemory<int>(5))), DeviceB5(getOrDie(Device.allocateDeviceMemory<int>(5))), DeviceA7(getOrDie(Device.allocateDeviceMemory<int>(7))), @@ -66,6 +70,9 @@ protected: int Host5[5]; int Host7[7]; + se::RegisteredHostMemory<int> RegisteredHost5; + se::RegisteredHostMemory<int> RegisteredHost7; + // Device memory. se::GlobalDeviceMemory<int> DeviceA5; se::GlobalDeviceMemory<int> DeviceB5; @@ -78,166 +85,119 @@ using llvm::MutableArrayRef; // D2H tests -TEST_F(StreamTest, CopyD2HToMutableArrayRefByCount) { - Stream.thenCopyD2H(DeviceA5, MutableArrayRef<int>(Host5), 5); +TEST_F(StreamTest, CopyD2HToRegisteredRefByCount) { + Stream.thenCopyD2H(DeviceA5, RegisteredHost5, 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(HostA5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceB5, MutableArrayRef<int>(Host5), 2); + Stream.thenCopyD2H(DeviceB5, RegisteredHost5, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(HostB5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceA7, MutableArrayRef<int>(Host5), 7); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyD2HToMutableArrayRef) { - Stream.thenCopyD2H(DeviceA5, MutableArrayRef<int>(Host5)); - EXPECT_TRUE(Stream.isOK()); - for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); - } - - Stream.thenCopyD2H(DeviceA5, MutableArrayRef<int>(Host7)); + Stream.thenCopyD2H(DeviceA7, RegisteredHost5, 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyD2HToPointer) { - Stream.thenCopyD2H(DeviceA5, Host5, 5); +TEST_F(StreamTest, CopyD2HToRegistered) { + Stream.thenCopyD2H(DeviceA5, RegisteredHost5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(HostA5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceA5, Host7, 7); + Stream.thenCopyD2H(DeviceA5, RegisteredHost7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyD2HSliceToMutableArrayRefByCount) { +TEST_F(StreamTest, CopyD2HSliceToRegiseredSliceByCount) { Stream.thenCopyD2H(DeviceA5.asSlice().drop_front(1), - MutableArrayRef<int>(Host5 + 1, 4), 4); + RegisteredHost5.asSlice().slice(1, 4), 4); EXPECT_TRUE(Stream.isOK()); for (int I = 1; I < 5; ++I) { EXPECT_EQ(HostA5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceB5.asSlice().drop_back(1), - MutableArrayRef<int>(Host5), 2); + Stream.thenCopyD2H(DeviceB5.asSlice().drop_back(1), RegisteredHost5, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(HostB5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceA5.asSlice(), MutableArrayRef<int>(Host7), 7); + Stream.thenCopyD2H(DeviceA5.asSlice(), RegisteredHost7, 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyD2HSliceToMutableArrayRef) { - Stream.thenCopyD2H(DeviceA7.asSlice().slice(1, 5), - MutableArrayRef<int>(Host5)); +TEST_F(StreamTest, CopyD2HSliceToRegistered) { + Stream.thenCopyD2H(DeviceA7.asSlice().slice(1, 5), RegisteredHost5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(HostA7[I + 1], Host5[I]); } - Stream.thenCopyD2H(DeviceA5.asSlice(), MutableArrayRef<int>(Host7)); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyD2HSliceToPointer) { - Stream.thenCopyD2H(DeviceA5.asSlice().drop_front(1), Host5 + 1, 4); - EXPECT_TRUE(Stream.isOK()); - for (int I = 1; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); - } - - Stream.thenCopyD2H(DeviceA5.asSlice(), Host7, 7); + Stream.thenCopyD2H(DeviceA5.asSlice(), RegisteredHost7); EXPECT_FALSE(Stream.isOK()); } // H2D tests -TEST_F(StreamTest, CopyH2DToArrayRefByCount) { - Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceA5, 5); +TEST_F(StreamTest, CopyH2DFromRegisterdByCount) { + Stream.thenCopyH2D(RegisteredHost5, DeviceA5, 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceB5, 2); + Stream.thenCopyH2D(RegisteredHost5, DeviceB5, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(getDeviceValue(DeviceB5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef<int>(Host7), DeviceA5, 7); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyH2DToArrayRef) { - Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceA5); - EXPECT_TRUE(Stream.isOK()); - for (int I = 0; I < 5; ++I) { - EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); - } - - Stream.thenCopyH2D(ArrayRef<int>(Host7), DeviceA5); + Stream.thenCopyH2D(RegisteredHost7, DeviceA5, 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyH2DToPointer) { - Stream.thenCopyH2D(Host5, DeviceA5, 5); +TEST_F(StreamTest, CopyH2DFromRegistered) { + Stream.thenCopyH2D(RegisteredHost5, DeviceA5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(Host7, DeviceA5, 7); + Stream.thenCopyH2D(RegisteredHost7, DeviceA5); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyH2DSliceToArrayRefByCount) { - Stream.thenCopyH2D(ArrayRef<int>(Host5 + 1, 4), +TEST_F(StreamTest, CopyH2DFromRegisteredSliceToSlice) { + Stream.thenCopyH2D(RegisteredHost5.asSlice().slice(1, 4), DeviceA5.asSlice().drop_front(1), 4); EXPECT_TRUE(Stream.isOK()); for (int I = 1; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceB5.asSlice().drop_back(1), 2); + Stream.thenCopyH2D(RegisteredHost5, DeviceB5.asSlice().drop_back(1), 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(getDeviceValue(DeviceB5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceA5.asSlice(), 7); + Stream.thenCopyH2D(RegisteredHost5, DeviceA5.asSlice(), 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyH2DSliceToArrayRef) { - - Stream.thenCopyH2D(ArrayRef<int>(Host5), DeviceA5.asSlice()); +TEST_F(StreamTest, CopyH2DRegisteredToSlice) { + Stream.thenCopyH2D(RegisteredHost5, DeviceA5.asSlice()); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef<int>(Host7), DeviceA5.asSlice()); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyH2DSliceToPointer) { - Stream.thenCopyH2D(Host5, DeviceA5.asSlice(), 5); - EXPECT_TRUE(Stream.isOK()); - for (int I = 0; I < 5; ++I) { - EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); - } - - Stream.thenCopyH2D(Host7, DeviceA5.asSlice(), 7); + Stream.thenCopyH2D(RegisteredHost7, DeviceA5.asSlice()); EXPECT_FALSE(Stream.isOK()); } @@ -289,7 +249,6 @@ TEST_F(StreamTest, CopySliceD2DByCount) { } TEST_F(StreamTest, CopySliceD2D) { - Stream.thenCopyD2D(DeviceA7.asSlice().drop_back(2), DeviceB5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { @@ -318,7 +277,6 @@ TEST_F(StreamTest, CopyD2DSliceByCount) { } TEST_F(StreamTest, CopyD2DSlice) { - Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice().drop_back(2)); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { @@ -330,7 +288,6 @@ TEST_F(StreamTest, CopyD2DSlice) { } TEST_F(StreamTest, CopySliceD2DSliceByCount) { - Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice(), 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { @@ -348,7 +305,6 @@ TEST_F(StreamTest, CopySliceD2DSliceByCount) { } TEST_F(StreamTest, CopySliceD2DSlice) { - Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice()); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { |