diff --git a/packages/macros/src/lib.rs b/packages/macros/src/lib.rs index 85af4100..950a643c 100644 --- a/packages/macros/src/lib.rs +++ b/packages/macros/src/lib.rs @@ -1,2 +1,5 @@ mod num_traits; +mod zero_trait; mod pow; + +mod parse; diff --git a/packages/macros/src/num_traits.rs b/packages/macros/src/num_traits.rs index 5ecbb84d..a1d670bb 100644 --- a/packages/macros/src/num_traits.rs +++ b/packages/macros/src/num_traits.rs @@ -1,15 +1,6 @@ use cairo_lang_macro::{derive_macro, ProcMacroResult, TokenStream}; -use cairo_lang_parser::utils::SimpleParserDatabase; -use cairo_lang_syntax::node::kind::SyntaxKind::{ - Member, OptionWrappedGenericParamListEmpty, TerminalStruct, TokenIdentifier, - WrappedGenericParamList, -}; -struct StructInfo { - name: String, - generic_params: Option>, - members: Vec, -} +use crate::parse::{StructInfo, parse_struct_info}; struct OpInfo { trait_name: String, @@ -17,63 +8,6 @@ struct OpInfo { operator: String, } -fn parse_struct_info(token_stream: TokenStream) -> StructInfo { - let db = SimpleParserDatabase::default(); - let (parsed, _diag) = db.parse_virtual_with_diagnostics(token_stream); - let mut nodes = parsed.descendants(&db); - - // find struct name - the next TokenIdentifier after TeminalStruct - let mut struct_name = String::new(); - while let Some(node) = nodes.next() { - if node.kind(&db) == TerminalStruct { - struct_name = nodes - .find(|node| node.kind(&db) == TokenIdentifier) - .unwrap() - .get_text(&db); - break; - } - } - - // collect generic params or skip if there aren't any - let mut generic_params: Option> = None; - while let Some(node) = nodes.next() { - match node.kind(&db) { - WrappedGenericParamList => { - let params = node - .descendants(&db) - .filter(|node| node.kind(&db) == TokenIdentifier) - .map(|node| node.get_text(&db)) - .collect(); - generic_params = Some(params); - break; - } - OptionWrappedGenericParamListEmpty => { - break; - } - _ => {} - } - } - - // collect struct members - all TokenIdentifier nodes after each Member - let mut members = Vec::new(); - while let Some(node) = nodes.next() { - if node.kind(&db) == Member { - let member = node - .descendants(&db) - .find(|node| node.kind(&db) == TokenIdentifier) - .map(|node| node.get_text(&db)) - .unwrap(); - members.push(member); - } - } - - StructInfo { - name: struct_name, - generic_params, - members, - } -} - fn generate_op_trait_impl(op_info: &OpInfo, s: &StructInfo) -> String { let generic_params = s .generic_params diff --git a/packages/macros/src/parse.rs b/packages/macros/src/parse.rs new file mode 100644 index 00000000..5586fe32 --- /dev/null +++ b/packages/macros/src/parse.rs @@ -0,0 +1,69 @@ +use cairo_lang_macro::TokenStream; +use cairo_lang_parser::utils::SimpleParserDatabase; +use cairo_lang_syntax::node::kind::SyntaxKind::{ + Member, OptionWrappedGenericParamListEmpty, TerminalStruct, TokenIdentifier, + WrappedGenericParamList, +}; + +pub(crate) struct StructInfo { + pub(crate) name: String, + pub(crate) generic_params: Option>, + pub(crate) members: Vec, +} + +pub(crate) fn parse_struct_info(token_stream: TokenStream) -> StructInfo { + let db = SimpleParserDatabase::default(); + let (parsed, _diag) = db.parse_virtual_with_diagnostics(token_stream); + let mut nodes = parsed.descendants(&db); + + // find struct name - the next TokenIdentifier after TeminalStruct + let mut struct_name = String::new(); + while let Some(node) = nodes.next() { + if node.kind(&db) == TerminalStruct { + struct_name = nodes + .find(|node| node.kind(&db) == TokenIdentifier) + .unwrap() + .get_text(&db); + break; + } + } + + // collect generic params or skip if there aren't any + let mut generic_params: Option> = None; + while let Some(node) = nodes.next() { + match node.kind(&db) { + WrappedGenericParamList => { + let params = node + .descendants(&db) + .filter(|node| node.kind(&db) == TokenIdentifier) + .map(|node| node.get_text(&db)) + .collect(); + generic_params = Some(params); + break; + } + OptionWrappedGenericParamListEmpty => { + break; + } + _ => {} + } + } + + // collect struct members - all TokenIdentifier nodes after each Member + let mut members = Vec::new(); + while let Some(node) = nodes.next() { + if node.kind(&db) == Member { + let member = node + .descendants(&db) + .find(|node| node.kind(&db) == TokenIdentifier) + .map(|node| node.get_text(&db)) + .unwrap(); + members.push(member); + } + } + + StructInfo { + name: struct_name, + generic_params, + members, + } +} diff --git a/packages/macros/src/zero_trait.rs b/packages/macros/src/zero_trait.rs new file mode 100644 index 00000000..3b71d25e --- /dev/null +++ b/packages/macros/src/zero_trait.rs @@ -0,0 +1,81 @@ +use cairo_lang_macro::{derive_macro, ProcMacroResult, TokenStream}; +use crate::parse::{parse_struct_info, StructInfo}; + +fn generate_zero_trait_impl(s: &StructInfo) -> String { + let generic_params = s + .generic_params + .as_ref() + .map_or(String::new(), |params| format!("<{}>", params.join(", "))); + + let trait_bounds = s.generic_params.as_ref().map_or_else( + || String::new(), + |params| { + let bounds = params + .iter() + .flat_map(|param| { + vec![ + format!("+core::num::traits::Zero<{}>", param), + format!("+core::traits::Drop<{}>", param), + ] + }) + .collect::>() + .join(",\n"); + format!("<{},\n{}>", params.join(", "), bounds) + }, + ); + + let zero_fn = s + .members + .iter() + .map(|member| format!("{}: core::num::traits::Zero::zero()", member)) + .collect::>() + .join(", "); + + let is_zero_fn = s + .members + .iter() + .map(|member| format!("self.{}.is_zero()", member)) + .collect::>() + .join(" && "); + + format!( + "\n +impl {0}ZeroImpl{1} +of core::num::traits::Zero<{0}{2}> {{ + fn zero() -> {0}{2} {{ + {0} {{ {3} }} + }} + + fn is_zero(self: @{0}{2}) -> bool {{ + {4} + }} + + fn is_non_zero(self: @{0}{2}) -> bool {{ + !self.is_zero() + }} +}}\n", + s.name, trait_bounds, generic_params, zero_fn, is_zero_fn + ) +} + +/// Adds implementation of the `Zero` trait. +/// +/// All members of the struct must already implement the `Zero` trait. +/// +/// ``` +/// #[derive(Zero, PartialEq, Debug)] +/// struct Point { +/// x: u64, +/// y: u64, +/// } +/// +/// assert_eq!(Point { x: 0, y: 0 }, Zero::zero()); +/// assert!(Point { x: 0, y: 0 }.is_zero()); +/// assert!(Point { x: 1, y: 0 }.is_non_zero()); +/// ``` +#[derive_macro] +pub fn zero(token_stream: TokenStream) -> ProcMacroResult { + let s = parse_struct_info(token_stream); + + ProcMacroResult::new(TokenStream::new(generate_zero_trait_impl(&s))) +} diff --git a/packages/macros_tests/src/lib.cairo b/packages/macros_tests/src/lib.cairo index 7f3e3029..c76e302c 100644 --- a/packages/macros_tests/src/lib.cairo +++ b/packages/macros_tests/src/lib.cairo @@ -3,3 +3,6 @@ mod test_pow; #[cfg(test)] mod test_num_traits; + +#[cfg(test)] +mod test_zero_trait; diff --git a/packages/macros_tests/src/test_zero_trait.cairo b/packages/macros_tests/src/test_zero_trait.cairo new file mode 100644 index 00000000..9854bf17 --- /dev/null +++ b/packages/macros_tests/src/test_zero_trait.cairo @@ -0,0 +1,56 @@ +use core::num::traits::Zero; + +// a basic struct +#[derive(Zero, Debug, Drop, PartialEq)] +struct B { + pub a: u8, + b: u16 +} + +// a generic struct +#[derive(Zero, Debug, Drop, PartialEq)] +struct G { + x: T1, + pub y: T2, + z: T2 +} + +// a complex struct +#[derive(Zero, Debug, Drop, PartialEq)] +struct C { + pub g: G, + i: u64, + j: u32 +} + + + +#[test] +fn test_zero_derive() { + let b0: B = B { a: 0, b: 0 }; + let b1: B = B { a: 1, b: 2 }; + + assert_eq!(b0, Zero::zero()); + assert!(b0.is_zero()); + assert!(b0.is_non_zero() == false); + assert!(b1.is_zero() == false); + assert!(b1.is_non_zero()); + + let g0: G = G { x: 0, y: 0, z: 0 }; + let g1: G = G { x: 1, y: 2, z: 3 }; + + assert_eq!(g0, Zero::zero()); + assert!(g0.is_zero()); + assert!(g0.is_non_zero() == false); + assert!(g1.is_zero() == false); + assert!(g1.is_non_zero()); + + let c0: C = C { g: G { x: 0, y: 0, z: 0 }, i: 0, j: 0 }; + let c1: C = C { g: G { x: 0, y: 0, z: 0 }, i: 4, j: 5 }; + + assert_eq!(c0, Zero::zero()); + assert!(c0.is_zero()); + assert!(c0.is_non_zero() == false); + assert!(c1.is_zero() == false); + assert!(c1.is_non_zero()); +}