summaryrefslogtreecommitdiff
path: root/parallel-libs
diff options
context:
space:
mode:
authorJason Henline <jhen@google.com>2016-09-01 18:35:37 +0000
committerJason Henline <jhen@google.com>2016-09-01 18:35:37 +0000
commit2d2da4b5d51063f5cf4914fdfff268b4884a9c77 (patch)
tree4fa21e8ed3353a9b9b93682db5f83546a48da2b5 /parallel-libs
parentf2f148636c8ef954c6eca8b207cb91b2a0ac7935 (diff)
[SE] Make Stream movable
Summary: The example code makes it clear that this is a much better design decision. Reviewers: jlebar Subscribers: jprice, parallel_libs-commits Differential Revision: https://reviews.llvm.org/D24142
Diffstat (limited to 'parallel-libs')
-rw-r--r--parallel-libs/streamexecutor/examples/Example.cpp6
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Device.h3
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/Stream.h15
-rw-r--r--parallel-libs/streamexecutor/lib/Device.cpp4
-rw-r--r--parallel-libs/streamexecutor/lib/Stream.cpp3
5 files changed, 17 insertions, 14 deletions
diff --git a/parallel-libs/streamexecutor/examples/Example.cpp b/parallel-libs/streamexecutor/examples/Example.cpp
index af1994da415..8f42ffa0a3b 100644
--- a/parallel-libs/streamexecutor/examples/Example.cpp
+++ b/parallel-libs/streamexecutor/examples/Example.cpp
@@ -121,13 +121,13 @@ int main() {
getOrDie(Device->allocateDeviceMemory<float>(ArraySize));
// Run operations on a stream.
- std::unique_ptr<se::Stream> Stream = getOrDie(Device->createStream());
- Stream->thenCopyH2D<float>(HostX, X)
+ se::Stream Stream = getOrDie(Device->createStream());
+ Stream.thenCopyH2D<float>(HostX, X)
.thenCopyH2D<float>(HostY, Y)
.thenLaunch(ArraySize, 1, *Kernel, A, X, Y)
.thenCopyD2H<float>(X, HostX);
// Wait for the stream to complete.
- se::dieIfError(Stream->blockHostUntilDone());
+ se::dieIfError(Stream.blockHostUntilDone());
// Process output data in HostX.
std::vector<float> ExpectedX = {4, 47, 90, 133};
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h
index c37f9b1affb..48ecf22ae76 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h
@@ -50,7 +50,8 @@ public:
std::move(*MaybeKernelHandle));
}
- Expected<std::unique_ptr<Stream>> createStream();
+ /// Creates a stream object for this device.
+ Expected<Stream> createStream();
/// Allocates an array of ElementCount entries of type T in device memory.
template <typename T>
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
index d1c82f9e5ea..1acb18139d8 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h
@@ -61,19 +61,22 @@ class Stream {
public:
explicit Stream(std::unique_ptr<PlatformStreamHandle> PStream);
+ Stream(Stream &&Other) = default;
+ Stream &operator=(Stream &&Other) = default;
+
~Stream();
/// Returns whether any error has occurred while entraining work on this
/// stream.
bool isOK() const {
- llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex);
+ llvm::sys::ScopedReader ReaderLock(*ErrorMessageMutex);
return !ErrorMessage;
}
/// Returns the status created by the first error that occurred while
/// entraining work on this stream.
Error getStatus() const {
- llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex);
+ llvm::sys::ScopedReader ReaderLock(*ErrorMessageMutex);
if (ErrorMessage)
return make_error(*ErrorMessage);
else
@@ -315,7 +318,7 @@ private:
/// Does not overwrite the error if it is already set.
void setError(Error &&E) {
if (E) {
- llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex);
+ llvm::sys::ScopedWriter WriterLock(*ErrorMessageMutex);
if (!ErrorMessage)
ErrorMessage = consumeAndGetMessage(std::move(E));
}
@@ -325,7 +328,7 @@ private:
///
/// Does not overwrite the error if it is already set.
void setError(llvm::Twine Message) {
- llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex);
+ llvm::sys::ScopedWriter WriterLock(*ErrorMessageMutex);
if (!ErrorMessage)
ErrorMessage = Message.str();
}
@@ -337,9 +340,7 @@ private:
std::unique_ptr<PlatformStreamHandle> ThePlatformStream;
/// Mutex that guards the error state flags.
- ///
- /// Mutable so that it can be obtained via const reader lock.
- mutable llvm::sys::RWMutex ErrorMessageMutex;
+ std::unique_ptr<llvm::sys::RWMutex> ErrorMessageMutex;
/// First error message for an operation in this stream or empty if there have
/// been no errors.
diff --git a/parallel-libs/streamexecutor/lib/Device.cpp b/parallel-libs/streamexecutor/lib/Device.cpp
index 4a5ec11997d..54f03849c68 100644
--- a/parallel-libs/streamexecutor/lib/Device.cpp
+++ b/parallel-libs/streamexecutor/lib/Device.cpp
@@ -27,7 +27,7 @@ Device::Device(PlatformDevice *PDevice) : PDevice(PDevice) {}
Device::~Device() = default;
-Expected<std::unique_ptr<Stream>> Device::createStream() {
+Expected<Stream> Device::createStream() {
Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream =
PDevice->createStream();
if (!MaybePlatformStream) {
@@ -35,7 +35,7 @@ Expected<std::unique_ptr<Stream>> Device::createStream() {
}
assert((*MaybePlatformStream)->getDevice() == PDevice &&
"an executor created a stream with a different stored executor");
- return llvm::make_unique<Stream>(std::move(*MaybePlatformStream));
+ return Stream(std::move(*MaybePlatformStream));
}
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/Stream.cpp b/parallel-libs/streamexecutor/lib/Stream.cpp
index 20a817c2715..e1fca58cc19 100644
--- a/parallel-libs/streamexecutor/lib/Stream.cpp
+++ b/parallel-libs/streamexecutor/lib/Stream.cpp
@@ -17,7 +17,8 @@
namespace streamexecutor {
Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream)
- : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)) {}
+ : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)),
+ ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {}
Stream::~Stream() = default;