diff --git a/attentionbench/attention_utils.py b/attentionbench/attention_utils.py index 6397f01..d8d8d18 100644 --- a/attentionbench/attention_utils.py +++ b/attentionbench/attention_utils.py @@ -139,8 +139,10 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None): {{ indexing_maps = [#Q, #K, #V, #S, #O] {",compilation_info = #tuning" if tuning and config.dtype == "f16" else ""} }} - ins(%Q, %K, %V, %scale : !Q, !K, !V, !dtype) - outs(%empty : !O) -> !O + ins(%Q, %K, %V, %scale : !Q, !K, !V, !dtype) outs(%empty : !O) {{ + ^bb0(%score: f32): + iree_linalg_ext.yield %score : f32 + }} -> !O return %O : !O }} """