Linalg Tiling and Fusion on Attention
Attention mechanisms have reshaped modern machine learning, but their complexity demands advanced computational strategies. Our LinalgExt dialect, tailored to complement upstream Linalg, is specifically designed to optimize these mechanisms. By leveraging enhanced tiling and fusion techniques, our dialect offers a streamlined approach to executing attention-based models efficiently. This document unveils how we tackle the challenges posed by attention and its variations.
Flash attention example
Here, we demonstrate how to reach flash attention in a regular self attention.
In ByteIR, flash attention can be reached from a regular self attention presenting in linalg or linalg-ext ops, such as matmul
and softmax
, through just linalg-ext fuse transformation
of with proper tiling parameters, proper tile_sizes
to fully utilize on-chip memory and tile_interchange
of [2, 1, 0].
// input.mlir
func.func @dot_attention(%arg0: tensor<1024x32xf32>, %arg1: tensor<32x512xf32>, %arg2: tensor<512x32xf32>) -> tensor<1024x32xf32> {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%0 = tensor.empty() : tensor<1024x512xf32>
%1 = tensor.empty() : tensor<1024x32xf32>
%2 = tensor.empty() : tensor<1024x512xf32>
%3 = tensor.empty() : tensor<1024xf32>
%4 = tensor.empty() : tensor<1024xf32>
%5 = tensor.empty() : tensor<1024xf32>
%6 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<1024xf32>) -> tensor<1024xf32>
%7 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1024x512xf32>) -> tensor<1024x512xf32>
%8 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1024x32xf32>) -> tensor<1024x32xf32>
%9 = linalg.fill ins(%cst : f32) outs(%4 : tensor<1024xf32>) -> tensor<1024xf32>
%10 = linalg.matmul ins(%arg0, %arg1 : tensor<1024x32xf32>, tensor<32x512xf32>) outs(%7 : tensor<1024x512xf32>) -> tensor<1024x512xf32>
%11:4 = linalg_ext.softmax dimension(1) ins(%10 : tensor<1024x512xf32>) outs(%2, %6, %9, %5 : tensor<1024x512xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) : tensor<1024x512xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>
%12 = linalg.matmul {__root__} ins(%11#0, %arg2 : tensor<1024x512xf32>, tensor<512x32xf32>) outs(%8 : tensor<1024x32xf32>) -> tensor<1024x32xf32>
return %12 : tensor<1024x32xf32>
}
// result after transform.structured.fuse_ext {tile_interchange = [2, 1, 0], tile_sizes = [4, 0, 8]}
func.func @dot_attention(%arg0: tensor<1024x32xf32>, %arg1: tensor<32x512xf32>, %arg2: tensor<512x32xf32>) -> tensor<1024x32xf32> {
%c1024 = arith.constant 1024 : index
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%0 = tensor.empty() : tensor<1024x32xf32>
%1 = tensor.empty() : tensor<1024xf32>
%2 = tensor.empty() : tensor<1024xf32>
%3 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<1024xf32>) -> tensor<1024xf32>
%4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1024x32xf32>) -> tensor<1024x32xf32>
%5 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1024xf32>) -> tensor<1024xf32>
%6:3 = scf.for %arg3 = %c0 to %c512 step %c8 iter_args(%arg4 = %4, %arg5 = %3, %arg6 = %5) -> (tensor<1024x32xf32>, tensor<1024xf32>, tensor<1024xf32>) {
%7:3 = scf.for %arg7 = %c0 to %c1024 step %c4 iter_args(%arg8 = %arg4, %arg9 = %arg5, %arg10 = %arg6) -> (tensor<1024x32xf32>, tensor<1024xf32>, tensor<1024xf32>) {
%extracted_slice = tensor.extract_slice %arg0[%arg7, 0] [4, 32] [1, 1] : tensor<1024x32xf32> to tensor<4x32xf32>
%extracted_slice_1 = tensor.extract_slice %arg1[0, %arg3] [32, 8] [1, 1] : tensor<32x512xf32> to tensor<32x8xf32>
%8 = tensor.empty() : tensor<4x8xf32>
%9 = linalg.fill ins(%cst : f32) outs(%8 : tensor<4x8xf32>) -> tensor<4x8xf32>
%10 = linalg.matmul ins(%extracted_slice, %extracted_slice_1 : tensor<4x32xf32>, tensor<32x8xf32>) outs(%9 : tensor<4x8xf32>) -> tensor<4x8xf32>
%11 = tensor.empty() : tensor<4x8xf32>
%extracted_slice_2 = tensor.extract_slice %arg9[%arg7] [4] [1] : tensor<1024xf32> to tensor<4xf32>
%extracted_slice_3 = tensor.extract_slice %arg10[%arg7] [4] [1] : tensor<1024xf32> to tensor<4xf32>
%12 = tensor.empty() : tensor<4xf32>
%13:4 = linalg_ext.softmax dimension(1) ins(%10 : tensor<4x8xf32>) outs(%11, %extracted_slice_2, %extracted_slice_3, %12 : tensor<4x8xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) : tensor<4x8xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
%extracted_slice_4 = tensor.extract_slice %arg2[%arg3, 0] [8, 32] [1, 1] : tensor<512x32xf32> to tensor<8x32xf32>
%extracted_slice_5 = tensor.extract_slice %arg8[%arg7, 0] [4, 32] [1, 1] : tensor<1024x32xf32> to tensor<4x32xf32>
%14 = tensor.empty() : tensor<4x4xf32>
%15 = linalg_ext.diag ins(%13#3 : tensor<4xf32>) outs(%14 : tensor<4x4xf32>) : tensor<4x4xf32>
%16 = tensor.empty() : tensor<4x32xf32>
%17 = linalg.fill ins(%cst : f32) outs(%16 : tensor<4x32xf32>) -> tensor<4x32xf32>
%18 = linalg.matmul ins(%15, %extracted_slice_5 : tensor<4x4xf32>, tensor<4x32xf32>) outs(%17 : tensor<4x32xf32>) -> tensor<4x32xf32>
%19 = linalg.matmul {__root__} ins(%13#0, %extracted_slice_4 : tensor<4x8xf32>, tensor<8x32xf32>) outs(%18 : tensor<4x32xf32>) -> tensor<4x32xf32>
%inserted_slice = tensor.insert_slice %19 into %arg8[%arg7, 0] [4, 32] [1, 1] : tensor<4x32xf32> into tensor<1024x32xf32>
%inserted_slice_6 = tensor.insert_slice %13#1 into %arg9[%arg7] [4] [1] : tensor<4xf32> into tensor<1024xf32>
%inserted_slice_7 = tensor.insert_slice %13#2 into %arg10[%arg7] [4] [1] : tensor<4xf32> into tensor<1024xf32>
scf.yield %inserted_slice, %inserted_slice_6, %inserted_slice_7 : tensor<1024x32xf32>, tensor<1024xf32>, tensor<1024xf32>
}
scf.yield %7#0, %7#1, %7#2 : tensor<1024x32xf32>, tensor<1024xf32>, tensor<1024xf32>
}
return %6#0 : tensor<1024x32xf32>
}
And multi-head attention is also supported.
// input.mlir
func.func @fuse_multihead_attention(%arg0: tensor<128x16x1024x32xf32>, %arg1: tensor<128x16x32x512xf32>, %arg2: tensor<128x16x512x32xf32>) -> tensor<128x16x1024x32xf32> {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%0 = tensor.empty() : tensor<128x16x1024x512xf32>
%1 = tensor.empty() : tensor<128x16x1024x32xf32>
%2 = tensor.empty() : tensor<128x16x1024x512xf32>
%3 = tensor.empty() : tensor<128x16x1024xf32>
%4 = tensor.empty() : tensor<128x16x1024xf32>
%5 = tensor.empty() : tensor<128x16x1024xf32>
%6 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<128x16x1024xf32>) -> tensor<128x16x1024xf32>
%7 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x16x1024x512xf32>) -> tensor<128x16x1024x512xf32>
%8 = linalg.fill ins(%cst : f32) outs(%1 : tensor<128x16x1024x32xf32>) -> tensor<128x16x1024x32xf32>
%9 = linalg.fill ins(%cst : f32) outs(%4 : tensor<128x16x1024xf32>) -> tensor<128x16x1024xf32>
%10 = linalg_ext.batch_matmul ins(%arg0, %arg1 : tensor<128x16x1024x32xf32>, tensor<128x16x32x512xf32>) outs(%7 : tensor<128x16x1024x512xf32>) layout = "nn"
%11:4 = linalg_ext.softmax dimension(3) ins(%10 : tensor<128x16x1024x512xf32>) outs(%2, %6, %9, %5 : tensor<128x16x1024x512xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>) : tensor<128x16x1024x512xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>
%12 = linalg_ext.batch_matmul ins(%11#0, %arg2 : tensor<128x16x1024x512xf32>, tensor<128x16x512x32xf32>) outs(%8 : tensor<128x16x1024x32xf32>) layout = "nn" {__root__}
return %12 : tensor<128x16x1024x32xf32>
}
// result after transform.structured.fuse_ext {tile_sizes = [2, 0, 8, 0, 4], tile_interchange = [0, 1, 4, 3, 2]}
func.func @fuse_multihead_attention(%arg0: tensor<128x16x1024x32xf32>, %arg1: tensor<128x16x32x512xf32>, %arg2: tensor<128x16x512x32xf32>) -> tensor<128x16x1024x32xf32> {
%c1024 = arith.constant 1024 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%c2 = arith.constant 2 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%0 = tensor.empty() : tensor<128x16x1024x32xf32>
%1 = tensor.empty() : tensor<128x16x1024xf32>
%2 = tensor.empty() : tensor<128x16x1024xf32>
%3 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<128x16x1024xf32>) -> tensor<128x16x1024xf32>
%4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x16x1024x32xf32>) -> tensor<128x16x1024x32xf32>
%5 = linalg.fill ins(%cst : f32) outs(%2 : tensor<128x16x1024xf32>) -> tensor<128x16x1024xf32>
%6:3 = scf.for %arg3 = %c0 to %c128 step %c2 iter_args(%arg4 = %4, %arg5 = %3, %arg6 = %5) -> (tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>) {
%7:3 = scf.for %arg7 = %c0 to %c512 step %c4 iter_args(%arg8 = %arg4, %arg9 = %arg5, %arg10 = %arg6) -> (tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>) {
%8:3 = scf.for %arg11 = %c0 to %c1024 step %c8 iter_args(%arg12 = %arg8, %arg13 = %arg9, %arg14 = %arg10) -> (tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>) {
%extracted_slice = tensor.extract_slice %arg0[%arg3, 0, %arg11, 0] [2, 16, 8, 32] [1, 1, 1, 1] : tensor<128x16x1024x32xf32> to tensor<2x16x8x32xf32>
%extracted_slice_1 = tensor.extract_slice %arg1[%arg3, 0, 0, %arg7] [2, 16, 32, 4] [1, 1, 1, 1] : tensor<128x16x32x512xf32> to tensor<2x16x32x4xf32>
%9 = tensor.empty() : tensor<2x16x8x4xf32>
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<2x16x8x4xf32>) -> tensor<2x16x8x4xf32>
%11 = linalg_ext.batch_matmul ins(%extracted_slice, %extracted_slice_1 : tensor<2x16x8x32xf32>, tensor<2x16x32x4xf32>) outs(%10 : tensor<2x16x8x4xf32>) layout = "nn"
%12 = tensor.empty() : tensor<2x16x8x4xf32>
%extracted_slice_2 = tensor.extract_slice %arg13[%arg3, 0, %arg11] [2, 16, 8] [1, 1, 1] : tensor<128x16x1024xf32> to tensor<2x16x8xf32>
%extracted_slice_3 = tensor.extract_slice %arg14[%arg3, 0, %arg11] [2, 16, 8] [1, 1, 1] : tensor<128x16x1024xf32> to tensor<2x16x8xf32>
%13 = tensor.empty() : tensor<2x16x8xf32>
%14:4 = linalg_ext.softmax dimension(3) ins(%11 : tensor<2x16x8x4xf32>) outs(%12, %extracted_slice_2, %extracted_slice_3, %13 : tensor<2x16x8x4xf32>, tensor<2x16x8xf32>, tensor<2x16x8xf32>, tensor<2x16x8xf32>) : tensor<2x16x8x4xf32>, tensor<2x16x8xf32>, tensor<2x16x8xf32>, tensor<2x16x8xf32>
%extracted_slice_4 = tensor.extract_slice %arg2[%arg3, 0, %arg7, 0] [2, 16, 4, 32] [1, 1, 1, 1] : tensor<128x16x512x32xf32> to tensor<2x16x4x32xf32>
%extracted_slice_5 = tensor.extract_slice %arg12[%arg3, 0, %arg11, 0] [2, 16, 8, 32] [1, 1, 1, 1] : tensor<128x16x1024x32xf32> to tensor<2x16x8x32xf32>
%15 = tensor.empty() : tensor<2x16x8x8xf32>
%16 = linalg_ext.diag ins(%14#3 : tensor<2x16x8xf32>) outs(%15 : tensor<2x16x8x8xf32>) : tensor<2x16x8x8xf32>
%17 = tensor.empty() : tensor<2x16x8x32xf32>
%18 = linalg.fill ins(%cst : f32) outs(%17 : tensor<2x16x8x32xf32>) -> tensor<2x16x8x32xf32>
%19 = linalg_ext.batch_matmul ins(%16, %extracted_slice_5 : tensor<2x16x8x8xf32>, tensor<2x16x8x32xf32>) outs(%18 : tensor<2x16x8x32xf32>) layout = "nn"
%20 = linalg_ext.batch_matmul ins(%14#0, %extracted_slice_4 : tensor<2x16x8x4xf32>, tensor<2x16x4x32xf32>) outs(%19 : tensor<2x16x8x32xf32>) layout = "nn" {__root__}
%inserted_slice = tensor.insert_slice %20 into %arg12[%arg3, 0, %arg11, 0] [2, 16, 8, 32] [1, 1, 1, 1] : tensor<2x16x8x32xf32> into tensor<128x16x1024x32xf32>
%inserted_slice_6 = tensor.insert_slice %14#1 into %arg13[%arg3, 0, %arg11] [2, 16, 8] [1, 1, 1] : tensor<2x16x8xf32> into tensor<128x16x1024xf32>
%inserted_slice_7 = tensor.insert_slice %14#2 into %arg14[%arg3, 0, %arg11] [2, 16, 8] [1, 1, 1] : tensor<2x16x8xf32> into tensor<128x16x1024xf32>
scf.yield %inserted_slice, %inserted_slice_6, %inserted_slice_7 : tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>
}
scf.yield %8#0, %8#1, %8#2 : tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>
}
scf.yield %7#0, %7#1, %7#2 : tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>
} {__byteir_parallel__}
return %6#0 : tensor<128x16x1024x32xf32>
}
Multi-head attention example with tiling on 3 dimensions
// input.mlir
func.func @fuse_multihead_attention_tile_3d(%arg0: tensor<128x16x1024x32xf32>, %arg1: tensor<128x16x32x512xf32>, %arg2: tensor<128x16x512x32xf32>) -> tensor<128x16x1024x32xf32> {
%0 = tensor.empty() : tensor<128x16x1024x512xf32>
%1 = tensor.empty() : tensor<128x16x1024x32xf32>
%2 = tensor.empty() : tensor<128x16x1024x512xf32>
%3 = tensor.empty() : tensor<128x16x1024xf32>
%4 = tensor.empty() : tensor<128x16x1024xf32>
%5 = tensor.empty() : tensor<128x16x1024xf32>
%6 = tensor.empty() : tensor<128x16x32x512xf32>
%cst = arith.constant 0xFF800000 : f32
%7 = linalg.fill ins(%cst : f32) outs(%3 : tensor<128x16x1024xf32>) -> tensor<128x16x1024xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
%8 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x16x1024x512xf32>) -> tensor<128x16x1024x512xf32>
%9 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<128x16x1024x32xf32>) -> tensor<128x16x1024x32xf32>
%10 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<128x16x1024xf32>) -> tensor<128x16x1024xf32>
%11 = linalg_ext.batch_matmul ins(%arg0, %arg1 : tensor<128x16x1024x32xf32>, tensor<128x16x32x512xf32>) outs(%8 : tensor<128x16x1024x512xf32>) layout = "nn"
%12:4 = linalg_ext.softmax dimension(3) ins(%11 : tensor<128x16x1024x512xf32>) outs(%2, %7, %10, %5 : tensor<128x16x1024x512xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>) : tensor<128x16x1024x512xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>
%13 = linalg_ext.batch_matmul ins(%12#0, %arg2 : tensor<128x16x1024x512xf32>, tensor<128x16x512x32xf32>) outs(%9 : tensor<128x16x1024x32xf32>) layout = "nn" {__root__}
return %13 : tensor<128x16x1024x32xf32>
}
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match attributes{"__root__"} in %arg1 : (!pdl.operation) -> !pdl.operation
%1, %loops:3 = transform.structured.fuse_ext %0 {tile_sizes = [2, 0, 8, 0, 4], tile_interchange = [0, 1, 4, 3, 2]}
transform.structured.tile_loop_hint %1 : !pdl.operation
cleanup
}
// result after transform.structured.fuse_ext
func.func @fuse_multihead_attention_tile_3d(%arg0: tensor<128x16x1024x32xf32>, %arg1: tensor<128x16x32x512xf32>, %arg2: tensor<128x16x512x32xf32>) -> tensor<128x16x1024x32xf32> {
%c1024 = arith.constant 1024 : index
%c512 = arith.constant 512 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%c2 = arith.constant 2 : index
%0 = tensor.empty() : tensor<128x16x1024x32xf32>
%1 = tensor.empty() : tensor<128x16x1024xf32>
%2 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<128x16x1024xf32>) -> tensor<128x16x1024xf32>
%3 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x16x1024x32xf32>) -> tensor<128x16x1024x32xf32>
%4 = linalg.fill ins(%cst : f32) outs(%1 : tensor<128x16x1024xf32>) -> tensor<128x16x1024xf32>
%5:3 = scf.for %arg3 = %c0 to %c128 step %c2 iter_args(%arg4 = %3, %arg5 = %2, %arg6 = %4) -> (tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>) {
%6:3 = scf.for %arg7 = %c0 to %c512 step %c4 iter_args(%arg8 = %arg4, %arg9 = %arg5, %arg10 = %arg6) -> (tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>) {
%7:3 = scf.for %arg11 = %c0 to %c1024 step %c8 iter_args(%arg12 = %arg8, %arg13 = %arg9, %arg14 = %arg10) -> (tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>) {
%extracted_slice = tensor.extract_slice %arg0[%arg3, 0, %arg11, 0] [2, 16, 8, 32] [1, 1, 1, 1] : tensor<128x16x1024x32xf32> to tensor<2x16x8x32xf32>
%extracted_slice_1 = tensor.extract_slice %arg1[%arg3, 0, 0, %arg7] [2, 16, 32, 4] [1, 1, 1, 1] : tensor<128x16x32x512xf32> to tensor<2x16x32x4xf32>
%8 = tensor.empty() : tensor<2x16x8x4xf32>
%9 = linalg.fill ins(%cst : f32) outs(%8 : tensor<2x16x8x4xf32>) -> tensor<2x16x8x4xf32>
%10 = linalg_ext.batch_matmul ins(%extracted_slice, %extracted_slice_1 : tensor<2x16x8x32xf32>, tensor<2x16x32x4xf32>) outs(%9 : tensor<2x16x8x4xf32>) layout = "nn"
%extracted_slice_2 = tensor.extract_slice %arg13[%arg3, 0, %arg11] [2, 16, 8] [1, 1, 1] : tensor<128x16x1024xf32> to tensor<2x16x8xf32>
%extracted_slice_3 = tensor.extract_slice %arg14[%arg3, 0, %arg11] [2, 16, 8] [1, 1, 1] : tensor<128x16x1024xf32> to tensor<2x16x8xf32>
%11 = tensor.empty() : tensor<2x16x8xf32>
%12:4 = linalg_ext.softmax dimension(3) ins(%10 : tensor<2x16x8x4xf32>) outs(%8, %extracted_slice_2, %extracted_slice_3, %11 : tensor<2x16x8x4xf32>, tensor<2x16x8xf32>, tensor<2x16x8xf32>, tensor<2x16x8xf32>) : tensor<2x16x8x4xf32>, tensor<2x16x8xf32>, tensor<2x16x8xf32>, tensor<2x16x8xf32>
%extracted_slice_4 = tensor.extract_slice %arg2[%arg3, 0, %arg7, 0] [2, 16, 4, 32] [1, 1, 1, 1] : tensor<128x16x512x32xf32> to tensor<2x16x4x32xf32>
%extracted_slice_5 = tensor.extract_slice %arg12[%arg3, 0, %arg11, 0] [2, 16, 8, 32] [1, 1, 1, 1] : tensor<128x16x1024x32xf32> to tensor<2x16x8x32xf32>
%13 = tensor.empty() : tensor<2x16x8x8xf32>
%14 = linalg_ext.diag ins(%12#3 : tensor<2x16x8xf32>) outs(%13 : tensor<2x16x8x8xf32>) : tensor<2x16x8x8xf32>
%15 = tensor.empty() : tensor<2x16x8x32xf32>
%16 = linalg.fill ins(%cst : f32) outs(%15 : tensor<2x16x8x32xf32>) -> tensor<2x16x8x32xf32>
%17 = linalg_ext.batch_matmul ins(%14, %extracted_slice_5 : tensor<2x16x8x8xf32>, tensor<2x16x8x32xf32>) outs(%16 : tensor<2x16x8x32xf32>) layout = "nn"
%18 = linalg_ext.batch_matmul ins(%12#0, %extracted_slice_4 : tensor<2x16x8x4xf32>, tensor<2x16x4x32xf32>) outs(%17 : tensor<2x16x8x32xf32>) layout = "nn" {__root__}
%inserted_slice = tensor.insert_slice %18 into %arg12[%arg3, 0, %arg11, 0] [2, 16, 8, 32] [1, 1, 1, 1] : tensor<2x16x8x32xf32> into tensor<128x16x1024x32xf32>
%inserted_slice_6 = tensor.insert_slice %12#1 into %arg13[%arg3, 0, %arg11] [2, 16, 8] [1, 1, 1] : tensor<2x16x8xf32> into tensor<128x16x1024xf32>
%inserted_slice_7 = tensor.insert_slice %12#2 into %arg14[%arg3, 0, %arg11] [2, 16, 8] [1, 1, 1] : tensor<2x16x8xf32> into tensor<128x16x1024xf32>
scf.yield %inserted_slice, %inserted_slice_6, %inserted_slice_7 : tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>
}
scf.yield %7#0, %7#1, %7#2 : tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>
}
scf.yield %6#0, %6#1, %6#2 : tensor<128x16x1024x32xf32>, tensor<128x16x1024xf32>, tensor<128x16x1024xf32>
} {__byteir_parallel__}
return %5#0 : tensor<128x16x1024x32xf32>
}
Split-head attention with prologue
support split-head attention, see https://arxiv.org/abs/1909.08053
// input.mlir
func.func @multihead_attention_with_prologue_proj(%arg0: tensor<4x1024x512xf32>, %arg1: tensor<512x512xf32>, %arg2: tensor<512x512xf32>, %arg3: tensor<512x512xf32>) -> tensor<4x8x1024x64xf32> {
%cst = arith.constant 0xFF800000 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x512x512xf32>
%broadcasted = linalg.broadcast ins(%arg1 : tensor<512x512xf32>) outs(%0 : tensor<4x512x512xf32>) dimensions = [0]
%1 = tensor.empty() : tensor<4x1024x512xf32>
%2 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<4x1024x512xf32>) -> tensor<4x1024x512xf32>
%3 = linalg_ext.batch_matmul ins(%arg0, %broadcasted : tensor<4x1024x512xf32>, tensor<4x512x512xf32>) outs(%2 : tensor<4x1024x512xf32>) layout = "nn"
%broadcasted_1 = linalg.broadcast ins(%arg2 : tensor<512x512xf32>) outs(%0 : tensor<4x512x512xf32>) dimensions = [0]
%4 = linalg_ext.batch_matmul ins(%arg0, %broadcasted_1 : tensor<4x1024x512xf32>, tensor<4x512x512xf32>) outs(%2 : tensor<4x1024x512xf32>) layout = "nn"
%broadcasted_2 = linalg.broadcast ins(%arg3 : tensor<512x512xf32>) outs(%0 : tensor<4x512x512xf32>) dimensions = [0]
%5 = linalg_ext.batch_matmul ins(%arg0, %broadcasted_2 : tensor<4x1024x512xf32>, tensor<4x512x512xf32>) outs(%2 : tensor<4x1024x512xf32>) layout = "nn"
%expanded = tensor.expand_shape %3 [[0], [1], [2, 3]] : tensor<4x1024x512xf32> into tensor<4x1024x8x64xf32>
%expanded_3 = tensor.expand_shape %4 [[0], [1], [2, 3]] : tensor<4x1024x512xf32> into tensor<4x1024x8x64xf32>
%expanded_4 = tensor.expand_shape %5 [[0], [1], [2, 3]] : tensor<4x1024x512xf32> into tensor<4x1024x8x64xf32>
%6 = tensor.empty() : tensor<4x8x1024x64xf32>
%7 = tensor.empty() : tensor<4x8x64x1024xf32>
%transposed = linalg.transpose ins(%expanded : tensor<4x1024x8x64xf32>) outs(%6 : tensor<4x8x1024x64xf32>) permutation = [0, 2, 1, 3]
%transposed_5 = linalg.transpose ins(%expanded_3 : tensor<4x1024x8x64xf32>) outs(%7 : tensor<4x8x64x1024xf32>) permutation = [0, 2, 3, 1]
%transposed_6 = linalg.transpose ins(%expanded_4 : tensor<4x1024x8x64xf32>) outs(%6 : tensor<4x8x1024x64xf32>) permutation = [0, 2, 1, 3]
%8 = tensor.empty() : tensor<4x8x1024x1024xf32>
%9 = linalg.fill ins(%cst_0 : f32) outs(%8 : tensor<4x8x1024x1024xf32>) -> tensor<4x8x1024x1024xf32>
%10 = tensor.empty() : tensor<4x8x1024xf32>
%11 = linalg.fill ins(%cst : f32) outs(%10 : tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32>
%12 = linalg.fill ins(%cst_0 : f32) outs(%10 : tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32>
%13 = linalg.fill ins(%cst_0 : f32) outs(%6 : tensor<4x8x1024x64xf32>) -> tensor<4x8x1024x64xf32>
%14 = linalg_ext.batch_matmul ins(%transposed, %transposed_5 : tensor<4x8x1024x64xf32>, tensor<4x8x64x1024xf32>) outs(%9 : tensor<4x8x1024x1024xf32>) layout = "nn"
%15:4 = linalg_ext.softmax dimension(3) ins(%14 : tensor<4x8x1024x1024xf32>) outs(%8, %11, %12, %10 : tensor<4x8x1024x1024xf32>, tensor<4x8x1024xf32>, tensor<4x8x1024xf32>, tensor<4x8x1024xf32>) : tensor<4x8x1024x1024xf32>, tensor<4x8x1024xf32>, tensor<4x8x1024xf32>, tensor<4x8x1024xf32>
%16 = linalg_ext.batch_matmul ins(%15#0, %transposed_6 : tensor<4x8x1024x1024xf32>, tensor<4x8x1024x64xf32>) outs(%13 : tensor<4x8x1024x64xf32>) layout = "nn" {__root__}
return %16 : tensor<4x8x1024x64xf32>
}
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
%0 = transform.structured.match attributes {__root__} in %arg0 : (!pdl.operation) -> !pdl.operation
%transformed, %loops:2 = transform.structured.fuse_ext %0 {tile_interchange = [0, 1, 4, 3, 2], tile_sizes = [2, 4, 0, 0, 0]}
cleanup
}
// result after transform.structured.fuse_ext
#map = affine_map<(d0) -> (d0 * 64)>
func.func @multihead_attention_with_prologue_proj(%arg0: tensor<4x1024x512xf32>, %arg1: tensor<512x512xf32>, %arg2: tensor<512x512xf32>, %arg3: tensor<512x512xf32>) -> tensor<4x8x1024x64xf32> {
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%c4 = arith.constant 4 : index
%c2 = arith.constant 2 : index
%0 = tensor.empty() : tensor<4x1024x512xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x1024x512xf32>) -> tensor<4x1024x512xf32>
%2 = tensor.empty() : tensor<4x8x1024x64xf32>
%3:5 = scf.for %arg4 = %c0 to %c4 step %c2 iter_args(%arg5 = %2, %arg6 = %0, %arg7 = %1, %arg8 = %1, %arg9 = %0) -> (tensor<4x8x1024x64xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>) {
%4:5 = scf.for %arg10 = %c0 to %c8 step %c4 iter_args(%arg11 = %arg5, %arg12 = %arg6, %arg13 = %arg7, %arg14 = %arg8, %arg15 = %arg9) -> (tensor<4x8x1024x64xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>) {
%5 = affine.apply #map(%arg10)
%extracted_slice = tensor.extract_slice %arg0[%arg4, 0, 0] [2, 1024, 512] [1, 1, 1] : tensor<4x1024x512xf32> to tensor<2x1024x512xf32>
%extracted_slice_1 = tensor.extract_slice %arg1[0, %5] [512, 256] [1, 1] : tensor<512x512xf32> to tensor<512x256xf32>
%6 = tensor.empty() : tensor<2x512x256xf32>
%broadcasted = linalg.broadcast ins(%extracted_slice_1 : tensor<512x256xf32>) outs(%6 : tensor<2x512x256xf32>) dimensions = [0]
%7 = tensor.empty() : tensor<2x4x1024x64xf32>
%extracted_slice_2 = tensor.extract_slice %arg2[0, %5] [512, 256] [1, 1] : tensor<512x512xf32> to tensor<512x256xf32>
%broadcasted_3 = linalg.broadcast ins(%extracted_slice_2 : tensor<512x256xf32>) outs(%6 : tensor<2x512x256xf32>) dimensions = [0]
%8 = tensor.empty() : tensor<2x4x64x1024xf32>
%9 = tensor.empty() : tensor<2x4x1024x1024xf32>
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<2x4x1024x1024xf32>) -> tensor<2x4x1024x1024xf32>
%11 = tensor.empty() : tensor<2x4x1024xf32>
%12 = linalg.fill ins(%cst_0 : f32) outs(%11 : tensor<2x4x1024xf32>) -> tensor<2x4x1024xf32>
%13 = linalg.fill ins(%cst : f32) outs(%11 : tensor<2x4x1024xf32>) -> tensor<2x4x1024xf32>
%extracted_slice_4 = tensor.extract_slice %arg3[0, %5] [512, 256] [1, 1] : tensor<512x512xf32> to tensor<512x256xf32>
%broadcasted_5 = linalg.broadcast ins(%extracted_slice_4 : tensor<512x256xf32>) outs(%6 : tensor<2x512x256xf32>) dimensions = [0]
%14 = tensor.empty() : tensor<2x1024x256xf32>
%15 = linalg.fill ins(%cst : f32) outs(%14 : tensor<2x1024x256xf32>) -> tensor<2x1024x256xf32>
%16 = linalg_ext.batch_matmul ins(%extracted_slice, %broadcasted_5 : tensor<2x1024x512xf32>, tensor<2x512x256xf32>) outs(%15 : tensor<2x1024x256xf32>) layout = "nn"
%expanded = tensor.expand_shape %16 [[0], [1], [2, 3]] : tensor<2x1024x256xf32> into tensor<2x1024x4x64xf32>
%transposed = linalg.transpose ins(%expanded : tensor<2x1024x4x64xf32>) outs(%7 : tensor<2x4x1024x64xf32>) permutation = [0, 2, 1, 3]
%17 = linalg.fill ins(%cst : f32) outs(%7 : tensor<2x4x1024x64xf32>) -> tensor<2x4x1024x64xf32>
%inserted_slice = tensor.insert_slice %16 into %arg12[%arg4, 0, %5] [2, 1024, 256] [1, 1, 1] : tensor<2x1024x256xf32> into tensor<4x1024x512xf32>
%inserted_slice_6 = tensor.insert_slice %15 into %arg15[%arg4, 0, %5] [2, 1024, 256] [1, 1, 1] : tensor<2x1024x256xf32> into tensor<4x1024x512xf32>
%18 = linalg_ext.batch_matmul ins(%extracted_slice, %broadcasted : tensor<2x1024x512xf32>, tensor<2x512x256xf32>) outs(%15 : tensor<2x1024x256xf32>) layout = "nn"
%expanded_7 = tensor.expand_shape %18 [[0], [1], [2, 3]] : tensor<2x1024x256xf32> into tensor<2x1024x4x64xf32>
%transposed_8 = linalg.transpose ins(%expanded_7 : tensor<2x1024x4x64xf32>) outs(%7 : tensor<2x4x1024x64xf32>) permutation = [0, 2, 1, 3]
%19 = linalg_ext.batch_matmul ins(%extracted_slice, %broadcasted_3 : tensor<2x1024x512xf32>, tensor<2x512x256xf32>) outs(%15 : tensor<2x1024x256xf32>) layout = "nn"
%expanded_9 = tensor.expand_shape %19 [[0], [1], [2, 3]] : tensor<2x1024x256xf32> into tensor<2x1024x4x64xf32>
%transposed_10 = linalg.transpose ins(%expanded_9 : tensor<2x1024x4x64xf32>) outs(%8 : tensor<2x4x64x1024xf32>) permutation = [0, 2, 3, 1]
%20 = linalg_ext.batch_matmul ins(%transposed_8, %transposed_10 : tensor<2x4x1024x64xf32>, tensor<2x4x64x1024xf32>) outs(%10 : tensor<2x4x1024x1024xf32>) layout = "nn"
%21:4 = linalg_ext.softmax dimension(3) ins(%20 : tensor<2x4x1024x1024xf32>) outs(%9, %12, %13, %11 : tensor<2x4x1024x1024xf32>, tensor<2x4x1024xf32>, tensor<2x4x1024xf32>, tensor<2x4x1024xf32>) : tensor<2x4x1024x1024xf32>, tensor<2x4x1024xf32>, tensor<2x4x1024xf32>, tensor<2x4x1024xf32>
%22 = linalg_ext.batch_matmul ins(%21#0, %transposed : tensor<2x4x1024x1024xf32>, tensor<2x4x1024x64xf32>) outs(%17 : tensor<2x4x1024x64xf32>) layout = "nn" {__root__}
%inserted_slice_11 = tensor.insert_slice %22 into %arg11[%arg4, %arg10, 0, 0] [2, 4, 1024, 64] [1, 1, 1, 1] : tensor<2x4x1024x64xf32> into tensor<4x8x1024x64xf32>
%inserted_slice_12 = tensor.insert_slice %18 into %arg13[%arg4, 0, %5] [2, 1024, 256] [1, 1, 1] : tensor<2x1024x256xf32> into tensor<4x1024x512xf32>
%inserted_slice_13 = tensor.insert_slice %19 into %arg14[%arg4, 0, %5] [2, 1024, 256] [1, 1, 1] : tensor<2x1024x256xf32> into tensor<4x1024x512xf32>
scf.yield %inserted_slice_11, %inserted_slice, %inserted_slice_12, %inserted_slice_13, %inserted_slice_6 : tensor<4x8x1024x64xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>
}
scf.yield %4#0, %4#1, %4#2, %4#3, %4#4 : tensor<4x8x1024x64xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>, tensor<4x1024x512xf32>
}
return %3#0 : tensor<4x8x1024x64xf32>
}
Split-head multi-query-attention
see https://arxiv.org/pdf/1911.02150.pdf
// input.mlir
// batch size = 4
// sequence length = 1024
// hidden size = 512
// head number = 8
// head dimension = 64
func.func @multiquery_attention_with_prologue_proj(%arg0: tensor<4x1024x512xf32>, %arg1: tensor<512x512xf32>, %arg2: tensor<512x64xf32>, %arg3: tensor<512x64xf32>) -> tensor<4x8x1024x64xf32> {
%cst = arith.constant 0xFF800000 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x512x512xf32>
%1 = tensor.empty() : tensor<4x512x64xf32>
%broadcasted = linalg.broadcast ins(%arg1 : tensor<512x512xf32>) outs(%0 : tensor<4x512x512xf32>) dimensions = [0]
%2 = tensor.empty() : tensor<4x1024x512xf32>
%3 = tensor.empty() : tensor<4x1024x64xf32>
%4 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<4x1024x512xf32>) -> tensor<4x1024x512xf32>
%5 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<4x1024x64xf32>) -> tensor<4x1024x64xf32>
%6 = linalg_ext.batch_matmul ins(%arg0, %broadcasted : tensor<4x1024x512xf32>, tensor<4x512x512xf32>) outs(%4 : tensor<4x1024x512xf32>) layout = "nn"
%broadcasted_1 = linalg.broadcast ins(%arg2 : tensor<512x64xf32>) outs(%1 : tensor<4x512x64xf32>) dimensions = [0]
%7 = linalg_ext.batch_matmul ins(%arg0, %broadcasted_1 : tensor<4x1024x512xf32>, tensor<4x512x64xf32>) outs(%5 : tensor<4x1024x64xf32>) layout = "nn" {__stop__}
%broadcasted_2 = linalg.broadcast ins(%arg3 : tensor<512x64xf32>) outs(%1 : tensor<4x512x64xf32>) dimensions = [0]
%8 = linalg_ext.batch_matmul ins(%arg0, %broadcasted_2 : tensor<4x1024x512xf32>, tensor<4x512x64xf32>) outs(%5 : tensor<4x1024x64xf32>) layout = "nn" {__stop__}
%expanded = tensor.expand_shape %6 [[0], [1], [2, 3]] : tensor<4x1024x512xf32> into tensor<4x1024x8x64xf32>
%9 = tensor.empty() : tensor<4x8x1024x64xf32>
%10 = tensor.empty() : tensor<4x64x1024xf32>
%transposed = linalg.transpose ins(%expanded : tensor<4x1024x8x64xf32>) outs(%9 : tensor<4x8x1024x64xf32>) permutation = [0, 2, 1, 3]
%transposed_3 = linalg.transpose ins(%7 : tensor<4x1024x64xf32>) outs(%10 : tensor<4x64x1024xf32>) permutation = [0, 2, 1]
%11 = tensor.empty() : tensor<4x8x1024x1024xf32>
%12 = linalg.fill ins(%cst_0 : f32) outs(%11 : tensor<4x8x1024x1024xf32>) -> tensor<4x8x1024x1024xf32>
%13 = tensor.empty() : tensor<4x8x1024xf32>
%14 = linalg.fill ins(%cst : f32) outs(%13 : tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32>
%15 = linalg.fill ins(%cst_0 : f32) outs(%13 : tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32>
%16 = linalg.fill ins(%cst_0 : f32) outs(%9 : tensor<4x8x1024x64xf32>) -> tensor<4x8x1024x64xf32>
%17 = tensor.empty() : tensor<4x8x64x1024xf32>
%broadcasted_4 = linalg.broadcast ins(%transposed_3 : tensor<4x64x1024xf32>) outs(%17 : tensor<4x8x64x1024xf32>) dimensions = [1]
%broadcasted_5 = linalg.broadcast ins(%8 : tensor<4x1024x64xf32>) outs(%9 : tensor<4x8x1024x64xf32>) dimensions = [1]
%18 = linalg_ext.batch_matmul ins(%transposed, %broadcasted_4 : tensor<4x8x1024x64xf32>, tensor<4x8x64x1024xf32>) outs(%12 : tensor<4x8x1024x1024xf32>) layout = "nn"
%19:4 = linalg_ext.softmax dimension(3) ins(%18 : tensor<4x8x1024x1024xf32>) outs(%11, %14, %15, %13 : tensor<4x8x1024x1024xf32>, tensor<4x8x1024xf32>, tensor<4x8x1024xf32>, tensor<4x8x1024xf32>) : tensor<4x8x1024x1024xf32>, tensor<4x8x1024xf32>, tensor<4x8x1024xf32>, tensor<4x8x1024xf32>
%20 = linalg_ext.batch_matmul ins(%19#0, %broadcasted_5 : tensor<4x8x1024x1024xf32>, tensor<4x8x1024x64xf32>) outs(%16 : tensor<4x8x1024x64xf32>) layout = "nn" {__root__}
return %20 : tensor<4x8x1024x64xf32>
}
transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
%0 = transform.structured.match attributes {__root__} in %arg0 : (!pdl.operation) -> !pdl.operation
%stop = transform.structured.match attributes {__stop__} in %arg0 : (!pdl.operation) -> !pdl.operation
%transformed, %loops:2 = transform.structured.fuse_ext %0, %stop {tile_interchange = [], tile_sizes = [2, 4, 0, 0, 0]}
cleanup
}
// result after transform.structured.fuse_ext
#map = affine_map<(d0) -> (d0 * 64)>
func.func @multiquery_attention_with_prologue_proj(%arg0: tensor<4x1024x512xf32>, %arg1: tensor<512x512xf32>, %arg2: tensor<512x64xf32>, %arg3: tensor<512x64xf32>) -> tensor<4x8x1024x64xf32> {
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%c4 = arith.constant 4 : index
%c2 = arith.constant 2 : index
%0 = tensor.empty() : tensor<4x512x64xf32>
%1 = tensor.empty() : tensor<4x1024x512xf32>
%2 = tensor.empty() : tensor<4x1024x64xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<4x1024x64xf32>) -> tensor<4x1024x64xf32>
%broadcasted = linalg.broadcast ins(%arg2 : tensor<512x64xf32>) outs(%0 : tensor<4x512x64xf32>) dimensions = [0]
%4 = linalg_ext.batch_matmul ins(%arg0, %broadcasted : tensor<4x1024x512xf32>, tensor<4x512x64xf32>) outs(%3 : tensor<4x1024x64xf32>) layout = "nn" {__stop__}
%broadcasted_1 = linalg.broadcast ins(%arg3 : tensor<512x64xf32>) outs(%0 : tensor<4x512x64xf32>) dimensions = [0]
%5 = linalg_ext.batch_matmul ins(%arg0, %broadcasted_1 : tensor<4x1024x512xf32>, tensor<4x512x64xf32>) outs(%3 : tensor<4x1024x64xf32>) layout = "nn" {__stop__}
%6 = tensor.empty() : tensor<4x8x1024x64xf32>
%7 = tensor.empty() : tensor<4x64x1024xf32>
%8:3 = scf.for %arg4 = %c0 to %c4 step %c2 iter_args(%arg5 = %6, %arg6 = %7, %arg7 = %1) -> (tensor<4x8x1024x64xf32>, tensor<4x64x1024xf32>, tensor<4x1024x512xf32>) {
%9:3 = scf.for %arg8 = %c0 to %c8 step %c4 iter_args(%arg9 = %arg5, %arg10 = %arg6, %arg11 = %arg7) -> (tensor<4x8x1024x64xf32>, tensor<4x64x1024xf32>, tensor<4x1024x512xf32>) {
%10 = affine.apply #map(%arg8)
%extracted_slice = tensor.extract_slice %arg0[%arg4, 0, 0] [2, 1024, 512] [1, 1, 1] : tensor<4x1024x512xf32> to tensor<2x1024x512xf32>
%extracted_slice_2 = tensor.extract_slice %arg1[0, %10] [512, 256] [1, 1] : tensor<512x512xf32> to tensor<512x256xf32>
%11 = tensor.empty() : tensor<2x512x256xf32>
%broadcasted_3 = linalg.broadcast ins(%extracted_slice_2 : tensor<512x256xf32>) outs(%11 : tensor<2x512x256xf32>) dimensions = [0]
%12 = tensor.empty() : tensor<2x1024x256xf32>
%13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<2x1024x256xf32>) -> tensor<2x1024x256xf32>
%14 = linalg_ext.batch_matmul ins(%extracted_slice, %broadcasted_3 : tensor<2x1024x512xf32>, tensor<2x512x256xf32>) outs(%13 : tensor<2x1024x256xf32>) layout = "nn"
%expanded = tensor.expand_shape %14 [[0], [1], [2, 3]] : tensor<2x1024x256xf32> into tensor<2x1024x4x64xf32>
%15 = tensor.empty() : tensor<2x4x1024x64xf32>
%transposed = linalg.transpose ins(%expanded : tensor<2x1024x4x64xf32>) outs(%15 : tensor<2x4x1024x64xf32>) permutation = [0, 2, 1, 3]
%extracted_slice_4 = tensor.extract_slice %4[%arg4, 0, 0] [2, 1024, 64] [1, 1, 1] : tensor<4x1024x64xf32> to tensor<2x1024x64xf32>
%extracted_slice_5 = tensor.extract_slice %arg10[%arg4, 0, 0] [2, 64, 1024] [1, 1, 1] : tensor<4x64x1024xf32> to tensor<2x64x1024xf32>
%transposed_6 = linalg.transpose ins(%extracted_slice_4 : tensor<2x1024x64xf32>) outs(%extracted_slice_5 : tensor<2x64x1024xf32>) permutation = [0, 2, 1]
%16 = tensor.empty() : tensor<2x4x64x1024xf32>
%broadcasted_7 = linalg.broadcast ins(%transposed_6 : tensor<2x64x1024xf32>) outs(%16 : tensor<2x4x64x1024xf32>) dimensions = [1]
%17 = tensor.empty() : tensor<2x4x1024x1024xf32>
%18 = linalg.fill ins(%cst : f32) outs(%17 : tensor<2x4x1024x1024xf32>) -> tensor<2x4x1024x1024xf32>
%19 = linalg_ext.batch_matmul ins(%transposed, %broadcasted_7 : tensor<2x4x1024x64xf32>, tensor<2x4x64x1024xf32>) outs(%18 : tensor<2x4x1024x1024xf32>) layout = "nn"
%20 = tensor.empty() : tensor<2x4x1024xf32>
%21 = linalg.fill ins(%cst_0 : f32) outs(%20 : tensor<2x4x1024xf32>) -> tensor<2x4x1024xf32>
%22 = linalg.fill ins(%cst : f32) outs(%20 : tensor<2x4x1024xf32>) -> tensor<2x4x1024xf32>
%23:4 = linalg_ext.softmax dimension(3) ins(%19 : tensor<2x4x1024x1024xf32>) outs(%17, %21, %22, %20 : tensor<2x4x1024x1024xf32>, tensor<2x4x1024xf32>, tensor<2x4x1024xf32>, tensor<2x4x1024xf32>) : tensor<2x4x1024x1024xf32>, tensor<2x4x1024xf32>, tensor<2x4x1024xf32>, tensor<2x4x1024xf32>
%extracted_slice_8 = tensor.extract_slice %5[%arg4, 0, 0] [2, 1024, 64] [1, 1, 1] : tensor<4x1024x64xf32> to tensor<2x1024x64xf32>
%broadcasted_9 = linalg.broadcast ins(%extracted_slice_8 : tensor<2x1024x64xf32>) outs(%15 : tensor<2x4x1024x64xf32>) dimensions = [1]
%24 = linalg.fill ins(%cst : f32) outs(%15 : tensor<2x4x1024x64xf32>) -> tensor<2x4x1024x64xf32>
%25 = linalg_ext.batch_matmul ins(%23#0, %broadcasted_9 : tensor<2x4x1024x1024xf32>, tensor<2x4x1024x64xf32>) outs(%24 : tensor<2x4x1024x64xf32>) layout = "nn" {__root__}
%inserted_slice = tensor.insert_slice %25 into %arg9[%arg4, %arg8, 0, 0] [2, 4, 1024, 64] [1, 1, 1, 1] : tensor<2x4x1024x64xf32> into tensor<4x8x1024x64xf32>
%inserted_slice_10 = tensor.insert_slice %transposed_6 into %arg10[%arg4, 0, 0] [2, 64, 1024] [1, 1, 1] : tensor<2x64x1024xf32> into tensor<4x64x1024xf32>
%inserted_slice_11 = tensor.insert_slice %14 into %arg11[%arg4, 0, %10] [2, 1024, 256] [1, 1, 1] : tensor<2x1024x256xf32> into tensor<4x1024x512xf32>
scf.yield %inserted_slice, %inserted_slice_10, %inserted_slice_11 : tensor<4x8x1024x64xf32>, tensor<4x64x1024xf32>, tensor<4x1024x512xf32>
}
scf.yield %9#0, %9#1, %9#2 : tensor<4x8x1024x64xf32>, tensor<4x64x1024xf32>, tensor<4x1024x512xf32>
}
return %8#0 : tensor<4x8x1024x64xf32>
}