Skip to content

Commit

Permalink
add operands to airrt.herd_load operation
Browse files Browse the repository at this point in the history
  • Loading branch information
fifield committed Jun 21, 2024
1 parent 92b967b commit f4f758c
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 17 deletions.
4 changes: 2 additions & 2 deletions mlir/include/air/Dialect/AIRRt/AIRRtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def AIRRt_HerdMetadataTerminatorOp

def AIRRt_HerdLoadOp : AIRRt_Op<"herd_load", []> {
let summary = "load a herd";
let arguments = (ins StrAttr:$sym_name);
let arguments = (ins StrAttr:$sym_name, Variadic<AnyType>:$rtp);
let results = (outs I64:$h, Optional<AIRRt_Event>:$event);
let assemblyFormat = [{
$sym_name attr-dict `:` type($h) (`,` type($event)^)?
$sym_name ` ``(` $rtp `)` attr-dict `:` functional-type($rtp, results)
}];
}

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/AIRLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ class AIRHerdConversion : public ConversionPattern {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(op->getBlock());
rewriter.create<airrt::HerdLoadOp>(op->getLoc(), rewriter.getI64Type(),
herd_name_attr.getValue().str());
herd_name_attr.getValue().str(),
/* operands */ SmallVector<Value>());
}

SmallVector<Value, 4> deps;
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/AIRLowering/air_L2L3_to_airrt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//

// RUN: air-opt %s -air-to-std | FileCheck %s
// CHECK: %{{.*}} = airrt.herd_load "herd_0" : i64
// CHECK: %{{.*}} = airrt.herd_load "herd_0" () : () -> i64
// CHECK: airrt.memcpy_nd({{.*}}) : (memref<64x64xi32, 1>, memref<64x64xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
// CHECK: airrt.memcpy_nd({{.*}}) : (memref<64x64xi32, 1>, memref<64x64xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
// CHECK: airrt.memcpy_nd({{.*}}) : (memref<64x64xi32, 1>, memref<64x64xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Conversion/AIRLowering/air_channel_get_put.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ module {
// CHECK: airrt.segment_load "segment_0" : i64
// CHECK: airrt.dma_memcpy_nd(%c3_i32, %{{.*}}, %{{.*}}, %arg0[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}]) : (i32, i64, i64, memref<32x16xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
// CHECK: airrt.dma_memcpy_nd(%c4_i32, %{{.*}}, %{{.*}}, %arg1[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}]) : (i32, i64, i64, memref<32x16xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
// CHECK: airrt.herd_load "herd_0" : i64
// CHECK: airrt.herd_load "herd_0" () : () -> i64

module {
air.channel @channel_3 [2, 2]
Expand Down Expand Up @@ -128,7 +128,7 @@ module {
// CHECK: scf.for
// CHECK: airrt.dma_memcpy_nd(%{{.*}}, %{{.*}}, %{{.*}}, %arg1[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}]) : (i32, i64, i64, memref<32x16xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
// CHECK: scf.yield
// CHECK: airrt.herd_load "herd_0" : i64
// CHECK: airrt.herd_load "herd_0" () : () -> i64
module {
air.channel @channel_5 [2, 2]
air.channel @channel_4 [2, 2]
Expand Down Expand Up @@ -198,7 +198,7 @@ module {
// CHECK: affine.for
// CHECK: airrt.segment_load "segment_0" : i64
// CHECK: airrt.dma_memcpy_nd(%{{.*}}, %{{.*}}, %{{.*}}, %arg0[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}]) : (i32, i64, i64, memref<128xf32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event
// CHECK: airrt.herd_load "herd_0" : i64
// CHECK: airrt.herd_load "herd_0" () : () -> i64

#map = affine_map<()[s0] -> (s0 * 64)>
module {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/AIRRtToLLVM/airrt_L2L3cpy_to_std.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ module attributes {torch.debug_module_name = "mmult"} {
airrt.memcpy_nd(%2, %arg0, [%c0_i64, %c0_i64, %c0_i64, %c0_i64], [%c1_i64, %c1_i64, %c64_i64, %c64_i64], [%c0_i64, %c0_i64, %c64_i64]) : (memref<64x64xi32, 1>, memref<64x64xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
airrt.memcpy_nd(%3, %arg1, [%c0_i64, %c0_i64, %c0_i64, %c0_i64], [%c1_i64, %c1_i64, %c64_i64, %c64_i64], [%c0_i64, %c0_i64, %c64_i64]) : (memref<64x64xi32, 1>, memref<64x64xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
airrt.memcpy_nd(%4, %1, [%c0_i64, %c0_i64, %c0_i64, %c0_i64], [%c1_i64, %c1_i64, %c64_i64, %c64_i64], [%c0_i64, %c0_i64, %c64_i64]) : (memref<64x64xi32, 1>, memref<64x64xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
%5 = airrt.herd_load "herd_0" : i64
%5 = airrt.herd_load "herd_0" () : () -> i64
affine.for %arg3 = 0 to 2 {
affine.for %arg4 = 0 to 2 {
%6 = arith.muli %arg3, %c32 : index
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/AIRRtToLLVM/airrt_herd_load.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ module {
}
func.func @f() {
%ret0 = airrt.segment_load "plot" : i64
%ret1 = airrt.herd_load "elk" : i64
airrt.herd_load "deer" : i64
%ret1 = airrt.herd_load "elk" () : () -> i64
airrt.herd_load "deer" () : () -> i64
return
}
}
2 changes: 1 addition & 1 deletion mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ module {
%c2_i32 = arith.constant 2 : i32
%c7_i32 = arith.constant 7 : i32
%c64_i64 = arith.constant 64 : i64
%p = airrt.herd_load "herd" : i64
%p = airrt.herd_load "herd" () : () -> i64
airrt.dma_memcpy_nd(%c7_i32, %c0_i64, %c0_i64, %arg1[%c0_i64, %c0_i64, %c0_i64, %c0_i64], [%c1_i64, %c1_i64, %c1_i64, %c64_i64], [%c0_i64, %c0_i64, %c0_i64]) {metadata = @airMemcpyId7} : (i32, i64, i64, memref<64xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64])
return
}
Expand Down
28 changes: 22 additions & 6 deletions mlir/test/Dialect/AIRRt/airrt_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,35 @@
//===----------------------------------------------------------------------===//

// RUN: air-opt %s | FileCheck %s
// CHECK: %[[E0:.*]] = airrt.wait_all : !airrt.event
// CHECK: %[[E1:.*]] = airrt.wait_all : !airrt.event
// CHECK: %[[E2:.*]] = airrt.wait_all %[[E0]], %[[E1]] : !airrt.event
// CHECK: airrt.wait_all %[[E2]]
// CHECK-LABEL: func.func @f0()
// CHECK: %[[VAL_0:.*]] = airrt.wait_all : !airrt.event
// CHECK: %[[VAL_1:.*]] = airrt.wait_all : !airrt.event
// CHECK: %[[VAL_2:.*]] = airrt.wait_all %[[VAL_0]], %[[VAL_1]] : !airrt.event
// CHECK: airrt.wait_all %[[VAL_2]]
// CHECK: %[[VAL_3:.*]] = airrt.herd_load "herd_0" () : () -> i64
// CHECK: %[[VAL_4:.*]], %[[VAL_5:.*]] = airrt.herd_load "herd_0" () : () -> (i64, !airrt.event)
// CHECK: %[[VAL_6:.*]] = arith.constant 64 : i64
// CHECK: %[[VAL_7:.*]] = arith.constant 42 : i64
// CHECK: %[[VAL_8:.*]], %[[VAL_9:.*]] = airrt.herd_load "herd_1" (%[[VAL_7]]) : (i64) -> (i64, !airrt.event)
// CHECK: %[[VAL_10:.*]] = airrt.herd_load "herd_2" (%[[VAL_7]], %[[VAL_6]]) : (i64, i64) -> i64
module {

func.func @f0() {
%event1 = airrt.wait_all : !airrt.event
%event2 = airrt.wait_all : !airrt.event
%event3 = airrt.wait_all %event1, %event2 : !airrt.event
airrt.wait_all %event3
%herd_load = airrt.herd_load "herd_0" : i64
%h, %e = airrt.herd_load "herd_0" : i64, !airrt.event

// load herd without runtime parameters
%herd_load = airrt.herd_load "herd_0" () : () -> i64
%h0, %e0 = airrt.herd_load "herd_0" () : () -> (i64, !airrt.event)

// load herd with runtime parameters
%c64 = arith.constant 64 : i64
%c42 = arith.constant 42 : i64
%h1, %e1 = airrt.herd_load "herd_1" (%c42) : (i64) -> (i64, !airrt.event)
%h2 = airrt.herd_load "herd_2" (%c42, %c64) : (i64, i64) -> (i64)

return
}

Expand Down

0 comments on commit f4f758c

Please sign in to comment.