From 7165f305f7111cf801d490e1ac4258af00acf306 Mon Sep 17 00:00:00 2001 From: Gabriel de Quadros Ligneul Date: Thu, 19 Dec 2024 14:43:10 -0300 Subject: [PATCH] De-duplicate similar functions --- stylus-proc/src/macros/public/types.rs | 163 ++++++++++--------------- 1 file changed, 67 insertions(+), 96 deletions(-) diff --git a/stylus-proc/src/macros/public/types.rs b/stylus-proc/src/macros/public/types.rs index 7891a5f..00df7a7 100644 --- a/stylus-proc/src/macros/public/types.rs +++ b/stylus-proc/src/macros/public/types.rs @@ -1,5 +1,6 @@ // Copyright 2022-2024, Offchain Labs, Inc. // For licensing, see https://github.com/OffchainLabs/stylus-sdk-rs/blob/main/licenses/COPYRIGHT.md + use proc_macro2::{Span, TokenStream}; use proc_macro_error::emit_error; use quote::{quote, ToTokens}; @@ -18,6 +19,34 @@ use crate::{ use super::Extension; +/// Generate the code to call the special function (fallback, receive, or constructor) from the +/// public impl block. Emits an error if there are multiple implementations. +macro_rules! call_special { + ($self:expr, $kind:pat, $kind_name:literal, $func:expr) => {{ + let specials: Vec = $self + .funcs + .iter() + .filter(|&func| matches!(func.kind, $kind)) + .map($func) + .collect(); + if specials.is_empty() { + None + } else { + if specials.len() > 1 { + emit_error!( + concat!("multiple ", $kind_name), + concat!( + "contract can only have one #[", + $kind_name, + "] method defined" + ) + ); + } + specials.first().cloned() + } + }}; +} + pub struct PublicImpl { pub self_ty: syn::Type, pub generic_params: Punctuated, @@ -47,30 +76,22 @@ impl PublicImpl { .collect::>(); let inheritance_routes = self.inheritance_routes(); - let call_fallback = self.call_fallback(); + let call_fallback = call_special!( + self, + FnKind::Fallback { .. }, + "fallback", + PublicFn::call_fallback + ); let inheritance_fallback = self.inheritance_fallback(); - - let (fallback, fallback_purity) = call_fallback.unwrap_or_else(|| { + let fallback = call_fallback.unwrap_or_else(|| { // If there is no fallback function specified, we rely on any inherited fallback. - ( - parse_quote!({ - #(#inheritance_fallback)* - None - }), - Purity::Payable, // Let the inherited fallback deal with purity. - ) + parse_quote!({ + #(#inheritance_fallback)* + None + }) }); - let fallback_deny: Option = match fallback_purity { - Purity::Payable => None, - _ => Some(parse_quote! { - if let Err(err) = stylus_sdk::abi::internal::deny_value("fallback") { - return Some(Err(err)); - } - }), - }; - - let call_receive = self.call_receive(); + let call_receive = call_special!(self, FnKind::Receive, "receive", PublicFn::call_receive); let inheritance_receive = self.inheritance_receive(); let receive = call_receive.unwrap_or_else(|| { parse_quote!({ @@ -79,7 +100,12 @@ impl PublicImpl { }) }); - let call_constructor = self.call_constructor(); + let call_constructor = call_special!( + self, + FnKind::Constructor, + "constructor", + PublicFn::call_constructor + ); let constructor = call_constructor.unwrap_or_else(|| parse_quote!({ None })); parse_quote! { @@ -115,7 +141,6 @@ impl PublicImpl { #[inline(always)] fn fallback(storage: &mut S, input: &[u8]) -> Option { - #fallback_deny #fallback } @@ -142,35 +167,6 @@ impl PublicImpl { }) } - fn call_fallback(&self) -> Option<(syn::Stmt, Purity)> { - let mut fallback_purity = Purity::View; - let fallbacks: Vec = self - .funcs - .iter() - .filter(|&func| { - if matches!(func.kind, FnKind::Fallback { .. }) { - fallback_purity = func.purity; - return true; - } - false - }) - .map(PublicFn::call_fallback) - .collect(); - if fallbacks.is_empty() { - return None; - } - if fallbacks.len() > 1 { - emit_error!( - "multiple fallbacks", - "contract can only have one #[fallback] method defined" - ); - } - fallbacks - .first() - .cloned() - .map(|func| (func, fallback_purity)) - } - fn inheritance_fallback(&self) -> impl Iterator + '_ { self.inheritance.iter().map(|ty| { parse_quote! { @@ -181,25 +177,6 @@ impl PublicImpl { }) } - fn call_receive(&self) -> Option { - let receives: Vec = self - .funcs - .iter() - .filter(|&func| matches!(func.kind, FnKind::Receive)) - .map(PublicFn::call_receive) - .collect(); - if receives.is_empty() { - return None; - } - if receives.len() > 1 { - emit_error!( - "multiple receives", - "contract can only have one #[receive] method defined" - ); - } - receives.first().cloned() - } - fn inheritance_receive(&self) -> impl Iterator + '_ { self.inheritance.iter().map(|ty| { parse_quote! { @@ -209,25 +186,6 @@ impl PublicImpl { } }) } - - fn call_constructor(&self) -> Option { - let constructors: Vec = self - .funcs - .iter() - .filter(|&func| matches!(func.kind, FnKind::Constructor)) - .map(PublicFn::call_constructor) - .collect(); - if constructors.is_empty() { - return None; - } - if constructors.len() > 1 { - emit_error!( - "multiple constructors", - "contract can only have one #[constructor] method defined" - ); - } - constructors.first().cloned() - } } #[derive(Debug)] @@ -356,16 +314,29 @@ impl PublicFn { } fn call_fallback(&self) -> syn::Stmt { + let deny_value: Option = match self.purity { + Purity::Payable => None, + _ => Some(parse_quote! { + if let Err(err) = stylus_sdk::abi::internal::deny_value("fallback") { + return Some(Err(err)); + } + }), + }; let name = &self.name; let storage_arg = self.storage_arg(); - if matches!(self.kind, FnKind::Fallback { with_args: false }) { - return parse_quote! { + let call: syn::Stmt = if matches!(self.kind, FnKind::Fallback { with_args: false }) { + parse_quote! { return Some(Self::#name(#storage_arg)); - }; - } - parse_quote! { - return Some(Self::#name(#storage_arg input)); - } + } + } else { + parse_quote! { + return Some(Self::#name(#storage_arg input)); + } + }; + parse_quote!({ + #deny_value + #call + }) } fn call_receive(&self) -> syn::Stmt {