Skip to content

Commit

Permalink
Add block.inline_region_before (#33)
Browse files Browse the repository at this point in the history
Now that variable names are less important (#30), the lowering can be
simplified.
  • Loading branch information
rikhuijzer authored Dec 19, 2024
1 parent 062cb5e commit ed7ebfb
Show file tree
Hide file tree
Showing 13 changed files with 159 additions and 116 deletions.
6 changes: 2 additions & 4 deletions xrcf/src/convert/mlir_to_llvmir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ impl Rewrite for BlockLowering {
let region = op.operation().region();
if let Some(region) = region {
let blocks = region.blocks();
let blocks = blocks.try_read().unwrap();
for block in blocks.iter() {
for block in blocks.into_iter() {
let label_prefix = block.label_prefix();
if label_prefix == "^" {
return Ok(true);
Expand Down Expand Up @@ -540,8 +539,7 @@ impl Rewrite for MergeLowering {
}
fn rewrite(&self, op: Arc<RwLock<dyn Op>>) -> Result<RewriteResult> {
let blocks = op.operation().region().unwrap().blocks();
let blocks = blocks.try_read().unwrap();
for block in blocks.iter() {
for block in blocks.into_iter() {
let block_read = block.try_read().unwrap();
let has_argument = !block_read.arguments().vec().try_read().unwrap().is_empty();
if has_argument {
Expand Down
66 changes: 24 additions & 42 deletions xrcf/src/convert/scf_to_cf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::ir::Block;
use crate::ir::BlockArgument;
use crate::ir::BlockArgumentName;
use crate::ir::BlockName;
use crate::ir::BlockPtr;
use crate::ir::GuardedBlock;
use crate::ir::GuardedOp;
use crate::ir::GuardedOperation;
Expand Down Expand Up @@ -80,42 +79,23 @@ fn branch_op(after: Arc<RwLock<Block>>) -> Arc<RwLock<dyn Op>> {
new_op
}

fn add_block_from_region(
label: String,
after: Arc<RwLock<Block>>,
region: Arc<RwLock<Region>>,
parent_region: Arc<RwLock<Region>>,
) -> Result<Arc<RwLock<OpOperand>>> {
let mut ops = region.ops();
/// Add a `cf.br` to the end of `block` with destination `after`.
fn add_branch_to_after(block: Arc<RwLock<Block>>, after: Arc<RwLock<Block>>) {
let ops = block.ops();
let mut ops = ops.try_write().unwrap();
let ops_clone = ops.clone();
let last_op = ops_clone.last().unwrap();
let last_op = last_op.try_read().unwrap();
let yield_op = last_op.as_any().downcast_ref::<dialect::scf::YieldOp>();
if let Some(yield_op) = yield_op {
let new_op = lower_yield_op(&yield_op, after.clone())?;
let new_op = lower_yield_op(&yield_op, after.clone()).unwrap();
ops.pop();
ops.push(new_op.clone());
} else {
let new_op = branch_op(after.clone());
new_op.set_parent(block.clone());
ops.push(new_op.clone());
};

let unset_block = parent_region.add_empty_block_before(after);
let block = unset_block.set_parent(Some(parent_region.clone()));
block.set_ops(Arc::new(RwLock::new(ops.clone())));
block.set_label(BlockName::Unset);
for op in ops.iter() {
let op = op.try_read().unwrap();
op.set_parent(block.clone());
}
let block_label = BlockName::Name(label.clone());
block.set_label(block_label);

let label = Value::BlockPtr(BlockPtr::new(block.clone()));
let label = Arc::new(RwLock::new(label));
let operand = OpOperand::new(label);
let operand = Arc::new(RwLock::new(operand));
Ok(operand)
}

/// Move all successors of `scf.if` to the return block.
Expand Down Expand Up @@ -258,20 +238,21 @@ fn results_users(results: Values) -> Vec<Users> {
fn add_blocks(
op: &dialect::scf::IfOp,
parent_region: Arc<RwLock<Region>>,
) -> Result<(Arc<RwLock<OpOperand>>, Arc<RwLock<OpOperand>>)> {
) -> Result<(Arc<RwLock<Block>>, Arc<RwLock<Block>>)> {
let results = op.operation().results();
let results_users = results_users(results.clone());
let exit = add_exit_block(op, parent_region.clone())?;
let then_label = format!("{}", parent_region.unique_block_name());
let then_label_index = then_label
.trim_start_matches("^bb")
.parse::<usize>()
.unwrap();
let has_results = !results.is_empty();
let else_label = format!("^bb{}", then_label_index + 1);

let then = op.then().expect("Expected `then` region");
let els = op.els().expect("Expected `else` region");
let then_region = op.then().expect("Expected `then` region");
let then = then_region.blocks().into_iter().next().unwrap();
then.set_label(BlockName::Unset);
exit.inline_region_before(then_region.clone());

let else_region = op.els().expect("Expected `else` region");
let els = else_region.blocks().into_iter().next().unwrap();
els.set_label(BlockName::Unset);
exit.inline_region_before(else_region.clone());

let after = if has_results {
let (merge, merge_block_arguments) =
Expand All @@ -298,11 +279,10 @@ fn add_blocks(
exit.clone()
};

let then_operand =
add_block_from_region(then_label, after.clone(), then, parent_region.clone())?;
let else_operand = add_block_from_region(else_label, after, els, parent_region.clone())?;
add_branch_to_after(then.clone(), after.clone());
add_branch_to_after(els.clone(), after.clone());

Ok((then_operand, else_operand))
Ok((then, els))
}

impl Rewrite for IfLowering {
Expand All @@ -318,13 +298,15 @@ impl Rewrite for IfLowering {
let parent_region = parent.parent().expect("Expected parent region");
let op = op.as_any().downcast_ref::<dialect::scf::IfOp>().unwrap();

let (then_operand, else_operand) = add_blocks(&op, parent_region.clone())?;
let (then, els) = add_blocks(&op, parent_region.clone())?;

let mut operation = Operation::default();
operation.set_parent(Some(parent.clone()));
operation.set_operand(0, op.operation().operand(0).clone().unwrap());
operation.set_operand(1, then_operand.clone());
operation.set_operand(2, else_operand.clone());
let then_operand = Arc::new(RwLock::new(OpOperand::from_block(then)));
operation.set_operand(1, then_operand);
let els_operand = Arc::new(RwLock::new(OpOperand::from_block(els)));
operation.set_operand(2, els_operand);
let new = dialect::cf::CondBranchOp::from_operation(operation);
let new: Arc<RwLock<dyn Op>> = Arc::new(RwLock::new(new));
op.replace(new.clone());
Expand Down
7 changes: 2 additions & 5 deletions xrcf/src/dialect/func/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,11 +413,8 @@ impl<T: ParserDispatch> Parser<T> {

{
let blocks = region.blocks();
let blocks = blocks.try_read().unwrap();
let block = blocks.first().unwrap();
let arguments = arguments.vec();
let arguments = arguments.try_read().unwrap();
for argument in arguments.iter() {
let block = blocks.into_iter().next().unwrap();
for argument in arguments.into_iter() {
argument.set_parent(Some(block.clone()));
}
}
Expand Down
100 changes: 86 additions & 14 deletions xrcf/src/ir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::ir::BlockArgumentName;
use crate::ir::GuardedOp;
use crate::ir::GuardedOpOperand;
use crate::ir::GuardedOperation;
use crate::ir::GuardedRegion;
use crate::ir::Op;
use crate::ir::Operation;
use crate::ir::Region;
Expand Down Expand Up @@ -164,10 +165,9 @@ impl Block {
let region = region.try_read().unwrap();
let index = region.index_of(self);
let blocks = region.blocks();
let predecessors = blocks.try_read().unwrap();
let predecessors = match index {
Some(index) => predecessors[..index].to_vec(),
None => predecessors.clone(),
Some(index) => blocks.into_iter().take(index).collect(),
None => blocks.into_iter().collect(),
};
Some(predecessors)
}
Expand All @@ -180,9 +180,8 @@ impl Block {
let region = region.try_read().unwrap();
let index = region.index_of(self);
let blocks = region.blocks();
let successors = blocks.try_read().unwrap();
let successors = match index {
Some(index) => successors[index + 1..].to_vec(),
Some(index) => blocks.into_iter().skip(index + 1).collect(),
None => panic!("Expected block to be in region"),
};
Some(successors)
Expand Down Expand Up @@ -335,24 +334,31 @@ impl Block {
}
None
}
/// Return index of `op` in `self`.
///
/// Returns `None` if `op` is not found in `self`.
pub fn index_of(&self, op: &Operation) -> Option<usize> {
let ops = self.ops();
let ops = ops.try_read().unwrap();
for (i, current) in (&ops).iter().enumerate() {
ops.iter().position(|current| {
let current = current.try_read().unwrap();
let current = current.operation();
let current = current.try_read().unwrap();
if *current == *op {
return Some(i);
}
}
None
let current = &*current.try_read().unwrap();
std::ptr::eq(current, op)
})
}
pub fn index_of_arc(&self, op: Arc<RwLock<Operation>>) -> Option<usize> {
self.index_of(&*op.try_read().unwrap())
}
pub fn inline_region_before(&self, _region: Arc<RwLock<Region>>) {
todo!()
/// Move the blocks that belong to `region` before `self`.
///
/// The caller is in charge of transferring the control flow to the region
/// and pass it the correct block arguments.
pub fn inline_region_before(&self, region: Arc<RwLock<Region>>) {
let parent = self.parent();
let parent = parent.expect("no parent");
let blocks = parent.blocks();
blocks.splice(self, region.blocks());
}
pub fn insert_op(&self, op: Arc<RwLock<dyn Op>>, index: usize) {
let ops = self.ops();
Expand Down Expand Up @@ -588,3 +594,69 @@ impl GuardedBlock for Arc<RwLock<Block>> {
self.try_read().unwrap().unique_value_name(prefix)
}
}

#[derive(Clone)]
pub struct Blocks {
vec: Arc<RwLock<Vec<Arc<RwLock<Block>>>>>,
}

impl IntoIterator for Blocks {
type Item = Arc<RwLock<Block>>;
type IntoIter = std::vec::IntoIter<Self::Item>;

fn into_iter(self) -> Self::IntoIter {
let vec = self.vec.try_read().unwrap();
vec.clone().into_iter()
}
}

impl Blocks {
pub fn new(vec: Arc<RwLock<Vec<Arc<RwLock<Block>>>>>) -> Self {
Self { vec }
}
pub fn vec(&self) -> Arc<RwLock<Vec<Arc<RwLock<Block>>>>> {
self.vec.clone()
}
/// Return the index of `block` in `self`.
///
/// Returns `None` if `block` is not found in `self`.
pub fn index_of(&self, block: &Block) -> Option<usize> {
let vec = self.vec();
let vec = vec.try_read().unwrap();
if vec.is_empty() {
panic!("Trying to find block in empty set of blocks");
}
vec.iter().position(|b| {
let b = &*b.try_read().unwrap();
std::ptr::eq(b, block)
})
}
fn transfer(&self, before: &Block, blocks: Blocks) {
let index = self.index_of(before);
let index = match index {
Some(index) => index,
None => {
panic!("Could not find block in blocks during transfer");
}
};
let vec = self.vec();
let mut vec = vec.try_write().unwrap();
let blocks = blocks.vec();
let mut blocks = blocks.try_write().unwrap();
vec.splice(index..index, blocks.iter().cloned());
{
let parent = before.parent();
for block in blocks.iter() {
let mut block = block.try_write().unwrap();
block.set_parent(parent.clone());
}
}
blocks.clear();
}
/// Move `blocks` before `before` in `self`.
///
/// This also handles side-effects like updating parents.
pub fn splice(&self, before: &Block, blocks: Blocks) {
self.transfer(before, blocks);
}
}
1 change: 1 addition & 0 deletions xrcf/src/ir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub use attribute::IntegerAttr;
pub use attribute::StringAttr;
pub use block::Block;
pub use block::BlockName;
pub use block::Blocks;
pub use block::GuardedBlock;
pub use block::UnsetBlock;
pub use module::ModuleOp;
Expand Down
3 changes: 1 addition & 2 deletions xrcf/src/ir/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ impl ModuleOp {
None => return Err(anyhow::anyhow!("Expected 1 region in module, got 0")),
};
let blocks = region.blocks();
let blocks = blocks.try_read().unwrap();
let block = match blocks.first() {
let block = match blocks.into_iter().next() {
Some(block) => block,
None => return Err(anyhow::anyhow!("Expected 1 block in module, got 0")),
};
Expand Down
4 changes: 4 additions & 0 deletions xrcf/src/ir/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ pub trait GuardedOp {
fn remove(&self);
fn replace(&self, new: Arc<RwLock<dyn Op>>);
fn result(&self, index: usize) -> Arc<RwLock<Value>>;
fn set_parent(&self, parent: Arc<RwLock<Block>>);
}

impl GuardedOp for Arc<RwLock<dyn Op>> {
Expand Down Expand Up @@ -277,4 +278,7 @@ impl GuardedOp for Arc<RwLock<dyn Op>> {
fn result(&self, index: usize) -> Arc<RwLock<Value>> {
self.try_read().unwrap().result(index)
}
fn set_parent(&self, parent: Arc<RwLock<Block>>) {
self.try_read().unwrap().set_parent(parent);
}
}
7 changes: 2 additions & 5 deletions xrcf/src/ir/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ pub fn display_region_inside_func(
if let Some(region) = region {
let region = region.try_read().unwrap();
let blocks = region.blocks();
let blocks = blocks.try_read().unwrap();
if blocks.is_empty() {
if blocks.into_iter().next().is_none() {
write!(f, "\n")
} else {
region.display(f, indent)
Expand Down Expand Up @@ -203,9 +202,7 @@ impl Operation {
let region = self.region();
let region = region.expect("expected region");
let region = region.try_read().unwrap();
let blocks = region.blocks();
let blocks = blocks.try_read().unwrap();
blocks.to_vec()
region.blocks().into_iter().collect::<Vec<_>>()
}
pub fn operand_types(&self) -> Types {
let operands = self.operands.vec();
Expand Down
Loading

0 comments on commit ed7ebfb

Please sign in to comment.