summaryrefslogtreecommitdiff
path: root/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h
blob: 6ea7c36180314b2f7ef0db1d3daf2cc79f2cc60a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
//===-- Kernel.h - StreamExecutor kernel types ------------------*- C++ -*-===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
///
/// \file
/// Types to represent device kernels (code compiled to run on GPU or other
/// accelerator).
///
/// See the \ref index "main page" for an example of how a compiler-generated
/// specialization of the Kernel class template can be used along with the
/// streamexecutor::Stream::thenLaunch method to create a typesafe interface for
/// kernel launches.
///
//===----------------------------------------------------------------------===//

#ifndef STREAMEXECUTOR_KERNEL_H
#define STREAMEXECUTOR_KERNEL_H

#include "streamexecutor/KernelSpec.h"
#include "streamexecutor/Utils/Error.h"

#include <memory>

namespace streamexecutor {

class PlatformDevice;

/// The base class for all kernel types.
///
/// Stores the name of the kernel in both mangled and demangled forms.
class KernelBase {
public:
  KernelBase(PlatformDevice *D, const void *PlatformKernelHandle,
             llvm::StringRef Name);

  KernelBase(const KernelBase &Other) = delete;
  KernelBase &operator=(const KernelBase &Other) = delete;

  KernelBase(KernelBase &&Other);
  KernelBase &operator=(KernelBase &&Other);

  ~KernelBase();

  const void *getPlatformHandle() const { return PlatformKernelHandle; }
  const std::string &getName() const { return Name; }
  const std::string &getDemangledName() const { return DemangledName; }

private:
  PlatformDevice *PDevice;
  const void *PlatformKernelHandle;

  std::string Name;
  std::string DemangledName;
};

/// A StreamExecutor kernel.
///
/// The template parameters are the types of the parameters to the kernel
/// function.
template <typename... ParameterTs> class Kernel : public KernelBase {
public:
  Kernel(PlatformDevice *D, const void *PlatformKernelHandle,
         llvm::StringRef Name)
      : KernelBase(D, PlatformKernelHandle, Name) {}

  Kernel(Kernel &&Other) = default;
  Kernel &operator=(Kernel &&Other) = default;
};

} // namespace streamexecutor

#endif // STREAMEXECUTOR_KERNEL_H