Skip to content

Commit

Permalink
Add bf16 and fp16 support for llama-2
Browse files Browse the repository at this point in the history
  • Loading branch information
AyiStar committed Jul 12, 2024
1 parent 6ac34b9 commit b5271b1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 2 additions & 0 deletions frontend/Python/ops/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def param_extract(
dtype_mapping = {
TensorDType.Float32: ir.F32Type.get(),
TensorDType.Int64: ir.IntegerType.get_signless(64),
TensorDType.BFloat16: ir.BF16Type.get(),
TensorDType.Float16: ir.F16Type.get(),
}
memref_element_type = dtype_mapping[node.tensor_meta["dtype"]]
if(len(node.tensor_meta['shape'])== 0):
Expand Down
4 changes: 2 additions & 2 deletions frontend/Python/ops/tosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _scalar_to_tensor(
doesn't support operation between scalers and tensors."""
element = (
ir.FloatAttr.get(element_type, float(scalar))
if str(element_type) == "f32"
if str(element_type) in ("f32", "bf16", "f16")
else ir.IntegerAttr.get(element_type, int(scalar))
)
attr = ir.DenseElementsAttr.get_splat(
Expand Down Expand Up @@ -804,7 +804,7 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation:
).element_type
if result_element_type == ir.IntegerType.get_signless(1):
element = ir.IntegerAttr.get(result_element_type, 0)
elif result_element_type == ir.F32Type.get():
elif result_element_type in (ir.F32Type.get(), ir.BF16Type.get(), ir.F16Type.get()):
element = ir.FloatAttr.get(result_element_type, 0.0)
else:
raise NotImplementedError("Unsupported element type!")
Expand Down
4 changes: 2 additions & 2 deletions frontend/Python/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def mlir_element_attr_get(type_name, value):
return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)
case TensorDType.Bool:
return ir.IntegerAttr.get(ir.IntegerType.get_signless(1), value)
case TensorDType.Floaf16:
case TensorDType.Float16:
return ir.FloatAttr.get(ir.F16Type.get(), value)
case TensorDType.BFloaf16:
case TensorDType.BFloat16:
return ir.FloatAttr.get(ir.BF16Type.get(), value)

0 comments on commit b5271b1

Please sign in to comment.