Skip to content

Commit

Permalink
Update attention template
Browse files Browse the repository at this point in the history
Ths commit updates the attention
template to include promote operands
and decomposition_config.

Signed-off-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
manupak committed Oct 31, 2024
1 parent 4621947 commit 1ee7856
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions attentionbench/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def get_lowering_config(self) -> str:
f"#iree_gpu.lowering_config<"
+ "{ "
+ f"workgroup = [{', '.join(map(str, self.wg_tiles))}], "
+ f"reduction = [{', '.join(map(str, self.reduction_tiles))}]"
+ f"reduction = [{', '.join(map(str, self.reduction_tiles))}],"
+ f"promote_operands = [0, 1, 2]"
+ " }"
+ f">"
)
Expand All @@ -93,7 +94,7 @@ def get_translation_info(self) -> str:
return (
f"#iree_codegen.translation_info<"
+ f"LLVMGPUVectorDistribute"
+ f" workgroup_size = [{self.N_warp * 64}, {self.M_warp}]"
+ f" workgroup_size = [{self.N_warp * self.M_warp * 64}]"
+ f" subgroup_size = 64"
+ f" ,{{mma_schedule = {self.get_mma_schedule()}"
+ f" , llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}"
Expand Down Expand Up @@ -137,6 +138,10 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
%empty = tensor.empty() : !O
%O = iree_linalg_ext.attention
{{ indexing_maps = [#Q, #K, #V, #S, #O]
,decomposition_config = {{
qk_attrs = {{attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [0, 1]}}>}},
pv_attrs = {{attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [1]}}>}}
}}
{",compilation_info = #tuning" if tuning and config.dtype == "f16" else ""}
}}
ins(%Q, %K, %V, %scale : !Q, !K, !V, !dtype) outs(%empty : !O) {{
Expand Down

0 comments on commit 1ee7856

Please sign in to comment.