Skip to content

Commit

Permalink
Fix duplicate name when lowering scf.if (#30)
Browse files Browse the repository at this point in the history
Fixes a bug where variables are not automatically renamed when lowered
from inside the region. This, for example, comes up when lowering
```mlir
func.func @main() -> i64 {
  %x = arith.constant false
  scf.if %x {
    %0 = arith.constant 2 : i64
  } else {
    %0 = arith.constant 3 : i64
  }
  %1 = arith.constant 0 : i64
  return %1 : i64
}
```
to 
```mlir
func.func @main() -> i64 {
  %0 = arith.constant false
  cf.cond_br %0, ^bb1, ^bb2
^bb1: 
  %1 = arith.constant 2 : i64
  cf.br ^bb3
^bb2:
  %2 = arith.constant 3 : i64
  cf.br ^bb3
^bb3:
  %3 = arith.constant 0 : i64
  return %3 : i64
}
```
Notice that the second definition of `%0` has to be renamed.

Now it makes sense why MLIR always renames variables. If you don't, any
variable that is inlined and has a name that is the same as another
variable needs to have a new name. But say you have `%0`, `%1` and now
you inline another `%0`. What name would you give it? `%0_0`? You cannot
just call it `%2` because, if I'm not mistaken, LLVM will then start to
complain (fair enough) because `%2` occurs before `%1`. Just ignoring
names and giving everything a new ("fresh") name leads to much cleaner
generated code.
  • Loading branch information
rikhuijzer authored Dec 10, 2024
1 parent 3baeddf commit d87c4c0
Show file tree
Hide file tree
Showing 21 changed files with 510 additions and 175 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# xrcf

Tools to build your own compiler.
<!-- When updating this README also update the README.md in xrcf/ -->

The eXtensible and Reusable Compiler Framework (xrcf) is a framework for building compilers.

You may be looking for:

Expand Down
2 changes: 1 addition & 1 deletion arnoldc/src/arnold_to_mlir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ impl ModuleLowering {
let typ = IntegerType::new(32);
let value = APInt::new(32, 0, true);
let integer = IntegerAttr::new(typ, value);
let name = parent.try_read().unwrap().unique_value_name();
let name = parent.try_read().unwrap().unique_value_name("%");
let result_type = Arc::new(RwLock::new(typ));
let result = constant.add_new_op_result(&name, result_type.clone());
let constant = arith::ConstantOp::from_operation(constant);
Expand Down
16 changes: 8 additions & 8 deletions arnoldc/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,11 @@ mod tests {
let expected = indoc! {r#"
module {
func.func @main() -> i32 {
%x = arith.constant 1 : i1
%0 = arith.constant 1 : i1
experimental.printf("x: ")
experimental.printf("%d", %x)
%0 = arith.constant 0 : i32
return %0 : i32
experimental.printf("%d", %0)
%1 = arith.constant 0 : i32
return %1 : i32
}
}
"#}
Expand Down Expand Up @@ -310,14 +310,14 @@ mod tests {
.trim();
let expected = indoc! {r#"
func.func @main() -> i32 {
%x = arith.constant 0 : i1
scf.if %x {
%0 = arith.constant 0 : i1
scf.if %0 {
experimental.printf("x was true")
} else {
experimental.printf("x was false")
}
%0 = arith.constant 0 : i32
return %0 : i32
%1 = arith.constant 0 : i32
return %1 : i32
}
"#}
.trim();
Expand Down
11 changes: 11 additions & 0 deletions xrcf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# xrcf

<!-- This README shows up at https://crates.io/crates/xrcf -->
<!-- When updating this README also update the README in the root -->

The eXtensible and Reusable Compiler Framework (xrcf) is a framework for building compilers.

You may be looking for:

- [An high-level overview of xrcf](https://docs.rs/xrcf/latest/xrcf/)
- [An example compiler built with xrcf](https://xrcf.org/blog/basic-arnoldc/)
18 changes: 13 additions & 5 deletions xrcf/src/convert/experimental_to_mlir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::dialect::llvm;
use crate::dialect::llvm::PointerType;
use crate::ir::APInt;
use crate::ir::Block;
use crate::ir::BlockArgumentName;
use crate::ir::GuardedBlock;
use crate::ir::GuardedOp;
use crate::ir::GuardedOperation;
Expand All @@ -36,7 +37,7 @@ impl PrintLowering {
let text = op.text().clone();
let text = text.c_string();
let len = text.len();
let name = parent.try_read().unwrap().unique_value_name();
let name = parent.try_read().unwrap().unique_value_name("%");
let typ = llvm::ArrayType::for_bytes(&text);
let typ = Arc::new(RwLock::new(typ));
let result = const_operation.add_new_op_result(&name, typ);
Expand All @@ -52,7 +53,7 @@ impl PrintLowering {
let mut operation = Operation::default();
operation.set_parent(Some(parent.clone()));
let typ = IntegerType::from_str("i16");
let name = parent.try_read().unwrap().unique_value_name();
let name = parent.try_read().unwrap().unique_value_name("%");
let result_type = Arc::new(RwLock::new(typ));
let result = operation.add_new_op_result(&name, result_type);
let op = arith::ConstantOp::from_operation(operation);
Expand All @@ -66,7 +67,7 @@ impl PrintLowering {
let mut operation = Operation::default();
operation.set_parent(Some(parent.clone()));
let typ = llvm::PointerType::new();
let name = parent.try_read().unwrap().unique_value_name();
let name = parent.try_read().unwrap().unique_value_name("%");
let result_type = Arc::new(RwLock::new(typ));
let result = operation.add_new_op_result(&name, result_type);
let array_size = len.result(0);
Expand Down Expand Up @@ -119,7 +120,7 @@ impl PrintLowering {
operation.set_operand(1, var);
}
let typ = IntegerType::from_str("i32");
let name = parent.unique_value_name();
let name = parent.unique_value_name("%");
let result_type = Arc::new(RwLock::new(typ));
let result = operation.add_new_op_result(&name, result_type);

Expand Down Expand Up @@ -187,7 +188,14 @@ impl PrintLowering {
{
let arg_type = PointerType::new();
let arg_type = Arc::new(RwLock::new(arg_type));
op.set_argument_from_type(0, arg_type)?;

let name = BlockArgumentName::Anonymous;
let name = Arc::new(RwLock::new(name));
let argument = crate::ir::BlockArgument::new(name, arg_type);
let value = Value::BlockArgument(argument);
let value = Arc::new(RwLock::new(value));
let operation = op.operation();
operation.set_argument(0, value);
}
if set_varargs {
let value = Value::Variadic;
Expand Down
15 changes: 13 additions & 2 deletions xrcf/src/convert/mlir_to_llvmir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::dialect::func::Func;
use crate::ir;
use crate::ir::Block;
use crate::ir::BlockArgument;
use crate::ir::BlockArgumentName;
use crate::ir::Constant;
use crate::ir::GuardedBlock;
use crate::ir::GuardedOp;
Expand Down Expand Up @@ -289,7 +290,9 @@ fn lower_block_argument_types(operation: &mut Operation) {
if typ.as_any().is::<dialect::llvm::PointerType>() {
let typ = targ3t::llvmir::PointerType::from_str("ptr");
let typ = Arc::new(RwLock::new(typ));
let arg = Value::BlockArgument(BlockArgument::new(None, typ));
let name = BlockArgumentName::Unset;
let name = Arc::new(RwLock::new(name));
let arg = Value::BlockArgument(BlockArgument::new(name, typ));
new_arguments.push(Arc::new(RwLock::new(arg)));
} else {
new_arguments.push(argument.clone());
Expand Down Expand Up @@ -444,7 +447,15 @@ fn set_phi_result(phi: Arc<RwLock<dyn Op>>, argument: &Arc<RwLock<Value>>) {
if let Value::BlockArgument(arg) = &*argument {
let typ = Some(arg.typ());
let defining_op = Some(phi.clone());
let res = OpResult::new(arg.name(), typ, defining_op);
let name = arg.name();
let name = name.try_read().unwrap();
let name = match &*name {
BlockArgumentName::Name(name) => name.to_string(),
BlockArgumentName::Anonymous => panic!("Expected a named block argument"),
BlockArgumentName::Unset => panic!("Block argument has no name"),
};
let name = Arc::new(RwLock::new(Some(name)));
let res = OpResult::new(name, typ, defining_op);
let new = Value::OpResult(res);
let new = Arc::new(RwLock::new(new));
operation.set_results(Values::from_vec(vec![new.clone()]));
Expand Down
66 changes: 60 additions & 6 deletions xrcf/src/convert/scf_to_cf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::convert::RewriteResult;
use crate::dialect;
use crate::ir::Block;
use crate::ir::BlockArgument;
use crate::ir::BlockArgumentName;
use crate::ir::BlockDest;
use crate::ir::BlockLabel;
use crate::ir::GuardedBlock;
Expand All @@ -16,12 +17,44 @@ use crate::ir::Op;
use crate::ir::OpOperand;
use crate::ir::Operation;
use crate::ir::Region;
use crate::ir::Users;
use crate::ir::Value;
use crate::ir::Values;
use anyhow::Result;
use std::sync::Arc;
use std::sync::RwLock;

/// Lower `scf.if` to `cf.cond_br`.
///
/// For example, this rewrites:
/// ```mlir
/// %result = scf.if %0 -> (i32) {
/// %1 = arith.constant 3 : i32
/// scf.yield %c1_i32 : i32
/// } else {
/// %2 = arith.constant 4 : i32
/// scf.yield %2 : i32
/// }
/// ```
/// to
/// ```mlir
/// cf.cond_br %0, ^bb1, ^bb2
/// ^bb1:
/// %1 = arith.constant 3 : i32
/// cf.br ^bb3(%1 : i32)
/// ^bb2:
/// %2 = arith.constant 4 : i32
/// cf.br ^bb3(%2 : i32)
/// ^bb3(%result : i32):
/// cf.br ^bb4
/// ^bb4:
/// return %result : i32
/// ```
///
/// This lowering is similar to the following rewrite method in MLIR:
/// ```cpp
/// LogicalResult IfLowering::matchAndRewrite
/// ```
struct IfLowering;

fn lower_yield_op(op: &dialect::scf::YieldOp, after_label: &str) -> Result<Arc<RwLock<dyn Op>>> {
Expand Down Expand Up @@ -122,20 +155,20 @@ fn add_merge_block(
merge_label: String,
results: Values,
exit_label: String,
) -> Result<()> {
) -> Result<Values> {
let unset_block = parent_region.add_empty_block();
let block = unset_block.set_parent(Some(parent_region.clone()));
block.set_label(Some(merge_label.clone()));
let merge_block_operands = as_block_arguments(results, block.clone())?;
block.set_arguments(merge_block_operands);
let merge_block_arguments = as_block_arguments(results, block.clone())?;
block.set_arguments(merge_block_arguments.clone());

let mut operation = Operation::default();
operation.set_parent(Some(block.clone()));
let mut merge_op = dialect::cf::BranchOp::from_operation(operation);
merge_op.set_dest(Some(Arc::new(RwLock::new(BlockDest::new(&exit_label)))));
let merge_op = Arc::new(RwLock::new(merge_op));
block.set_ops(Arc::new(RwLock::new(vec![merge_op.clone()])));
Ok(())
Ok(merge_block_arguments)
}

fn add_exit_block(
Expand All @@ -162,6 +195,8 @@ fn as_block_arguments(results: Values, parent: Arc<RwLock<Block>>) -> Result<Val
let result = result.try_read().unwrap();
let name = result.name();
let typ = result.typ().unwrap();
let name = BlockArgumentName::Name(name.unwrap());
let name = Arc::new(RwLock::new(name));
let mut arg = BlockArgument::new(name, typ);
arg.set_parent(Some(parent.clone()));
let arg = Value::BlockArgument(arg);
Expand Down Expand Up @@ -233,12 +268,31 @@ fn add_blocks(
let else_label = add_block_from_region(else_label, &after_label, els, parent_region.clone())?;

if has_results {
add_merge_block(
let merge_block_arguments = add_merge_block(
parent_region.clone(),
merge_label.unwrap(),
results,
results.clone(),
exit_label.clone(),
)?;
let merge_block_arguments = merge_block_arguments.vec();
let merge_block_arguments = merge_block_arguments.try_read().unwrap();

let results = results.vec();
let results = results.try_read().unwrap();
assert!(results.len() == merge_block_arguments.len());
for i in 0..results.len() {
let result = results[i].try_read().unwrap();
let users = result.users();
let users = match users {
Users::OpOperands(users) => users,
Users::HasNoOpResults => vec![],
};
let arg = merge_block_arguments[i].clone();
for user in users.iter() {
let mut user = user.try_write().unwrap();
user.set_value(arg.clone());
}
}
}
add_exit_block(op, parent_region.clone(), exit_label)?;
Ok((then_label, else_label))
Expand Down
3 changes: 1 addition & 2 deletions xrcf/src/dialect/arith/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ impl AddiOp {
new_operation.set_attributes(attributes);

let results = Values::default();
// TODO: use results.add_new_op_result()
let mut result = OpResult::default();
let result = OpResult::default();
result.set_name("%c3_i64");
let result = Value::OpResult(result);
let result = Arc::new(RwLock::new(result));
Expand Down
35 changes: 17 additions & 18 deletions xrcf/src/dialect/func/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::ir::GuardedBlock;
use crate::ir::GuardedOp;
use crate::ir::GuardedOperation;
use crate::ir::GuardedRegion;
use crate::ir::GuardedValue;
use crate::ir::IntegerType;
use crate::ir::Op;
use crate::ir::Operation;
Expand All @@ -13,7 +14,6 @@ use crate::ir::Region;
use crate::ir::StringAttr;
use crate::ir::Type;
use crate::ir::UnsetOp;
use crate::ir::Value;
use crate::ir::Values;
use crate::parser::Parse;
use crate::parser::Parser;
Expand Down Expand Up @@ -194,14 +194,6 @@ impl Parse for CallOp {
pub trait Func: Op {
fn identifier(&self) -> Option<String>;
fn set_identifier(&mut self, identifier: String);
fn set_argument_from_type(&mut self, index: usize, typ: Arc<RwLock<dyn Type>>) -> Result<()> {
let argument = crate::ir::BlockArgument::new(None, typ);
let value = Value::BlockArgument(argument);
let value = Arc::new(RwLock::new(value));
let operation = self.operation();
operation.set_argument(index, value);
Ok(())
}
fn sym_visibility(&self) -> Option<String> {
let operation = self.operation();
let attributes = operation.attributes();
Expand Down Expand Up @@ -405,7 +397,8 @@ impl<T: ParserDispatch> Parser<T> {
let visibility = FuncOp::try_parse_func_visibility(parser, &expected_name);
let identifier = parser.expect(TokenKind::AtIdentifier)?;
let identifier = identifier.lexeme.clone();
operation.set_arguments(parser.parse_function_arguments()?);
let arguments = parser.parse_function_arguments()?;
operation.set_arguments(arguments.clone());
operation.set_anonymous_results(parser.result_types()?)?;
let mut op = F::from_operation(operation);
op.set_identifier(identifier);
Expand All @@ -417,6 +410,17 @@ impl<T: ParserDispatch> Parser<T> {
let op_rd = op.try_read().unwrap();
op_rd.operation().set_region(Some(region.clone()));
region.set_parent(Some(op.clone()));

{
let blocks = region.blocks();
let blocks = blocks.try_read().unwrap();
let block = blocks.first().unwrap();
let arguments = arguments.vec();
let arguments = arguments.try_read().unwrap();
for argument in arguments.iter() {
argument.set_parent(Some(block.clone()));
}
}
}

Ok(op)
Expand All @@ -438,13 +442,9 @@ pub struct ReturnOp {
}

impl ReturnOp {
pub fn display_return(
op: &dyn Op,
name: &str,
f: &mut Formatter<'_>,
_indent: i32,
) -> std::fmt::Result {
pub fn display_return(op: &dyn Op, f: &mut Formatter<'_>, _indent: i32) -> std::fmt::Result {
let operation = op.operation();
let name = operation.name();
write!(f, "{name}")?;
let operands = operation.operands().vec();
let operands = operands.try_read().unwrap();
Expand Down Expand Up @@ -473,8 +473,7 @@ impl Op for ReturnOp {
&self.operation
}
fn display(&self, f: &mut Formatter<'_>, _indent: i32) -> std::fmt::Result {
let name = Self::operation_name().to_string();
ReturnOp::display_return(self, &name, f, _indent)
ReturnOp::display_return(self, f, _indent)
}
}

Expand Down
3 changes: 1 addition & 2 deletions xrcf/src/dialect/llvm/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,7 @@ impl Op for ReturnOp {
&self.operation
}
fn display(&self, f: &mut Formatter<'_>, indent: i32) -> std::fmt::Result {
let name = Self::operation_name().to_string();
func::ReturnOp::display_return(self, &name, f, indent)
func::ReturnOp::display_return(self, f, indent)
}
}

Expand Down
Loading

0 comments on commit d87c4c0

Please sign in to comment.