diff options
author | Jason Henline <jhen@google.com> | 2016-09-01 18:35:37 +0000 |
---|---|---|
committer | Jason Henline <jhen@google.com> | 2016-09-01 18:35:37 +0000 |
commit | 2d2da4b5d51063f5cf4914fdfff268b4884a9c77 (patch) | |
tree | 4fa21e8ed3353a9b9b93682db5f83546a48da2b5 /parallel-libs | |
parent | f2f148636c8ef954c6eca8b207cb91b2a0ac7935 (diff) |
[SE] Make Stream movable
Summary:
The example code makes it clear that this is a much better design
decision.
Reviewers: jlebar
Subscribers: jprice, parallel_libs-commits
Differential Revision: https://reviews.llvm.org/D24142
Diffstat (limited to 'parallel-libs')
5 files changed, 17 insertions, 14 deletions
diff --git a/parallel-libs/streamexecutor/examples/Example.cpp b/parallel-libs/streamexecutor/examples/Example.cpp index af1994da415..8f42ffa0a3b 100644 --- a/parallel-libs/streamexecutor/examples/Example.cpp +++ b/parallel-libs/streamexecutor/examples/Example.cpp @@ -121,13 +121,13 @@ int main() { getOrDie(Device->allocateDeviceMemory<float>(ArraySize)); // Run operations on a stream. - std::unique_ptr<se::Stream> Stream = getOrDie(Device->createStream()); - Stream->thenCopyH2D<float>(HostX, X) + se::Stream Stream = getOrDie(Device->createStream()); + Stream.thenCopyH2D<float>(HostX, X) .thenCopyH2D<float>(HostY, Y) .thenLaunch(ArraySize, 1, *Kernel, A, X, Y) .thenCopyD2H<float>(X, HostX); // Wait for the stream to complete. - se::dieIfError(Stream->blockHostUntilDone()); + se::dieIfError(Stream.blockHostUntilDone()); // Process output data in HostX. std::vector<float> ExpectedX = {4, 47, 90, 133}; diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h index c37f9b1affb..48ecf22ae76 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h @@ -50,7 +50,8 @@ public: std::move(*MaybeKernelHandle)); } - Expected<std::unique_ptr<Stream>> createStream(); + /// Creates a stream object for this device. + Expected<Stream> createStream(); /// Allocates an array of ElementCount entries of type T in device memory. template <typename T> diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h index d1c82f9e5ea..1acb18139d8 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h @@ -61,19 +61,22 @@ class Stream { public: explicit Stream(std::unique_ptr<PlatformStreamHandle> PStream); + Stream(Stream &&Other) = default; + Stream &operator=(Stream &&Other) = default; + ~Stream(); /// Returns whether any error has occurred while entraining work on this /// stream. bool isOK() const { - llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex); + llvm::sys::ScopedReader ReaderLock(*ErrorMessageMutex); return !ErrorMessage; } /// Returns the status created by the first error that occurred while /// entraining work on this stream. Error getStatus() const { - llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex); + llvm::sys::ScopedReader ReaderLock(*ErrorMessageMutex); if (ErrorMessage) return make_error(*ErrorMessage); else @@ -315,7 +318,7 @@ private: /// Does not overwrite the error if it is already set. void setError(Error &&E) { if (E) { - llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex); + llvm::sys::ScopedWriter WriterLock(*ErrorMessageMutex); if (!ErrorMessage) ErrorMessage = consumeAndGetMessage(std::move(E)); } @@ -325,7 +328,7 @@ private: /// /// Does not overwrite the error if it is already set. void setError(llvm::Twine Message) { - llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex); + llvm::sys::ScopedWriter WriterLock(*ErrorMessageMutex); if (!ErrorMessage) ErrorMessage = Message.str(); } @@ -337,9 +340,7 @@ private: std::unique_ptr<PlatformStreamHandle> ThePlatformStream; /// Mutex that guards the error state flags. - /// - /// Mutable so that it can be obtained via const reader lock. - mutable llvm::sys::RWMutex ErrorMessageMutex; + std::unique_ptr<llvm::sys::RWMutex> ErrorMessageMutex; /// First error message for an operation in this stream or empty if there have /// been no errors. diff --git a/parallel-libs/streamexecutor/lib/Device.cpp b/parallel-libs/streamexecutor/lib/Device.cpp index 4a5ec11997d..54f03849c68 100644 --- a/parallel-libs/streamexecutor/lib/Device.cpp +++ b/parallel-libs/streamexecutor/lib/Device.cpp @@ -27,7 +27,7 @@ Device::Device(PlatformDevice *PDevice) : PDevice(PDevice) {} Device::~Device() = default; -Expected<std::unique_ptr<Stream>> Device::createStream() { +Expected<Stream> Device::createStream() { Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream = PDevice->createStream(); if (!MaybePlatformStream) { @@ -35,7 +35,7 @@ Expected<std::unique_ptr<Stream>> Device::createStream() { } assert((*MaybePlatformStream)->getDevice() == PDevice && "an executor created a stream with a different stored executor"); - return llvm::make_unique<Stream>(std::move(*MaybePlatformStream)); + return Stream(std::move(*MaybePlatformStream)); } } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/Stream.cpp b/parallel-libs/streamexecutor/lib/Stream.cpp index 20a817c2715..e1fca58cc19 100644 --- a/parallel-libs/streamexecutor/lib/Stream.cpp +++ b/parallel-libs/streamexecutor/lib/Stream.cpp @@ -17,7 +17,8 @@ namespace streamexecutor { Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream) - : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)) {} + : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)), + ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {} Stream::~Stream() = default; |