Skip to content

Commit

Permalink
Introduce Shared and SharedExt, and refactor (#36)
Browse files Browse the repository at this point in the history
Introduces `Shared` and `SharedExt` as a convenience typ and trait that
turn
```rust
use std::sync::{Arc, RwLock};

let lock = Arc::new(RwLock::new(42));
assert_eq!(*lock.try_read().unwrap(), 42);
```
into
```rust
use std::sync::{Arc, RwLock};
use xrcf::shared::{Shared, SharedExt};

let lock = Shared::new(42.into());
assert_eq!(*lock.re(), 42);
```

This is a tradeoff between making it easier to use the `Arc<RwLock<T>>`
while also not wrapping it in a completely different object which would
add another layer of indirection.

The most important thing is now that thanks to `SharedExt`, writing
`lock.rd().` lists the available methods. This is much easier to quickly
check the available methods than via `lock.try_read().unwrap().`.
Another benefit is that `re` is much shorter so it's more likely to fit
into one line. One-liners work particularly well in Rust do to when
variables are freed, see https://xrcf.org/blog/iterators/ for more
information.
  • Loading branch information
rikhuijzer authored Dec 21, 2024
1 parent ed7ebfb commit 382d067
Show file tree
Hide file tree
Showing 40 changed files with 992 additions and 1,196 deletions.
41 changes: 21 additions & 20 deletions arnoldc/src/arnold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use xrcf::parser::Parse;
use xrcf::parser::Parser;
use xrcf::parser::ParserDispatch;
use xrcf::parser::TokenKind;
use xrcf::shared::Shared;
use xrcf::shared::SharedExt;

/// The token kind used for variables in ArnoldC.
///
Expand Down Expand Up @@ -150,12 +152,11 @@ impl Op for BeginMainOp {
&self.operation
}
fn display(&self, f: &mut Formatter<'_>, indent: i32) -> std::fmt::Result {
let operation = self.operation.try_read().unwrap();
let operation = self.operation.rd();
write!(f, "{} ", operation.name())?;
let region = operation.region().unwrap();
let region = region.try_read().unwrap();
write!(f, "()")?;
region.display(f, indent)?;
region.rd().display(f, indent)?;
Ok(())
}
}
Expand All @@ -171,13 +172,13 @@ impl Parse for BeginMainOp {
let name = BeginMainOp::operation_name();
parser.parse_arnold_operation_name_into(name, &mut operation)?;

let operation = Arc::new(RwLock::new(operation));
let operation = Shared::new(operation.into());
let op = BeginMainOp {
operation: operation.clone(),
};
let op = Arc::new(RwLock::new(op));
let op = Shared::new(op.into());
let region = parser.parse_region(op.clone())?;
let mut operation = operation.write().unwrap();
let mut operation = operation.wr();
operation.set_region(Some(region.clone()));
Ok(op)
}
Expand Down Expand Up @@ -226,12 +227,12 @@ impl Parse for CallOp {
let identifier = identifier.lexeme.clone();
parser.expect(TokenKind::LParen)?;
parser.expect(TokenKind::RParen)?;
let operation = Arc::new(RwLock::new(operation));
let operation = Shared::new(operation.into());
let op = CallOp {
operation: operation.clone(),
identifier: Some(identifier),
};
Ok(Arc::new(RwLock::new(op)))
Ok(Shared::new(op.into()))
}
}

Expand All @@ -254,7 +255,7 @@ impl Op for DeclareIntOp {
}
fn display(&self, f: &mut Formatter<'_>, _indent: i32) -> std::fmt::Result {
write!(f, "{}", Self::operation_name())?;
write!(f, " {}", self.operation().read().unwrap().results())?;
write!(f, " {}", self.operation().rd().results())?;
Ok(())
}
}
Expand All @@ -269,12 +270,12 @@ impl Parse for DeclareIntOp {
let name = DeclareIntOp::operation_name();
parser.parse_arnold_operation_name_into(name, &mut operation)?;
let result = parser.parse_op_result_into(TOKEN_KIND, &mut operation)?;
let operation = Arc::new(RwLock::new(operation));
let operation = Shared::new(operation.into());
let op = DeclareIntOp { operation };
let op = Arc::new(RwLock::new(op));
let op = Shared::new(op.into());
result.set_defining_op(Some(op.clone()));
let typ = IntegerType::new(16);
let typ = Arc::new(RwLock::new(typ));
let typ = Shared::new(typ.into());
result.set_typ(typ);
Ok(op)
}
Expand Down Expand Up @@ -334,21 +335,21 @@ impl Parse for IfOp {
let name = IfOp::operation_name();
parser.parse_arnold_operation_name_into(name, &mut operation)?;
parser.parse_op_operand_into(parent.clone().unwrap(), TOKEN_KIND, &mut operation)?;
let operation = Arc::new(RwLock::new(operation));
let operation = Shared::new(operation.into());
let op = IfOp {
operation: operation.clone(),
then: None,
els: None,
};
let op = Arc::new(RwLock::new(op));
let op = Shared::new(op.into());
let then = parser.parse_region(op.clone())?;
let else_keyword = parser.expect(TokenKind::BareIdentifier)?;
if else_keyword.lexeme != "BULLSHIT" {
panic!("Expected BULLSHIT but got {}", else_keyword.lexeme);
}
let els = parser.parse_region(op.clone())?;
let op_write = op.clone();
let mut op_write = op_write.try_write().unwrap();
let mut op_write = op_write.wr();
op_write.then = Some(then);
op_write.els = Some(els);
Ok(op)
Expand Down Expand Up @@ -394,7 +395,7 @@ impl Op for PrintOp {
}
fn display(&self, f: &mut Formatter<'_>, _indent: i32) -> std::fmt::Result {
write!(f, "{}", Self::operation_name())?;
write!(f, " {}", self.text().try_read().unwrap())?;
write!(f, " {}", self.text().rd())?;
Ok(())
}
}
Expand All @@ -408,13 +409,13 @@ impl Parse for PrintOp {
operation.set_parent(parent.clone());
let name = PrintOp::operation_name();
parser.parse_arnold_operation_name_into(name, &mut operation)?;
let operation = Arc::new(RwLock::new(operation));
let operation = Shared::new(operation.into());
let text = parser.parse_op_operand(parent.clone().unwrap(), TOKEN_KIND)?;
let mut op = PrintOp {
operation: operation.clone(),
};
op.set_text(text);
Ok(Arc::new(RwLock::new(op)))
Ok(Shared::new(op.into()))
}
}

Expand Down Expand Up @@ -458,10 +459,10 @@ impl Parse for SetInitialValueOp {
parser.parse_arnold_constant_into(&mut operation)?;
}

let operation = Arc::new(RwLock::new(operation));
let operation = Shared::new(operation.into());
let op = SetInitialValueOp {
operation: operation.clone(),
};
Ok(Arc::new(RwLock::new(op)))
Ok(Shared::new(op.into()))
}
}
46 changes: 24 additions & 22 deletions arnoldc/src/arnold_to_mlir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use xrcf::ir::Operation;
use xrcf::ir::RenameBareToPercent;
use xrcf::ir::StringAttr;
use xrcf::ir::Value;
use xrcf::shared::Shared;
use xrcf::shared::SharedExt;

const RENAMER: RenameBareToPercent = RenameBareToPercent;

Expand All @@ -40,14 +42,14 @@ impl Rewrite for CallLowering {
Ok(op.as_any().is::<arnold::CallOp>())
}
fn rewrite(&self, op: Arc<RwLock<dyn Op>>) -> Result<RewriteResult> {
let op = op.try_read().unwrap();
let op = op.rd();
let op = op.as_any().downcast_ref::<arnold::CallOp>().unwrap();
let identifier = op.identifier().unwrap();
let operation = op.operation();
let mut new_op = func::CallOp::from_operation_arc(operation.clone());
let identifier = format!("@{}", identifier);
new_op.set_identifier(identifier);
let new_op = Arc::new(RwLock::new(new_op));
let new_op = Shared::new(new_op.into());
op.replace(new_op.clone());
Ok(RewriteResult::Changed(ChangedOp::new(new_op)))
}
Expand All @@ -74,13 +76,13 @@ impl Rewrite for DeclareIntLowering {
Ok(op.as_any().is::<arnold::DeclareIntOp>())
}
fn rewrite(&self, op: Arc<RwLock<dyn Op>>) -> Result<RewriteResult> {
let op = op.try_read().unwrap();
let op = op.rd();
let op = op.as_any().downcast_ref::<arnold::DeclareIntOp>().unwrap();
op.operation().rename_variables(&RENAMER)?;

let successors = op.operation().successors();
let set_initial_value = successors.first().unwrap();
let set_initial_value = set_initial_value.try_read().unwrap();
let set_initial_value = set_initial_value.rd();
let set_initial_value = set_initial_value
.as_any()
.downcast_ref::<arnold::SetInitialValueOp>()
Expand All @@ -91,7 +93,7 @@ impl Rewrite for DeclareIntLowering {
new_op.set_parent(op.operation().parent().clone().unwrap());
new_op.set_value(set_initial_value.value());
set_initial_value.remove();
let new_op = Arc::new(RwLock::new(new_op));
let new_op = Shared::new(new_op.into());
op.replace(new_op.clone());
Ok(RewriteResult::Changed(ChangedOp::new(new_op)))
}
Expand All @@ -107,13 +109,13 @@ impl Rewrite for FuncLowering {
Ok(op.as_any().is::<arnold::BeginMainOp>())
}
fn rewrite(&self, op: Arc<RwLock<dyn Op>>) -> Result<RewriteResult> {
let op = op.try_read().unwrap();
let op = op.rd();
let op = op.as_any().downcast_ref::<arnold::BeginMainOp>().unwrap();
let identifier = "@main";
let operation = op.operation();
let mut new_op = func::FuncOp::from_operation_arc(operation.clone());
new_op.set_identifier(identifier.to_string());
let new_op = Arc::new(RwLock::new(new_op));
let new_op = Shared::new(new_op.into());
op.replace(new_op.clone());
Ok(RewriteResult::Changed(ChangedOp::new(new_op)))
}
Expand Down Expand Up @@ -147,14 +149,14 @@ impl Rewrite for IfLowering {
Ok(op.as_any().is::<arnold::IfOp>())
}
fn rewrite(&self, op: Arc<RwLock<dyn Op>>) -> Result<RewriteResult> {
let op = op.try_read().unwrap();
let op = op.rd();
let op = op.as_any().downcast_ref::<arnold::IfOp>().unwrap();
let operation = op.operation();
let mut new_op = scf::IfOp::from_operation_arc(operation.clone());
new_op.set_parent(operation.parent().clone().unwrap());
new_op.set_then(op.then().clone());
new_op.set_els(op.els().clone());
let new_op = Arc::new(RwLock::new(new_op));
let new_op = Shared::new(new_op.into());
op.replace(new_op.clone());
Ok(RewriteResult::Changed(ChangedOp::new(new_op)))
}
Expand All @@ -170,12 +172,12 @@ impl ModuleLowering {
let typ = IntegerType::new(32);
let value = APInt::new(32, 0, true);
let integer = IntegerAttr::new(typ, value);
let name = parent.try_read().unwrap().unique_value_name("%");
let result_type = Arc::new(RwLock::new(typ));
let name = parent.rd().unique_value_name("%");
let result_type = Shared::new(typ.into());
let result = constant.add_new_op_result(&name, result_type.clone());
let constant = arith::ConstantOp::from_operation(constant);
constant.set_value(Arc::new(integer));
let constant = Arc::new(RwLock::new(constant));
let constant = Shared::new(constant.into());
result.set_defining_op(Some(constant.clone()));
constant
}
Expand All @@ -184,24 +186,24 @@ impl ModuleLowering {
constant: Arc<RwLock<dyn Op>>,
) -> Arc<RwLock<dyn Op>> {
let typ = IntegerType::new(32);
let result_type = Arc::new(RwLock::new(typ));
let result_type = Shared::new(typ.into());
let mut ret = Operation::default();
ret.set_parent(Some(parent.clone()));
ret.set_name(func::ReturnOp::operation_name());
ret.set_anonymous_result(result_type).unwrap();
let value = constant.result(0);
let operand = OpOperand::new(value);
let operand = Arc::new(RwLock::new(operand));
let operand = Shared::new(operand.into());
ret.set_operand(0, operand);
let ret = func::ReturnOp::from_operation(ret);
let ret = Arc::new(RwLock::new(ret));
let ret = Shared::new(ret.into());
ret
}
fn return_zero(func: Arc<RwLock<dyn Op>>) {
let operation = func.operation();
let typ = IntegerType::new(32);
operation
.set_anonymous_result(Arc::new(RwLock::new(typ)))
.set_anonymous_result(Shared::new(typ.into()))
.unwrap();

let ops = func.ops();
Expand All @@ -219,10 +221,10 @@ impl ModuleLowering {
constant.insert_after(ret.clone());
}
fn returns_something(func: Arc<RwLock<dyn Op>>) -> bool {
let func = func.try_read().unwrap();
let func = func.rd();
let func_op = func.as_any().downcast_ref::<func::FuncOp>().unwrap();
let result = func_op.operation().results();
result.vec().try_read().unwrap().len() == 1
result.vec().rd().len() == 1
}
fn ensure_main_returns_zero(module: Arc<RwLock<dyn Op>>) -> Result<RewriteResult> {
let ops = module.ops();
Expand Down Expand Up @@ -258,17 +260,17 @@ impl Rewrite for PrintLowering {
Ok(op.as_any().is::<arnold::PrintOp>())
}
fn rewrite(&self, op: Arc<RwLock<dyn Op>>) -> Result<RewriteResult> {
let op = op.try_read().unwrap();
let op = op.rd();
let op = op.as_any().downcast_ref::<arnold::PrintOp>().unwrap();
let mut operation = Operation::default();
operation.set_name(experimental::PrintfOp::operation_name());
let operation = Arc::new(RwLock::new(operation));
let operation = Shared::new(operation.into());
let mut new_op = experimental::PrintfOp::from_operation_arc(operation.clone());
new_op.set_parent(op.operation().parent().clone().unwrap());

let operand = op.text();
let value = operand.value();
match &*value.try_read().unwrap() {
match &*value.rd() {
// printf("some text")
Value::Constant(constant) => {
let text = constant.value();
Expand All @@ -284,7 +286,7 @@ impl Rewrite for PrintLowering {
_ => panic!("expected constant or op result"),
};

let new_op = Arc::new(RwLock::new(new_op));
let new_op = Shared::new(new_op.into());
op.replace(new_op.clone());
Ok(RewriteResult::Changed(ChangedOp::new(new_op)))
}
Expand Down
10 changes: 6 additions & 4 deletions arnoldc/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::env::ArgsOs;
use std::io::Read;
use xrcf::convert::RewriteResult;
use xrcf::init_subscriber;
use xrcf::shared::SharedExt;
use xrcf::Passes;
use xrcf::TransformOptions;

Expand Down Expand Up @@ -89,7 +90,7 @@ fn main() {

let result = parse_and_transform(&input_text, &options).unwrap();
let result = match result {
RewriteResult::Changed(op) => op.op.try_read().unwrap().to_string(),
RewriteResult::Changed(op) => op.op.rd().to_string(),
RewriteResult::Unchanged => input_text.to_string(),
};
println!("{result}");
Expand All @@ -104,6 +105,7 @@ mod tests {
use std::sync::Arc;
use std::sync::RwLock;
use xrcf::convert::RewriteResult;
use xrcf::shared::Shared;
use xrcf::tester::Tester;

fn run_app(
Expand Down Expand Up @@ -164,20 +166,20 @@ mod tests {
"--print-ir-before-all",
];
tracing::info!("\nBefore {args:?}:\n{src}");
let out: Arc<RwLock<Vec<u8>>> = Arc::new(RwLock::new(Vec::new()));
let out: Arc<RwLock<Vec<u8>>> = Shared::new(Vec::new().into());
let result = run_app(Some(out.clone()), args.clone(), &src);
assert!(result.is_ok());
let actual = match result.unwrap() {
RewriteResult::Changed(op) => {
let op = op.op.try_read().unwrap();
let op = op.op.rd();
op.to_string()
}
RewriteResult::Unchanged => panic!("Expected a change"),
};
tracing::info!("\nAfter {args:?}:\n{actual}");
assert!(actual.contains("define i32 @main"));

let printed = out.try_read().unwrap();
let printed = out.rd();
let printed = String::from_utf8(printed.clone()).unwrap();
let expected = indoc! {r#"
// ----- // IR Dump before convert-func-to-llvm //----- //
Expand Down
3 changes: 2 additions & 1 deletion arnoldc/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ mod tests {
use indoc::indoc;
use std::panic::Location;
use tracing;
use xrcf::shared::SharedExt;
use xrcf::tester::Tester;
use xrcf::Passes;

Expand Down Expand Up @@ -209,7 +210,7 @@ mod tests {
panic!("Expected changes");
}
};
let actual = new_root_op.try_read().unwrap().to_string();
let actual = new_root_op.rd().to_string();
print_heading("After", &actual, &passes);
(new_root_op, actual)
}
Expand Down
3 changes: 2 additions & 1 deletion xrcf-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use xrcf::convert::RewriteResult;
use xrcf::init_subscriber;
use xrcf::parser::DefaultParserDispatch;
use xrcf::parser::Parser;
use xrcf::shared::SharedExt;
use xrcf::transform;
use xrcf::DefaultTransformDispatch;
use xrcf::Passes;
Expand Down Expand Up @@ -38,7 +39,7 @@ fn parse_and_transform(src: &str, options: &TransformOptions) -> String {
let module = Parser::<DefaultParserDispatch>::parse(&src).unwrap();
let result = transform::<DefaultTransformDispatch>(module, options).unwrap();
let result = match result {
RewriteResult::Changed(op) => op.op.try_read().unwrap().to_string(),
RewriteResult::Changed(op) => op.op.rd().to_string(),
RewriteResult::Unchanged => src.to_string(),
};
result
Expand Down
Loading

0 comments on commit 382d067

Please sign in to comment.