Skip to content

Commit

Permalink
De-duplicate similar functions
Browse files Browse the repository at this point in the history
  • Loading branch information
gligneul committed Dec 19, 2024
1 parent de3f600 commit 7165f30
Showing 1 changed file with 67 additions and 96 deletions.
163 changes: 67 additions & 96 deletions stylus-proc/src/macros/public/types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright 2022-2024, Offchain Labs, Inc.
// For licensing, see https://github.com/OffchainLabs/stylus-sdk-rs/blob/main/licenses/COPYRIGHT.md

use proc_macro2::{Span, TokenStream};
use proc_macro_error::emit_error;
use quote::{quote, ToTokens};
Expand All @@ -18,6 +19,34 @@ use crate::{

use super::Extension;

/// Generate the code to call the special function (fallback, receive, or constructor) from the
/// public impl block. Emits an error if there are multiple implementations.
macro_rules! call_special {
($self:expr, $kind:pat, $kind_name:literal, $func:expr) => {{
let specials: Vec<syn::Stmt> = $self
.funcs
.iter()
.filter(|&func| matches!(func.kind, $kind))
.map($func)
.collect();
if specials.is_empty() {
None
} else {
if specials.len() > 1 {
emit_error!(
concat!("multiple ", $kind_name),
concat!(
"contract can only have one #[",
$kind_name,
"] method defined"
)
);
}
specials.first().cloned()
}
}};
}

pub struct PublicImpl<E: InterfaceExtension = Extension> {
pub self_ty: syn::Type,
pub generic_params: Punctuated<syn::GenericParam, Token![,]>,
Expand Down Expand Up @@ -47,30 +76,22 @@ impl PublicImpl {
.collect::<Vec<_>>();
let inheritance_routes = self.inheritance_routes();

let call_fallback = self.call_fallback();
let call_fallback = call_special!(
self,
FnKind::Fallback { .. },
"fallback",
PublicFn::call_fallback
);
let inheritance_fallback = self.inheritance_fallback();

let (fallback, fallback_purity) = call_fallback.unwrap_or_else(|| {
let fallback = call_fallback.unwrap_or_else(|| {
// If there is no fallback function specified, we rely on any inherited fallback.
(
parse_quote!({
#(#inheritance_fallback)*
None
}),
Purity::Payable, // Let the inherited fallback deal with purity.
)
parse_quote!({
#(#inheritance_fallback)*
None
})
});

let fallback_deny: Option<syn::ExprIf> = match fallback_purity {
Purity::Payable => None,
_ => Some(parse_quote! {
if let Err(err) = stylus_sdk::abi::internal::deny_value("fallback") {
return Some(Err(err));
}
}),
};

let call_receive = self.call_receive();
let call_receive = call_special!(self, FnKind::Receive, "receive", PublicFn::call_receive);
let inheritance_receive = self.inheritance_receive();
let receive = call_receive.unwrap_or_else(|| {
parse_quote!({
Expand All @@ -79,7 +100,12 @@ impl PublicImpl {
})
});

let call_constructor = self.call_constructor();
let call_constructor = call_special!(
self,
FnKind::Constructor,
"constructor",
PublicFn::call_constructor
);
let constructor = call_constructor.unwrap_or_else(|| parse_quote!({ None }));

parse_quote! {
Expand Down Expand Up @@ -115,7 +141,6 @@ impl PublicImpl {

#[inline(always)]
fn fallback(storage: &mut S, input: &[u8]) -> Option<stylus_sdk::ArbResult> {
#fallback_deny
#fallback
}

Expand All @@ -142,35 +167,6 @@ impl PublicImpl {
})
}

fn call_fallback(&self) -> Option<(syn::Stmt, Purity)> {
let mut fallback_purity = Purity::View;
let fallbacks: Vec<syn::Stmt> = self
.funcs
.iter()
.filter(|&func| {
if matches!(func.kind, FnKind::Fallback { .. }) {
fallback_purity = func.purity;
return true;
}
false
})
.map(PublicFn::call_fallback)
.collect();
if fallbacks.is_empty() {
return None;
}
if fallbacks.len() > 1 {
emit_error!(
"multiple fallbacks",
"contract can only have one #[fallback] method defined"
);
}
fallbacks
.first()
.cloned()
.map(|func| (func, fallback_purity))
}

fn inheritance_fallback(&self) -> impl Iterator<Item = syn::ExprIf> + '_ {
self.inheritance.iter().map(|ty| {
parse_quote! {
Expand All @@ -181,25 +177,6 @@ impl PublicImpl {
})
}

fn call_receive(&self) -> Option<syn::Stmt> {
let receives: Vec<syn::Stmt> = self
.funcs
.iter()
.filter(|&func| matches!(func.kind, FnKind::Receive))
.map(PublicFn::call_receive)
.collect();
if receives.is_empty() {
return None;
}
if receives.len() > 1 {
emit_error!(
"multiple receives",
"contract can only have one #[receive] method defined"
);
}
receives.first().cloned()
}

fn inheritance_receive(&self) -> impl Iterator<Item = syn::ExprIf> + '_ {
self.inheritance.iter().map(|ty| {
parse_quote! {
Expand All @@ -209,25 +186,6 @@ impl PublicImpl {
}
})
}

fn call_constructor(&self) -> Option<syn::Stmt> {
let constructors: Vec<syn::Stmt> = self
.funcs
.iter()
.filter(|&func| matches!(func.kind, FnKind::Constructor))
.map(PublicFn::call_constructor)
.collect();
if constructors.is_empty() {
return None;
}
if constructors.len() > 1 {
emit_error!(
"multiple constructors",
"contract can only have one #[constructor] method defined"
);
}
constructors.first().cloned()
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -356,16 +314,29 @@ impl<E: FnExtension> PublicFn<E> {
}

fn call_fallback(&self) -> syn::Stmt {
let deny_value: Option<syn::ExprIf> = match self.purity {
Purity::Payable => None,
_ => Some(parse_quote! {
if let Err(err) = stylus_sdk::abi::internal::deny_value("fallback") {
return Some(Err(err));
}
}),
};
let name = &self.name;
let storage_arg = self.storage_arg();
if matches!(self.kind, FnKind::Fallback { with_args: false }) {
return parse_quote! {
let call: syn::Stmt = if matches!(self.kind, FnKind::Fallback { with_args: false }) {
parse_quote! {
return Some(Self::#name(#storage_arg));
};
}
parse_quote! {
return Some(Self::#name(#storage_arg input));
}
}
} else {
parse_quote! {
return Some(Self::#name(#storage_arg input));
}
};
parse_quote!({
#deny_value
#call
})
}

fn call_receive(&self) -> syn::Stmt {
Expand Down

0 comments on commit 7165f30

Please sign in to comment.