aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
blob: 6178aa393ee0e5645bfe330f348ee6effe951039 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-and-distribute-options -split-input-file | FileCheck %s

func.func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
  linalg.matmul {__internal_linalg_transform__ = "distribute1"}
    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
   outs(%c: memref<?x?xf32>)
  return
}
//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
//      CHECK: func @gemm1(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
//      CHECK: scf.for %[[ARG3:.*]] =
//      CHECK:   %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK:   %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK:   %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK:   %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK:   %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
//      CHECK:   %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
//      CHECK:   %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
//      CHECK:   linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]

// -----

func.func @gemm2(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
  linalg.matmul  {__internal_linalg_transform__ = "distribute2"}
    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
   outs(%c:memref<?x?xf32>)
  return
}
//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
//      CHECK: func @gemm2(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
//      CHECK: %[[ITERY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK: %[[ITERX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK: %[[INBOUNDSY:.*]] = arith.cmpi slt, %[[ITERY]], %{{.*}}
//      CHECK: %[[INBOUNDSX:.*]] = arith.cmpi slt, %[[ITERX]], %{{.*}}
//      CHECK: %[[INBOUNDS:.*]] = arith.andi %[[INBOUNDSY]], %[[INBOUNDSX]]
//      CHECK: scf.if %[[INBOUNDS]]
//      CHECK:   scf.for %[[ARG3:.*]] =
//      CHECK:     %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK:     %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK:     %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK:     %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK:     %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
//      CHECK:     %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
//      CHECK:     %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
//      CHECK:     linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]

// -----

func.func @gemm3(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
  linalg.matmul {__internal_linalg_transform__ = "distribute3"}
    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
   outs(%c: memref<?x?xf32>)
  return
}
//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
//      CHECK: func @gemm3(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
//  CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y
//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
//  CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x
//      CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
//      CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK: %[[STEPX:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSX]]]
//      CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = (%[[LBY]], %[[LBX]]) to (%{{.*}}, %{{.*}}) step (%[[STEPY]], %[[STEPX]])
//      CHECK:   scf.for %[[ARG5:.*]] =
//      CHECK:     %[[SV1:.*]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG5]]]
//      CHECK:     %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG5]], %[[ARG4]]]
//      CHECK:     %[[SV3:.*]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
//      CHECK:     linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]

// -----

func.func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
  linalg.matmul {__internal_linalg_transform__ = "distribute4"}
    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
   outs(%c: memref<?x?xf32>)
  return
}
//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
//      CHECK: func @gemm4(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
//      CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK: %[[INBOUNDS:.*]] = arith.cmpi slt, %[[LBX]], %{{.*}}
//      CHECK: scf.if %[[INBOUNDS]]
//      CHECK:   scf.for %[[ARG3:.*]] =
//      CHECK:     %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK:     %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK:     %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK:     %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK:     %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
//      CHECK:     %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
//      CHECK:     %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
//      CHECK:     linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]

// -----

func.func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
  linalg.matmul {__internal_linalg_transform__ = "distribute5"}
    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
   outs(%c: memref<?x?xf32>)
  return
}
//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
//      CHECK: func @gemm5(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
//  CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x
//      CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK: %[[STEPX:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSX]]]
//      CHECK: %[[INBOUNDS:.*]] = arith.cmpi slt, %[[LBY]], %{{.*}}
//      CHECK: scf.if %[[INBOUNDS]]
//      CHECK:   scf.parallel (%[[ARG3:.*]]) = (%[[LBX]]) to (%{{.*}}) step (%[[STEPX]])
//      CHECK:     scf.for %[[ARG4:.*]] =
//      CHECK:      %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK:       %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK:       %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG4]]]
//      CHECK:       %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[ARG3]]]
//      CHECK:       %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[ARG3]]]
//      CHECK:       linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]

// -----

func.func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
  linalg.matmul {__internal_linalg_transform__ = "distribute6"}
    ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
   outs(%c: memref<?x?xf32>)
  return
}
//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
//      CHECK: func @gemm6(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
//  CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y
//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
//      CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
//      CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
//      CHECK: scf.parallel (%[[ARG3:.*]]) = (%[[LBY]]) to (%{{.*}}) step (%[[STEPY]])
//      CHECK:   scf.for %[[ARG4:.*]] =
//      CHECK:     %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK:     %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
//      CHECK:     %[[SV1:.*]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG4]]]
//      CHECK:     %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]]
//      CHECK:     %[[SV3:.*]] = memref.subview %[[ARG2]][%[[ARG3]], %[[OFFSETX_2]]]
//      CHECK:     linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]

// -----

//      CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
//      CHECK: #[[ADDMAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
//      CHECK: func @matmul_tensors(
// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
func.func @matmul_tensors(
  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
    -> tensor<?x?xf32> {
//  CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
//  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
//  CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y
//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
//  CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x
//      CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]]
//      CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
//      CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]]
//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
//      CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDX]], %[[C8]]]
//      CHECK: %[[LBX:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
//      CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]]
//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
//      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
//      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
//      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
//      CHECK:       %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME:                                  outs(%[[sTC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
//      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
//      CHECK:       scf.yield %[[TD]] : tensor<?x?xf32>
//      CHECK:     scf.yield %[[TD2]] : tensor<?x?xf32>
//      CHECK:   scf.yield %[[TD1]] : tensor<?x?xf32>
  %0 = linalg.matmul {__internal_linalg_transform__ = "tensors_distribute1"}
       ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%arg2: tensor<?x?xf32>)
    -> tensor<?x?xf32>

//      CHECK: return %[[TD0]] : tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}