Linalg 扩展
ByteIR 编译器扩展了 MLIR linalg 方言,以支持多种非平凡的模式。 ByteIR 以在现有 linalg 方言之上引入 linalg-ext 方言的方式实现。 linalg-ext 中的操作和转换是期望能够与 linalg 中的现有操作和转换互换工作,并有望最终上传到 LLVM。
理由
对于 linalg 非平凡模式的需要
上游的 linalg 仍然没有很好地覆盖几个性能关键的模式。
有些模式可能无法通过通用算子或仅依赖现有的 linalg 接口在 linalg 方言中轻松表达。Top-k 和 Scan(cumsum)可能属于这一类。
有些可能通过组合几个通用算子来表达,但由于缺乏适当的接口,可能会阻碍所需的转换。Softmax 属于这一类。
有些旨在成为现有上游版本的更通用的替代品,linalg_ext.batch_matmul
属于这一类。
引入 linalg-ext 的实现
引入 linalg-ext 可以提供以下几个好处,
- 它清楚地将算子或转换的扩展与现有的 linalg 分开,避免了误用。
- 可以直观地解决需要引入接口的模式。
变换(Transformation)扩展
ByteIR linalg-ext 中增强或引入了几种变换。
引入了 合并维度变换
- 来合并 linalg.generic 算子的维数。
引入了 消去单位长度维数变换
- 来移除 Linalg 算子中单位长度的维数。
引入了 递降为循环变换
- 来将算子递降为循环。
引入了 Linalg 概括变换
- 将 linalg 算子概括为命名函数,其中
libcall
用于外部库调用。如果libcall
设置为 False,则每个概述的函数都将具有唯一的名称,在这种情况下func_name
只提供命名提示。否则,所有转换的函数调用都引用名为func_name
的相同的外部函数。
引入了 分块标记变换
- 来通过属性表示循环类别 (并行或规约)。
注意这个分块标记变换也能够作用在现有的 linalg 分块和融合变换。
引入了 共享输出的形式到分布式形式变换
- 将并行分块由共享输出的形式转到分布式形式。
增强了 分块变换
- 以支持 linalg-ext 算子。
增强了 融合变换
- 以支持 linalg-ext 算子,
- 来正确地支持沿着规约轴的分块,
- 以支持在融合中将中间结果作为输出,
- 以支持中间结果的张量的维度简化,
- 以支持钻石型结构(if-else),
- 以支持可选的 stop 属性,
- 以支持与 tensor dialect 的融合。
引入了 融合操作数变换
- 以支持融合中多个根节点的情况,
- 以支持检查 func 算子中的运算是否都被融合了。
注意,这个变换将会和融合变换合并起来。
这里我们展示了沿着规约轴分块的不同。
// input.mlir
func.func @tiled_matmul(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<128x128xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%1 : tensor<128x128xf32>) -> tensor<128x128xf32>
return %2 : tensor<128x128xf32>
}
// result after transform.structured.fuse, wrong tiling result
func.func @tile_matmul(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<128x128xf32>
%1 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %0) -> (tensor<128x128xf32>) {
%extracted_slice = tensor.extract_slice %arg0[0, %arg2] [128, 8] [1, 1] : tensor<128x128xf32> to tensor<128x8xf32>
%extracted_slice_0 = tensor.extract_slice %arg1[%arg2, 0] [8, 128] [1, 1] : tensor<128x128xf32> to tensor<8x128xf32>
%2 = linalg.fill ins(%cst : f32) outs(%arg3 : tensor<128x128xf32>) -> tensor<128x128xf32> // shouldn't fill to zero every step
%3 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<128x8xf32>, tensor<8x128xf32>) outs(%2 : tensor<128x128xf32>) -> tensor<128x128xf32>
scf.yield %3 : tensor<128x128xf32>
}
return %1 : tensor<128x128xf32>
}
// result after transform.structured.fuse_ext
func.func @tile_matmul(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<128x128xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
%2 = scf.for %arg2 = %c0 to %c128 step %c8 iter_args(%arg3 = %1) -> (tensor<128x128xf32>) {
%extracted_slice = tensor.extract_slice %arg0[0, %arg2] [128, 8] [1, 1] : tensor<128x128xf32> to tensor<128x8xf32>
%extracted_slice_0 = tensor.extract_slice %arg1[%arg2, 0] [8, 128] [1, 1] : tensor<128x128xf32> to tensor<8x128xf32>
%3 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<128x8xf32>, tensor<8x128xf32>) outs(%arg3 : tensor<128x128xf32>) -> tensor<128x128xf32>
scf.yield %3 : tensor<128x128xf32>
}
return %2 : tensor<128x128xf32>
}
这里我们展示了当有一个中间结果作为输出时的不同。
// input.mlir
func.func @fuse_element(%arg0: tensor<512x128xf32>, %arg1: tensor<512x128xf32>) -> (tensor<512x128xf32>, tensor<512x128xf32>) {
%0 = linalg.elemwise_unary ins(%arg0 : tensor<512x128xf32>)
outs(%arg1: tensor<512x128xf32>) -> tensor<512x128xf32>
%1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<512x128xf32>, tensor<512x128xf32>)
outs(%arg1: tensor<512x128xf32>) -> tensor<512x128xf32>
return %0, %1 : tensor<512x128xf32>, tensor<512x128xf32>
}
// result after transform.structured.fuse
func.func @fuse_element_static(%arg0: tensor<512x128xf32>, %arg1: tensor<512x128xf32>) -> (tensor<512x128xf32>, tensor<512x128xf32>) {
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%0 = linalg.elemwise_unary ... // duplicate producer elemwise_unary
%1 = scf.for %arg2 = %c0 to %c512 step %c32 iter_args(%arg3 = %arg1) -> (tensor<512x128xf32>) {
%2 = scf.for %arg4 = %c0 to %c128 step %c32 iter_args(%arg5 = %arg3) -> (tensor<512x128xf32>) {
...
%3 = linalg.elemwise_unary ... // producer fusion
...
%4 = linalg.elemwise_binary ...
%inserted_slice = tensor.insert_slice ...
scf.yield %inserted_slice : tensor<512x128xf32>
}
scf.yield %2 : tensor<512x128xf32>
}
return %0, %1 : tensor<512x128xf32>, tensor<512x128xf32>
}
// result after transform.structured.fuse_ext
func.func @fuse_element_static(%arg0: tensor<512x128xf32>, %arg1: tensor<512x128xf32>) -> (tensor<512x128xf32>, tensor<512x128xf32>) {
%c128 = arith.constant 128 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%0:2 = scf.for %arg2 = %c0 to %c512 step %c32 iter_args(%arg3 = %arg1, %arg4 = %arg1) -> (tensor<512x128xf32>, tensor<512x128xf32>) {
%1:2 = scf.for %arg5 = %c0 to %c128 step %c32 iter_args(%arg6 = %arg1, %arg7 = %arg1) -> (tensor<512x128xf32>, tensor<512x128xf32>) {
...
%2 = linalg.elemwise_unary ... // producer fusion
...
%3 = linalg.elemwise_binary ...
%inserted_slice = tensor.insert_slice ...
%inserted_slice_3 = tensor.insert_slice ...
scf.yield %inserted_slice, %inserted_slice_3 : tensor<512x128xf32>, tensor<512x128xf32>
}
scf.yield %1#0, %1#1 : tensor<512x128xf32>, tensor<512x128xf32>
}
return %0#1, %0#0 : tensor<512x128xf32>, tensor<512x128xf32>
}
这里展示了分块和融合残差块时的不同。 顺便说一句,上游版本在一系列残差块上执行非常慢,因为一些节点被访问了 2^N 次,N 是残差块的数量。
// input.mlir
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @resnet_block(%arg0: tensor<1x56x56x256xf16>) -> tensor<1x56x56x256xf16> {
%cst = arith.constant dense_resource<__elided__> : tensor<256x1x1x64xf32>
%cst_0 = arith.constant dense_resource<__elided__> : tensor<64x3x3x64xf32>
%cst_1 = arith.constant dense_resource<__elided__> : tensor<64x1x1x256xf32>
%cst_2 = arith.constant 0.000000e+00 : f16
%0 = tensor.empty() : tensor<1x56x56x256xf16>
%1 = linalg.fill ins(%cst_2 : f16) outs(%0 : tensor<1x56x56x256xf16>) -> tensor<1x56x56x256xf16>
%2 = linalg.elemwise_unary {__revisited__} ins(%arg0 : tensor<1x56x56x256xf16>) outs(%1 : tensor<1x56x56x256xf16>) -> tensor<1x56x56x256xf16>
%3 = tensor.empty() : tensor<1x56x56x64xf16>
%4 = linalg.fill ins(%cst_2 : f16) outs(%3 : tensor<1x56x56x64xf16>) -> tensor<1x56x56x64xf16>
%5 = linalg.conv_2d_nhwc_fhwc {__conv_0__, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%2, %cst_1 : tensor<1x56x56x256xf16>, tensor<64x1x1x256xf32>) outs(%4 : tensor<1x56x56x64xf16>) -> tensor<1x56x56x64xf16>
%padded = tensor.pad %5 nofold low[0, 1, 1, 0] high[0, 1, 1, 0] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
tensor.yield %cst_2 : f16
} : tensor<1x56x56x64xf16> to tensor<1x58x58x64xf16>
%6 = tensor.empty() : tensor<1x56x56x64xf16>
%7 = linalg.fill ins(%cst_2 : f16) outs(%6 : tensor<1x56x56x64xf16>) -> tensor<1x56x56x64xf16>
%8 = linalg.conv_2d_nhwc_fhwc {__conv_1__, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%padded, %cst_0 : tensor<1x58x58x64xf16>, tensor<64x3x3x64xf32>) outs(%7 : tensor<1x56x56x64xf16>) -> tensor<1x56x56x64xf16>
%9 = tensor.empty() : tensor<1x56x56x256xf16>
%10 = linalg.fill ins(%cst_2 : f16) outs(%9 : tensor<1x56x56x256xf16>) -> tensor<1x56x56x256xf16>
%11 = linalg.conv_2d_nhwc_fhwc {__conv_2__, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%8, %cst : tensor<1x56x56x64xf16>, tensor<256x1x1x64xf32>) outs(%10 : tensor<1x56x56x256xf16>) -> tensor<1x56x56x256xf16>
%12 = tensor.empty() : tensor<1x56x56x256xf16>
%13 = linalg.generic {__root__, indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2, %11 : tensor<1x56x56x256xf16>, tensor<1x56x56x256xf16>) outs(%12 : tensor<1x56x56x256xf16>) {
^bb0(%in: f16, %in_3: f16, %out: f16):
%14 = arith.addf %in, %in_3 : f16
linalg.yield %14 : f16
} -> tensor<1x56x56x256xf16>
return %13 : tensor<1x56x56x256xf16>
}
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
%0 = transform.structured.match attributes {__root__} in %arg0 : (!pdl.operation) -> !pdl.operation
%transformed, %loops:2 = transform.structured.fuse_ext %0 {tile_interchange = [], tile_sizes = [0, 8, 0, 32]}
cleanup
}
// result after transform.structured.fuse_ext, `linalg.elemwise_unary {__revisited__}` is tiled only once
// and its tile size is calculated by getting the maximum of two paths
#map = affine_map<(d0) -> (-d0 + 1, 0)>
#map1 = affine_map<(d0) -> (0, d0 - 1)>
#map2 = affine_map<(d0) -> (56, d0)>
#map3 = affine_map<(d0) -> (56, d0 + 9)>
#map4 = affine_map<(d0, d1) -> (d0 - d1)>
#map5 = affine_map<(d0, d1, d2) -> (-d0 - d1 + d2 + 10)>
#map6 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @resnet_block(%arg0: tensor<1x56x56x256xf16>) -> tensor<1x56x56x256xf16> {
%c256 = arith.constant 256 : index
%c56 = arith.constant 56 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%cst_0 = arith.constant dense_resource<__elided__> : tensor<64x1x1x256xf32>
%cst_1 = arith.constant dense_resource<__elided__> : tensor<64x3x3x64xf32>
%cst_2 = arith.constant dense_resource<__elided__> : tensor<256x1x1x64xf32>
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%0 = tensor.empty() : tensor<1x56x56x256xf16>
%1 = tensor.empty() : tensor<1x56x56x64xf16>
%2:7 = scf.for %arg1 = %c0 to %c56 step %c8 iter_args(%arg2 = %0, %arg3 = %1, %arg4 = %1, %arg5 = %1, %arg6 = %0, %arg7 = %1, %arg8 = %0) -> (tensor<1x56x56x256xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x256xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x256xf16>) {
%3:7 = scf.for %arg9 = %c0 to %c256 step %c32 iter_args(%arg10 = %arg2, %arg11 = %arg3, %arg12 = %arg4, %arg13 = %arg5, %arg14 = %arg6, %arg15 = %arg7, %arg16 = %arg8) -> (tensor<1x56x56x256xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x256xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x256xf16>) {
%4 = arith.addi %arg1, %c8 : index
%5 = arith.addi %arg9, %c32 : index
%6 = arith.minsi %arg9, %c0 : index
%7 = arith.maxsi %5, %c256 : index
%8 = arith.subi %7, %6 : index
%9 = arith.subi %arg9, %6 : index
%10 = arith.subi %c0, %6 : index
%11 = affine.max #map(%arg1)
%12 = affine.max #map1(%arg1)
%13 = affine.min #map2(%12)
%14 = affine.min #map3(%arg1)
%15 = affine.apply #map4(%14, %13)
%16 = affine.apply #map5(%11, %14, %13)
%extracted_slice = tensor.extract_slice %arg13[0, %13, 0, 0] [1, %15, 56, 64] [1, 1, 1, 1] : tensor<1x56x56x64xf16> to tensor<1x?x56x64xf16>
%17 = linalg.fill ins(%cst : f16) outs(%extracted_slice : tensor<1x?x56x64xf16>) -> tensor<1x?x56x64xf16>
%extracted_slice_3 = tensor.extract_slice %arg11[0, %arg1, 0, 0] [1, 8, 56, 64] [1, 1, 1, 1] : tensor<1x56x56x64xf16> to tensor<1x8x56x64xf16>
%18 = linalg.fill ins(%cst : f16) outs(%extracted_slice_3 : tensor<1x8x56x64xf16>) -> tensor<1x8x56x64xf16>
%extracted_slice_4 = tensor.extract_slice %cst_2[%arg9, 0, 0, 0] [32, 1, 1, 64] [1, 1, 1, 1] : tensor<256x1x1x64xf32> to tensor<32x1x1x64xf32>
%19 = tensor.empty() : tensor<1x8x56x32xf16>
%20 = linalg.fill ins(%cst : f16) outs(%19 : tensor<1x8x56x32xf16>) -> tensor<1x8x56x32xf16>
%inserted_slice = tensor.insert_slice %18 into %arg12[0, %arg1, 0, 0] [1, 8, 56, 64] [1, 1, 1, 1] : tensor<1x8x56x64xf16> into tensor<1x56x56x64xf16>
%inserted_slice_5 = tensor.insert_slice %17 into %arg15[0, %13, 0, 0] [1, %15, 56, 64] [1, 1, 1, 1] : tensor<1x?x56x64xf16> into tensor<1x56x56x64xf16>
%21 = arith.minsi %arg1, %13 : index
%22 = arith.addi %13, %15 : index
%23 = arith.maxsi %4, %22 : index
%24 = arith.subi %23, %21 : index
%extracted_slice_6 = tensor.extract_slice %arg0[0, %21, 0, %6] [1, %24, 56, %8] [1, 1, 1, 1] : tensor<1x56x56x256xf16> to tensor<1x?x56x?xf16>
%extracted_slice_7 = tensor.extract_slice %arg14[0, %21, 0, %6] [1, %24, 56, %8] [1, 1, 1, 1] : tensor<1x56x56x256xf16> to tensor<1x?x56x?xf16>
%25 = linalg.fill ins(%cst : f16) outs(%extracted_slice_7 : tensor<1x?x56x?xf16>) -> tensor<1x?x56x?xf16>
%26 = linalg.elemwise_unary {__revisited__} ins(%extracted_slice_6 : tensor<1x?x56x?xf16>) outs(%25 : tensor<1x?x56x?xf16>) -> tensor<1x?x56x?xf16>
%27 = arith.subi %arg1, %21 : index
%extracted_slice_8 = tensor.extract_slice %26[0, %27, 0, %9] [1, 8, 56, 32] [1, 1, 1, 1] : tensor<1x?x56x?xf16> to tensor<1x8x56x32xf16>
%28 = arith.subi %13, %21 : index
%extracted_slice_9 = tensor.extract_slice %26[0, %28, 0, %10] [1, %15, 56, 256] [1, 1, 1, 1] : tensor<1x?x56x?xf16> to tensor<1x?x56x256xf16>
%29 = linalg.conv_2d_nhwc_fhwc {__conv_0__, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%extracted_slice_9, %cst_0 : tensor<1x?x56x256xf16>, tensor<64x1x1x256xf32>) outs(%17 : tensor<1x?x56x64xf16>) -> tensor<1x?x56x64xf16>
%padded = tensor.pad %29 nofold low[0, %11, 1, 0] high[0, %16, 1, 0] {
^bb0(%arg17: index, %arg18: index, %arg19: index, %arg20: index):
tensor.yield %cst : f16
} : tensor<1x?x56x64xf16> to tensor<1x10x58x64xf16>
%30 = linalg.conv_2d_nhwc_fhwc {__conv_1__, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%padded, %cst_1 : tensor<1x10x58x64xf16>, tensor<64x3x3x64xf32>) outs(%18 : tensor<1x8x56x64xf16>) -> tensor<1x8x56x64xf16>
%31 = linalg.conv_2d_nhwc_fhwc {__conv_2__, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%30, %extracted_slice_4 : tensor<1x8x56x64xf16>, tensor<32x1x1x64xf32>) outs(%20 : tensor<1x8x56x32xf16>) -> tensor<1x8x56x32xf16>
%32 = linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_8, %31 : tensor<1x8x56x32xf16>, tensor<1x8x56x32xf16>) outs(%19 : tensor<1x8x56x32xf16>) attrs = {__root__} {
^bb0(%in: f16, %in_15: f16, %out: f16):
%33 = arith.addf %in, %in_15 : f16
linalg.yield %33 : f16
} -> tensor<1x8x56x32xf16>
%inserted_slice_10 = tensor.insert_slice %32 into %arg10[0, %arg1, 0, %arg9] [1, 8, 56, 32] [1, 1, 1, 1] : tensor<1x8x56x32xf16> into tensor<1x56x56x256xf16>
%inserted_slice_11 = tensor.insert_slice %30 into %arg11[0, %arg1, 0, 0] [1, 8, 56, 64] [1, 1, 1, 1] : tensor<1x8x56x64xf16> into tensor<1x56x56x64xf16>
%inserted_slice_12 = tensor.insert_slice %29 into %arg13[0, %13, 0, 0] [1, %15, 56, 64] [1, 1, 1, 1] : tensor<1x?x56x64xf16> into tensor<1x56x56x64xf16>
%inserted_slice_13 = tensor.insert_slice %26 into %arg14[0, %21, 0, %6] [1, %24, 56, %8] [1, 1, 1, 1] : tensor<1x?x56x?xf16> into tensor<1x56x56x256xf16>
%inserted_slice_14 = tensor.insert_slice %25 into %arg16[0, %21, 0, %6] [1, %24, 56, %8] [1, 1, 1, 1] : tensor<1x?x56x?xf16> into tensor<1x56x56x256xf16>
scf.yield %inserted_slice_10, %inserted_slice_11, %inserted_slice, %inserted_slice_12, %inserted_slice_13, %inserted_slice_5, %inserted_slice_14 : tensor<1x56x56x256xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x256xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x256xf16>
}
scf.yield %3#0, %3#1, %3#2, %3#3, %3#4, %3#5, %3#6 : tensor<1x56x56x256xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x256xf16>, tensor<1x56x56x64xf16>, tensor<1x56x56x256xf16>
}
return %2#0 : tensor<1x56x56x256xf16>
}
请参考 注意力的例子 来理解 FuseExt 算子分块和融合的能力。
增强了 逐元素算子融合变换
- 以支持融合内中间结果作为输出,
- 以支持生产者消费者式的融合以及输入共享的融合,
- 以支持中间张量的维数简化,
- 通过自动将映射算子转为通用算子以支持映射融合,
- 以支持具有常量结果的索引映射。
这里展示了具有输入共享的融合时的不同。
// input.mlir
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @input_sharing(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%3 = arith.addf %in, %in_1 : f32
linalg.yield %3 : f32
} -> tensor<?x?xf32>
%2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%3 = arith.mulf %in, %in_1 : f32
linalg.yield %3 : f32
} -> tensor<?x?xf32>
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
}
// result after linalg-fuse-elementwise-ops, unchanged
func.func @input_sharing(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%3 = arith.addf %in, %in_1 : f32
linalg.yield %3 : f32
} -> tensor<?x?xf32>
%2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%3 = arith.mulf %in, %in_1 : f32
linalg.yield %3 : f32
} -> tensor<?x?xf32>
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
}
// result after linalg-fuse-elementwise-ext="shared-input"
func.func @input_sharing(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
%1:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) outs(%0, %0 : tensor<?x?xf32>, tensor<?x?xf32>) {
^bb0(%in: f32, %in_1: f32, %in_2: f32, %out: f32, %out_3: f32):
%2 = arith.addf %in, %in_1 : f32
%3 = arith.mulf %in, %in_2 : f32
linalg.yield %2, %3 : f32, f32
} -> (tensor<?x?xf32>, tensor<?x?xf32>)
return %1#0, %1#1 : tensor<?x?xf32>, tensor<?x?xf32>
}
这里展示了支持中间张量维数简化时的不同。
// input.mlir
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @may_more_break_outs_dependency(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %in : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.mulf %in, %in : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// result after linalg-fuse-elementwise-ops, no fusion
func.func @may_more_break_outs_dependency(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%4 = arith.addf %in, %in : f32
linalg.yield %4 : f32
} -> tensor<?x?xf32>
%dim_1 = tensor.dim %1, %c0 : tensor<?x?xf32>
%dim_2 = tensor.dim %1, %c1 : tensor<?x?xf32>
%2 = tensor.empty(%dim_1, %dim_2) : tensor<?x?xf32>
%3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%4 = arith.mulf %in, %in : f32
linalg.yield %4 : f32
} -> tensor<?x?xf32>
return %3 : tensor<?x?xf32>
}
// result after linalg-fuse-elementwise-ext, perfect fusion
func.func @may_more_break_outs_dependency(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %in : f32
%3 = arith.mulf %2, %2 : f32
linalg.yield %3 : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
这里展示了支持具有常量结果的索引映射时的不同。
// input.mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1) -> (d0, 0)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
func.func @constant_in_affine_map_with_collapse_shape(%arg0: tensor<1x256x1024xf32>, %arg1: tensor<256x1024xf16>, %arg2: tensor<256x1xf32>, %arg3: tensor<256x1xf32>) -> tensor<256x1024xf32> {
%expanded = tensor.expand_shape %arg1 [[0, 1], [2]] : tensor<256x1024xf16> into tensor<1x256x1024xf16>
%0 = tensor.empty() : tensor<1x256x1024xf32>
%1 = tensor.empty() : tensor<256x1024xf32>
%2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded, %arg0 : tensor<1x256x1024xf16>, tensor<1x256x1024xf32>) outs(%0 : tensor<1x256x1024xf32>) {
^bb0(%in: f16, %in_0: f32, %out: f32):
%4 = arith.extf %in : f16 to f32
%5 = arith.addf %in_0, %4 : f32
linalg.yield %5 : f32
} -> tensor<1x256x1024xf32>
%collapsed = tensor.collapse_shape %2 [[0, 1], [2]] : tensor<1x256x1024xf32> into tensor<256x1024xf32>
%3 = linalg.generic {indexing_maps = [#map1, #map1, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%arg3, %arg2, %collapsed : tensor<256x1xf32>, tensor<256x1xf32>, tensor<256x1024xf32>) outs(%1 : tensor<256x1024xf32>) {
^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32):
%4 = arith.subf %in_1, %in_0 : f32
%5 = arith.mulf %4, %in : f32
linalg.yield %5 : f32
} -> tensor<256x1024xf32>
return %3 : tensor<256x1024xf32>
}
// result after linalg-fuse-elementwise-ops, no fusion
func.func @constant_in_affine_map_with_collapse_shape(%arg0: tensor<1x256x1024xf32>, %arg1: tensor<256x1024xf16>, %arg2: tensor<256x1xf32>, %arg3: tensor<256x1xf32>) -> tensor<256x1024xf32> {
%expanded = tensor.expand_shape %arg1 [[0, 1], [2]] : tensor<256x1024xf16> into tensor<1x256x1024xf16>
%0 = tensor.empty() : tensor<1x256x1024xf32>
%1 = tensor.empty() : tensor<256x1024xf32>
%2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded, %arg0 : tensor<1x256x1024xf16>, tensor<1x256x1024xf32>) outs(%0 : tensor<1x256x1024xf32>) {
^bb0(%in: f16, %in_0: f32, %out: f32):
%4 = arith.extf %in : f16 to f32
%5 = arith.addf %in_0, %4 : f32
linalg.yield %5 : f32
} -> tensor<1x256x1024xf32>
%collapsed = tensor.collapse_shape %2 [[0, 1], [2]] : tensor<1x256x1024xf32> into tensor<256x1024xf32>
%3 = linalg.generic {indexing_maps = [#map1, #map1, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%arg3, %arg2, %collapsed : tensor<256x1xf32>, tensor<256x1xf32>, tensor<256x1024xf32>) outs(%1 : tensor<256x1024xf32>) {
^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32):
%4 = arith.subf %in_1, %in_0 : f32
%5 = arith.mulf %4, %in : f32
linalg.yield %5 : f32
} -> tensor<256x1024xf32>
return %3 : tensor<256x1024xf32>
}
// result after linalg-fuse-elementwise-ext, perfect fusion
func.func @constant_in_affine_map_with_collapse_shape(%arg0: tensor<1x256x1024xf32>, %arg1: tensor<256x1024xf16>, %arg2: tensor<256x1xf32>, %arg3: tensor<256x1xf32>) -> tensor<256x1024xf32> {
%expanded = tensor.expand_shape %arg1 [[0, 1], [2]] : tensor<256x1024xf16> into tensor<1x256x1024xf16>
%expanded_0 = tensor.expand_shape %arg3 [[0, 1], [2]] : tensor<256x1xf32> into tensor<1x256x1xf32>
%expanded_1 = tensor.expand_shape %arg2 [[0, 1], [2]] : tensor<256x1xf32> into tensor<1x256x1xf32>
%0 = tensor.empty() : tensor<1x256x1024xf32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded, %arg0, %expanded_0, %expanded_1 : tensor<1x256x1024xf16>, tensor<1x256x1024xf32>, tensor<1x256x1xf32>, tensor<1x256x1xf32>) outs(%0 : tensor<1x256x1024xf32>) {
^bb0(%in: f16, %in_2: f32, %in_3: f32, %in_4: f32, %out: f32):
%2 = arith.extf %in : f16 to f32
%3 = arith.addf %in_2, %2 : f32
%4 = arith.subf %3, %in_4 : f32
%5 = arith.mulf %4, %in_3 : f32
linalg.yield %5 : f32
} -> tensor<1x256x1024xf32>
%collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<1x256x1024xf32> into tensor<256x1024xf32>
return %collapsed : tensor<256x1024xf32>
}
Linalg-ext 算子拓展
别名算子 (Alias Op)
Linalg-ext 别名算子是作为辅助算子,帮助输入共享融合而引入的。
它是在Pass内部产生的,并且通常是在通用算子内。
它并不会因为标准化被消去,并且仅仅会因为调用 populateRemoveLinalgExtAliasPattern
被移除。
注意:别名算子并不是一个结构化的算子, 并且没有像 LoopIteratorType
这样的接口。
对角算子 (Diag Op)
Linalg-ext 对角算子是为了表示对角矩阵而引入的。 它是一个结构化的算子,但是目前仅仅是用作输出的中间表示,通常与矩阵相乘算子一起使用。 基于后端, 一个带有对角算子的矩阵相乘通常可以被重写成:
- 与一个减少了负载的矩阵的矩阵乘
- 一个稀疏的矩阵乘
- 一个带有广播的逐元素乘
定义:
- 操作数:
- input (输入): 形状为 N 的张量
- 初始值/结果:
- output (输出): 形状为 N x N 的张量
遍历算子 (Scan Op)
Linalg-ext 遍历算子是为了表示遍历、前缀和、或者 cumsum
这样的模式,
它是一个结构化的算子。
定义:
- 操作数:
- input (输入): 维数为 N 的张量
- 属性:
- dimension (维数): I64ArrayAttr
- inclusivie (是否包括当前值): BoolAttr
- 初始值/结果:
- output (输出): 维数为 N 的张量
- accumulator (累积器): 维数为 N - 1 的张量
索引更新算子 (Scatter Op)
Linalg-ext 索引更新算子是为了表示按照索引更新的模式, 它是一个结构化的算子。
定义:
- 操作数:
- indices (索引): 张量
- updates (更新值): 张量
- 初始值/结果:
- src (源值): 张量
这里, indices
和 update
的前 rank(indices) - 1
个维度是匹配的。
update
和 src
的后 rank(update) - rank(indices) + 1
个维度是匹配的。
indices
的最后一维表示为 dim(indices, rank(indices) - 1)
, 并且它应当是静态的。src
的秩等于 dim(indices, rank(indices) - 1) + rank(update) - rank(indices) + 1
。
归一化指数函数算子(Softmax Op)
Linalg-ext softmax 算子是为了表示 softmax 模式, 它是一个结构化的算子。
定义:
- 操作数:
- input (输入): 维数为 N 的张量
- 属性
- dimension (维数): I64ArrayAttr
- 初始值/结果:
- output (输出): 维数为 N 的张量,
output_result = exp(input - max_result) / accumulator_result
- max (最大值): 维数为 N - 1 的张量,
max_result = max(max(input, dimension), max_init)
- accumulator (累积器): 维数为 N - 1 的张量,
accumulator_result = accumulator_init * exp(max_init - max_result) + sum(exp(input - max_result), dimension)
- scale (标度): 维数为 N - 1 的张量,
scale_result = accumulator_init * exp(max_init - max_result) / accumulator_result
- output (输出): 维数为 N 的张量,
在这里, 操作数 1
, 最大值定义为 max_result = max(max(input, dimension), max_init)
。
基本上说, 它是对于 input
沿着维度 dimension
作初始值为 max_init
的 reduce_max
的结果。
操作数 2
, 累积器定义为 accumulator_result = accumulator_init * exp(max_init - max_result) + sum(exp(input - max_result), dimension)
基本上说, 它是对于 exp(input - max_result)
沿着维度 dimension
作初始值为 accumulator_init * exp(max_init - max_result)
的 reduce_sum
的结果。
操作数 3
, 标度定义为 scale_result = accumulator_init * exp(max_init - max_result) / accumulator_result
。
最后, 操作数 0
, 输出定义为 output_result = exp(input - max_result) / accumulator_result
。
Topk 算子 (Topk Op)
Linalg-ext topk 算子是用来表示 topk 模式, 它是一个结构化的算子。
定义:
- 操作数:
- 输入值(input_values): 维数为 N 的张量
- 输入索引(input_indices): 可选的维数为 N 的张量
- 属性:
- 维度(dimension): I64ArrayAttr
- 初始值/结果:
- 输出值(output_values): 维数为 N 的张量
- 输出索引(output_indices): 维数为 N 的张量