diff --git a/frontend/Python/ops/func.py b/frontend/Python/ops/func.py index a7dcc5e11b..b58ba92d97 100644 --- a/frontend/Python/ops/func.py +++ b/frontend/Python/ops/func.py @@ -104,6 +104,8 @@ def param_extract( dtype_mapping = { TensorDType.Float32: ir.F32Type.get(), TensorDType.Int64: ir.IntegerType.get_signless(64), + TensorDType.BFloat16: ir.BF16Type.get(), + TensorDType.Float16: ir.F16Type.get(), } memref_element_type = dtype_mapping[node.tensor_meta["dtype"]] if(len(node.tensor_meta['shape'])== 0): diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index 9562d95d83..470078509f 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -117,7 +117,7 @@ def _scalar_to_tensor( doesn't support operation between scalers and tensors.""" element = ( ir.FloatAttr.get(element_type, float(scalar)) - if str(element_type) == "f32" + if str(element_type) in ("f32", "bf16", "f16") else ir.IntegerAttr.get(element_type, int(scalar)) ) attr = ir.DenseElementsAttr.get_splat( @@ -804,7 +804,7 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: ).element_type if result_element_type == ir.IntegerType.get_signless(1): element = ir.IntegerAttr.get(result_element_type, 0) - elif result_element_type == ir.F32Type.get(): + elif result_element_type in (ir.F32Type.get(), ir.BF16Type.get(), ir.F16Type.get()): element = ir.FloatAttr.get(result_element_type, 0.0) else: raise NotImplementedError("Unsupported element type!") diff --git a/frontend/Python/ops/utils.py b/frontend/Python/ops/utils.py index 822cb4f535..0979822065 100644 --- a/frontend/Python/ops/utils.py +++ b/frontend/Python/ops/utils.py @@ -56,7 +56,7 @@ def mlir_element_attr_get(type_name, value): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) case TensorDType.Bool: return ir.IntegerAttr.get(ir.IntegerType.get_signless(1), value) - case TensorDType.Floaf16: + case TensorDType.Float16: return ir.FloatAttr.get(ir.F16Type.get(), value) - case TensorDType.BFloaf16: + case TensorDType.BFloat16: return ir.FloatAttr.get(ir.BF16Type.get(), value)