Skip to content

Commit

Permalink
Rework validation of names and aliases for aggregate UDFs
Browse files Browse the repository at this point in the history
Add better validation that `name = ...` and `alias = ...` line up when
there is a mismatch between the `BasicUdf` and `AggregateUdf`
implementation, and fix a problem that was disallowing use of aliases in
aggregate UDFs.

Error messages are significantly improved.

Fixes <#59>
  • Loading branch information
tgross35 committed May 7, 2024
1 parent 8233525 commit c2804af
Show file tree
Hide file tree
Showing 10 changed files with 526 additions and 108 deletions.
19 changes: 3 additions & 16 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,9 @@

## [Unreleased] - ReleaseDate

### Added

### Changed

### Removed



## [0.5.4] - 2023-09-10

### Added

### Changed

### Removed

Rework the validation of names and aliases for aggregate UDFs. This fixes an
issue where aliases could not be used for aggregate UDFs, and provides better
error messages.


## [0.5.4] - 2023-09-10
Expand Down
217 changes: 130 additions & 87 deletions udf-macros/src/register.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![allow(unused_imports)]

use std::iter;

use heck::AsSnakeCase;
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
Expand All @@ -8,24 +10,17 @@ use syn::parse::{Parse, ParseStream, Parser};
use syn::punctuated::Punctuated;
use syn::{
parse_macro_input, parse_quote, DeriveInput, Error, Expr, ExprLit, Ident, ImplItem,
ImplItemType, Item, ItemImpl, Lit, Meta, Path, PathSegment, Token, Type, TypePath,
ImplItemType, Item, ItemImpl, Lit, LitStr, Meta, Path, PathSegment, Token, Type, TypePath,
TypeReference,
};

use crate::match_variant;
use crate::types::{make_type_list, ImplType, RetType, TypeClass};

/// Create an identifier from another identifier, changing the name to snake case
macro_rules! format_ident_str {
($formatter:tt, $ident:ident) => {
Ident::new(format!($formatter, $ident).as_str(), Span::call_site())
};
}

/// Verify that an `ItemImpl` matches the end of any given path
///
/// implements `BasicUdf` (in any of its pathing options)
fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
fn impl_type(itemimpl: &ItemImpl) -> Option<ImplType> {
let implemented = &itemimpl.trait_.as_ref().unwrap().1.segments;

let basic_paths: [Punctuated<PathSegment, Token![::]>; 3] = [
Expand All @@ -39,9 +34,12 @@ fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
parse_quote! {AggregateUdf},
];

match expected {
ImplType::Basic => basic_paths.contains(implemented),
ImplType::Aggregate => arg_paths.contains(implemented),
if basic_paths.contains(implemented) {
Some(ImplType::Basic)
} else if arg_paths.contains(implemented) {
Some(ImplType::Aggregate)
} else {
None
}
}

Expand All @@ -57,14 +55,11 @@ fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
let parsed = parse_macro_input!(input as ItemImpl);

let impls_basic = impls_path(&parsed, ImplType::Basic);
let impls_agg = impls_path(&parsed, ImplType::Aggregate);

if !(impls_basic || impls_agg) {
let Some(impl_ty) = impl_type(&parsed) else {
return Error::new_spanned(&parsed, "Expected trait `BasicUdf` or `AggregateUdf`")
.into_compile_error()
.into();
}
};

// Full type path of our data struct
let Type::Path(dstruct_path) = parsed.self_ty.as_ref() else {
Expand All @@ -73,7 +68,7 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
.into();
};

let base_fn_names = match parse_args(args, dstruct_path) {
let parsed_meta = match ParsedMeta::parse(args, dstruct_path) {
Ok(v) => v,
Err(e) => return e.into_compile_error().into(),
};
Expand All @@ -89,91 +84,110 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
Span::call_site(),
);

let (ret_ty, wrapper_def) = if impls_basic {
match get_rt_and_wrapper(&parsed, dstruct_path, &wrapper_ident) {
let (ret_ty, wrapper_def) = match impl_ty {
ImplType::Basic => match get_ret_ty_and_wrapper(&parsed, dstruct_path, &wrapper_ident) {
Ok((r, w)) => (Some(r), w),
Err(e) => return e.into_compile_error().into(),
}
} else {
(None, TokenStream2::new())
},
ImplType::Aggregate => (None, TokenStream2::new()),
};

let content_iter = base_fn_names.iter().map(|base_fn_name| {
if impls_basic {
make_basic_fns(
ret_ty.as_ref().unwrap(),
base_fn_name,
dstruct_path,
&wrapper_ident,
)
} else {
make_agg_fns(&parsed, base_fn_name, dstruct_path, &wrapper_ident)
}
let helper_traits = make_helper_trait_impls(dstruct_path, &parsed_meta, impl_ty);

let fn_items_iter = parsed_meta.all_names().map(|base_fn_name| match impl_ty {
ImplType::Basic => make_basic_fns(
ret_ty.as_ref().unwrap(),
base_fn_name,
dstruct_path,
&wrapper_ident,
),
ImplType::Aggregate => make_agg_fns(&parsed, base_fn_name, dstruct_path, &wrapper_ident),
});

quote! {
#parsed

#wrapper_def

#( #content_iter )*
#helper_traits

#( #fn_items_iter )*
}
.into()
}

/// Parse attribute arguments. Returns an iterator of names
fn parse_args(args: &TokenStream, dstruct_path: &TypePath) -> syn::Result<Vec<String>> {
let meta = Punctuated::<Meta, Token![,]>::parse_terminated.parse(args.clone())?;
let mut base_fn_names: Vec<String> = vec![];
let mut primary_name_specified = false;

for m in meta {
let Meta::NameValue(mval) = m else {
return Err(Error::new_spanned(m, "expected `a = b atributes`"));
};
/// Arguments we parse from metadata or default to
struct ParsedMeta {
name: String,
aliases: Vec<String>,
default_name_used: bool,
}

if !mval.path.segments.iter().count() == 1 {
return Err(Error::new_spanned(mval.path, "unexpected path"));
}
impl ParsedMeta {
/// Parse attribute arguments. Returns an iterator of names
fn parse(args: &TokenStream, dstruct_path: &TypePath) -> syn::Result<Self> {
let meta = Punctuated::<Meta, Token![,]>::parse_terminated.parse(args.clone())?;
let mut name_from_attributes = None;
let mut aliases = Vec::new();

let key = mval.path.segments.first().unwrap();
for m in meta {
let Meta::NameValue(mval) = m else {
return Err(Error::new_spanned(m, "expected `a = b atributes`"));
};

let Expr::Lit(ExprLit {
lit: Lit::Str(value),
..
}) = mval.value
else {
return Err(Error::new_spanned(mval.value, "expected a literal string"));
};
if !mval.path.segments.iter().count() == 1 {
return Err(Error::new_spanned(mval.path, "unexpected path"));
}

if key.ident == "name" {
if primary_name_specified {
return Err(Error::new_spanned(key, "`name` can only be specified once"));
let key = mval.path.segments.first().unwrap();

let Expr::Lit(ExprLit {
lit: Lit::Str(value),
..
}) = mval.value
else {
return Err(Error::new_spanned(mval.value, "expected a literal string"));
};

if key.ident == "name" {
if name_from_attributes.is_some() {
return Err(Error::new_spanned(key, "`name` can only be specified once"));
}
name_from_attributes = Some(value.value());
} else if key.ident == "alias" {
aliases.push(value.value());
} else {
return Err(Error::new_spanned(
key,
"unexpected key (only `name` and `alias` are accepted)",
));
}
base_fn_names.push(value.value());
primary_name_specified = true;
} else if key.ident == "alias" {
base_fn_names.push(value.value());
} else {
return Err(Error::new_spanned(
key,
"unexpected key (only `name` and `alias` are accepted)",
));
}
}

if !primary_name_specified {
// If we don't have a name specified, use the type name as snake case
let ty_ident = &dstruct_path.path.segments.last().unwrap().ident;
let fn_name = AsSnakeCase(&ty_ident.to_string()).to_string();
base_fn_names.push(fn_name);
let mut default_name_used = false;
let name = name_from_attributes.unwrap_or_else(|| {
// If we don't have a name specified, use the type name as snake case
let ty_ident = &dstruct_path.path.segments.last().unwrap().ident;
let fn_name = AsSnakeCase(&ty_ident.to_string()).to_string();
default_name_used = true;
fn_name
});

Ok(Self {
name,
aliases,
default_name_used,
})
}

Ok(base_fn_names)
/// Iterate the basic name and all aliases
fn all_names(&self) -> impl Iterator<Item = &String> {
iter::once(&self.name).chain(self.aliases.iter())
}
}

/// Get the return type to use and a wrapper. Once per impl setup.
fn get_rt_and_wrapper(
fn get_ret_ty_and_wrapper(
parsed: &ItemImpl,
dstruct_path: &TypePath,
wrapper_ident: &Ident,
Expand Down Expand Up @@ -209,16 +223,50 @@ fn get_rt_and_wrapper(
Ok((ret_ty, wrapper_struct))
}

/// Make implementations for our helper/metadata traits
fn make_helper_trait_impls(
dstruct_path: &TypePath,
meta: &ParsedMeta,
impl_ty: ImplType,
) -> TokenStream2 {
let name = LitStr::new(&meta.name, Span::call_site());
let aliases = meta
.aliases
.iter()
.map(|alias| LitStr::new(alias.as_ref(), Span::call_site()));
let (trait_name, check_expr) = match impl_ty {
ImplType::Basic => (
quote! { ::udf::wrapper::RegisteredBasicUdf },
TokenStream2::new(),
),
ImplType::Aggregate => (
quote! { ::udf::wrapper::RegisteredAggregateUdf },
quote! { const _: () = ::udf::wrapper::verify_aggregate_attributes::<#dstruct_path>(); },
),
};
let default_name_used = meta.default_name_used;

quote! {
impl #trait_name for #dstruct_path {
const NAME: &'static str = #name;
const ALIASES: &'static [&'static str] = &[#( #aliases ),*];
const DEFAULT_NAME_USED: bool = #default_name_used;
}

#check_expr
}
}

/// Create the basic function signatures (`xxx_init`, `xxx_deinit`, `xxx`)
fn make_basic_fns(
rt: &RetType,
base_fn_name: &str,
dstruct_path: &TypePath,
wrapper_ident: &Ident,
) -> TokenStream2 {
let init_fn_name = format_ident_str!("{}_init", base_fn_name);
let deinit_fn_name = format_ident_str!("{}_deinit", base_fn_name);
let process_fn_name = format_ident_str!("{}", base_fn_name);
let init_fn_name = format_ident!("{}_init", base_fn_name);
let deinit_fn_name = format_ident!("{}_deinit", base_fn_name);
let process_fn_name = format_ident!("{}", base_fn_name);

let init_fn = make_init_fn(dstruct_path, wrapper_ident, &init_fn_name);
let deinit_fn = make_deinit_fn(dstruct_path, wrapper_ident, &deinit_fn_name);
Expand Down Expand Up @@ -269,9 +317,9 @@ fn make_agg_fns(
dstruct_path: &TypePath, // Name of the data structure
wrapper_ident: &Ident,
) -> TokenStream2 {
let clear_fn_name = format_ident_str!("{}_clear", base_fn_name);
let add_fn_name = format_ident_str!("{}_add", base_fn_name);
let remove_fn_name = format_ident_str!("{}_remove", base_fn_name);
let clear_fn_name = format_ident!("{}_clear", base_fn_name);
let add_fn_name = format_ident!("{}_add", base_fn_name);
let remove_fn_name = format_ident!("{}_remove", base_fn_name);

// Determine whether this re-implements `remove`
let impls_remove = &parsed
Expand All @@ -280,7 +328,6 @@ fn make_agg_fns(
.filter_map(match_variant!(ImplItem::Fn))
.map(|m| &m.sig.ident)
.any(|id| *id == "remove");
let base_fn_ident = Ident::new(base_fn_name, Span::call_site());

let clear_fn = make_clear_fn(dstruct_path, wrapper_ident, &clear_fn_name);
let add_fn = make_add_fn(dstruct_path, wrapper_ident, &add_fn_name);
Expand All @@ -295,10 +342,6 @@ fn make_agg_fns(
};

quote! {
// Sanity check that we implemented
#[allow(dead_code, non_upper_case_globals)]
const did_you_apply_the_same_aliases_to_the_BasicUdf_impl: *const () = #base_fn_ident as _;

#clear_fn

#add_fn
Expand Down
22 changes: 22 additions & 0 deletions udf-macros/tests/fail/agg_missing_basic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#![allow(unused)]

use udf::prelude::*;

struct MyUdf;

impl AggregateUdf for MyUdf {
// Required methods
fn clear(&mut self, cfg: &UdfCfg<Process>, error: Option<NonZeroU8>) -> Result<(), NonZeroU8> {
todo!()
}
fn add(
&mut self,
cfg: &UdfCfg<Process>,
args: &ArgList<'_, Process>,
error: Option<NonZeroU8>,
) -> Result<(), NonZeroU8> {
todo!()
}
}

fn main() {}
11 changes: 11 additions & 0 deletions udf-macros/tests/fail/agg_missing_basic.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
error[E0277]: the trait bound `MyUdf: BasicUdf` is not satisfied
--> tests/fail/agg_missing_basic.rs:7:23
|
7 | impl AggregateUdf for MyUdf {
| ^^^^^ the trait `BasicUdf` is not implemented for `MyUdf`
|
note: required by a bound in `udf::AggregateUdf`
--> $WORKSPACE/udf/src/traits.rs
|
| pub trait AggregateUdf: BasicUdf {
| ^^^^^^^^ required by this bound in `AggregateUdf`
Loading

0 comments on commit c2804af

Please sign in to comment.