Skip to content

Commit

Permalink
Fix PartialOrd misimplementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Joeoc2001 committed Nov 22, 2024
1 parent 844c113 commit a1ade35
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "perfect-derive"
version = "0.1.3"
version = "0.1.4"
edition = "2021"
license = "MIT"
description = "Provides a prototype of the proposed perfect_derive macro"
Expand Down
1 change: 1 addition & 0 deletions src/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ macro_rules! impls {
};
}

#[allow(clippy::single_component_path_imports)]
pub(crate) use impls;
26 changes: 14 additions & 12 deletions src/perfect_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn is_attribute_default(a: &Attribute) -> bool {
&& a.path()
.segments
.iter()
.all(|i| i.ident.to_string() == "default" && i.arguments.is_empty())
.all(|i| i.ident == "default" && i.arguments.is_empty())
}

fn remove_debug_markers(obj: &mut StructOrEnum) {
Expand Down Expand Up @@ -55,7 +55,7 @@ pub fn impl_traits(traits: DerivedList, mut obj: StructOrEnum) -> TokenStream {
if already_derived.contains(&derived.name) {
panic!("cannot derive {:?} twice", derived.name)
}
already_derived.insert(derived.name.clone());
already_derived.insert(derived.name);

add_type_impl(&mut output, &derived, &obj);
}
Expand All @@ -71,7 +71,7 @@ pub fn impl_traits(traits: DerivedList, mut obj: StructOrEnum) -> TokenStream {
#output
};

return output;
output
}

enum IdentOrLifetime {
Expand Down Expand Up @@ -182,7 +182,7 @@ fn get_debug_enum_marker(enum_item: &ItemEnum) -> &Variant {
.filter(|v| v.attrs.iter().any(is_attribute_default))
.collect::<Vec<_>>();
assert!(
default_variants.len() > 0,
!default_variants.is_empty(),
"one enum variant must be marked as default"
);
assert_eq!(
Expand Down Expand Up @@ -229,17 +229,17 @@ fn augment_where_clause(
let mut predicates = clause
.as_ref()
.map(|c| c.predicates.clone())
.unwrap_or(Punctuated::new());
.unwrap_or_default();
for predicate in extra {
predicates.push(predicate)
}

return WhereClause {
WhereClause {
where_token: clause.map(|c| c.where_token).unwrap_or(Where {
span: trait_to_impl.span,
}),
predicates,
};
}
}

fn get_named_idents(names: &FieldsNamed) -> Vec<Ident> {
Expand All @@ -251,7 +251,7 @@ fn get_named_idents(names: &FieldsNamed) -> Vec<Ident> {
}

fn get_named_idents_suffix(names: &FieldsNamed, suffix: &str) -> Vec<Ident> {
assert!(suffix != "");
assert!(!suffix.is_empty());
names
.named
.iter()
Expand Down Expand Up @@ -569,7 +569,7 @@ fn pord_struct(s: &ItemStruct) -> TokenStream {
quote! {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(std::cmp::Ordering::Equal) #(
.and_then(|o| self.#idents.partial_cmp(&other.#idents).map(|v| v.then(o)))
.and_then(|o| self.#idents.partial_cmp(&other.#idents).map(|v| o.then(v)))
)*
}
}
Expand All @@ -584,7 +584,7 @@ fn pord_struct(s: &ItemStruct) -> TokenStream {
let Self( #(#idents2),* ) = other;

Some(std::cmp::Ordering::Equal) #(
.and_then(|o| #idents1.partial_cmp(#idents2).map(|v| v.then(o)))
.and_then(|o| #idents1.partial_cmp(#idents2).map(|v| o.then(v)))
)*
}
}
Expand Down Expand Up @@ -612,7 +612,7 @@ fn pord_enum(e: &ItemEnum) -> TokenStream {
quote! {
(Self::#ident{#(#idents: #idents1),*}, Self::#ident{#(#idents: #idents2),*})
=> Some(std::cmp::Ordering::Equal) #(
.and_then(|o| #idents1.partial_cmp(#idents2).map(|v| v.then(o)))
.and_then(|o| #idents1.partial_cmp(#idents2).map(|v| o.then(v)))
)*
}
}
Expand All @@ -623,7 +623,7 @@ fn pord_enum(e: &ItemEnum) -> TokenStream {
quote! {
(Self::#ident(#(#idents1),*), Self::#ident(#(#idents2),*))
=> Some(std::cmp::Ordering::Equal) #(
.and_then(|o| #idents1.partial_cmp(#idents2).map(|v| v.then(o)))
.and_then(|o| #idents1.partial_cmp(#idents2).map(|v| o.then(v)))
)*
}
}
Expand Down Expand Up @@ -714,6 +714,8 @@ fn hash_enum(e: &ItemEnum) -> TokenStream {

quote! {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
let dis = std::mem::discriminant(self);
dis.hash(state);
match self {
#(
#variant_cases,
Expand Down
4 changes: 2 additions & 2 deletions src/perfect_parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl DerivedType {
_ => {} // fall through to default unscoped path
}

return type_enum_ident_as_path!(self);
type_enum_ident_as_path!(self)
}

pub fn get_trait(&self) -> TraitBound {
Expand Down Expand Up @@ -91,7 +91,7 @@ impl Parse for DerivedType {
parse_types_enum! {
match name {
ident...,
_ => Err(input.error(format!("type identifier {} is not supported - did you mean to use #[derive(...)]?", ident.to_string())))
_ => Err(input.error(format!("type identifier {} is not supported - did you mean to use #[derive(...)]?", ident)))
}
}
}
Expand Down
172 changes: 169 additions & 3 deletions tests/perfect_derive_base_macro.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use std::hash::Hash;

use perfect_derive::perfect_derive;

fn hash_to_int(v: &impl Hash) -> u64 {
use std::hash::Hasher;
let mut hasher = std::collections::hash_map::DefaultHasher::new();
v.hash(&mut hasher);
hasher.finish()
}

macro_rules! make_test {
($trait_name:ident $(,$trait_name_tail:ident)*; $method_name:ident) => {
#[allow(unused)]
mod $method_name {
use perfect_derive::perfect_derive;

Expand Down Expand Up @@ -65,25 +76,180 @@ make_test!(PartialOrd, PartialEq; pord);
make_test!(Debug; debug);
make_test!(Hash; hash);

#[derive(Copy, Clone, Ord, Eq, PartialOrd, PartialEq, Debug, Hash, Default)]
struct EverythingStructCore {
v1: usize,
pub v2: i32,
}

#[perfect_derive(Copy, Clone, Ord, Eq, PartialOrd, PartialEq, Debug, Hash, Default)]
struct EverythingStruct {
v1: usize,
pub v2: i32,
}

#[test]
fn struct_eq_matches() {
let c1 = EverythingStructCore { v1: 1, v2: 2 };
let c2 = EverythingStructCore { v1: 2, v2: 1 };

let s1 = EverythingStruct { v1: 1, v2: 2 };
let s2 = EverythingStruct { v1: 2, v2: 1 };

assert_eq!(s1.eq(&s1), c1.eq(&c1));
assert_eq!(s1.eq(&s2), c1.eq(&c2));
assert_eq!(s2.eq(&s1), c2.eq(&c1));
}

#[test]
fn struct_ord_matches() {
let c1 = EverythingStructCore { v1: 1, v2: 2 };
let c2 = EverythingStructCore { v1: 2, v2: 1 };

let s1 = EverythingStruct { v1: 1, v2: 2 };
let s2 = EverythingStruct { v1: 2, v2: 1 };

assert_eq!(s1.cmp(&s1), c1.cmp(&c1));
assert_eq!(s1.cmp(&s2), c1.cmp(&c2));
assert_eq!(s2.cmp(&s1), c2.cmp(&c1));
}

#[test]
fn struct_partial_ord_matches() {
let c1 = EverythingStructCore { v1: 1, v2: 2 };
let c2 = EverythingStructCore { v1: 2, v2: 1 };

let s1 = EverythingStruct { v1: 1, v2: 2 };
let s2 = EverythingStruct { v1: 2, v2: 1 };

assert_eq!(s1.partial_cmp(&s1), c1.partial_cmp(&c1));
assert_eq!(s1.partial_cmp(&s2), c1.partial_cmp(&c2));
assert_eq!(s2.partial_cmp(&s1), c2.partial_cmp(&c1));
}

#[test]
fn struct_hash_matches() {
let c1 = EverythingStructCore { v1: 1, v2: 2 };
let c2 = EverythingStructCore { v1: 2, v2: 1 };

let s1 = EverythingStruct { v1: 1, v2: 2 };
let s2 = EverythingStruct { v1: 2, v2: 1 };

assert_eq!(hash_to_int(&s1), hash_to_int(&c1));
assert_eq!(hash_to_int(&s2), hash_to_int(&c2));
}

#[derive(Copy, Clone, Ord, Eq, PartialOrd, PartialEq, Debug, Hash, Default)]
#[allow(unused)]
enum EverythingEnumCore {
#[default]
E1,
E2(),
E3(usize, usize),
E4(u32, ()),
E5 {
name1: u32,
name2: (),
},
}

#[perfect_derive(Copy, Clone, Ord, Eq, PartialOrd, PartialEq, Debug, Hash, Default)]
enum EverythingEnum {
#[default]
E1,
E2(),
E3(usize),
E3(usize, usize),
E4(u32, ()),
#[default]
E5 {
name1: u32,
name2: (),
},
}

#[test]
fn enum_eq_matches() {
let c1 = EverythingEnumCore::E1;
let c2 = EverythingEnumCore::E3(1, 2);
let c3 = EverythingEnumCore::E3(2, 1);

let s1 = EverythingEnum::E1;
let s2 = EverythingEnum::E3(1, 2);
let s3 = EverythingEnum::E3(2, 1);

assert_eq!(s1.eq(&s1), c1.eq(&c1));
assert_eq!(s2.eq(&s2), c2.eq(&c2));
assert_eq!(s3.eq(&s3), c3.eq(&c3));

assert_eq!(s1.eq(&s2), c1.eq(&c2));
assert_eq!(s2.eq(&s1), c2.eq(&c1));

assert_eq!(s1.eq(&s3), c1.eq(&c3));
assert_eq!(s3.eq(&s1), c3.eq(&c1));

assert_eq!(s3.eq(&s2), c3.eq(&c2));
assert_eq!(s2.eq(&s3), c2.eq(&c3));
}

#[test]
fn enum_ord_matches() {
let c1 = EverythingEnumCore::E1;
let c2 = EverythingEnumCore::E3(1, 2);
let c3 = EverythingEnumCore::E3(2, 1);

let s1 = EverythingEnum::E1;
let s2 = EverythingEnum::E3(1, 2);
let s3 = EverythingEnum::E3(2, 1);

assert_eq!(s1.cmp(&s1), c1.cmp(&c1));
assert_eq!(s2.cmp(&s2), c2.cmp(&c2));
assert_eq!(s3.cmp(&s3), c3.cmp(&c3));

assert_eq!(s1.cmp(&s2), c1.cmp(&c2));
assert_eq!(s2.cmp(&s1), c2.cmp(&c1));

assert_eq!(s1.cmp(&s3), c1.cmp(&c3));
assert_eq!(s3.cmp(&s1), c3.cmp(&c1));

assert_eq!(s3.cmp(&s2), c3.cmp(&c2));
assert_eq!(s2.cmp(&s3), c2.cmp(&c3));
}

#[test]
fn enum_partial_ord_matches() {
let c1 = EverythingEnumCore::E1;
let c2 = EverythingEnumCore::E3(1, 2);
let c3 = EverythingEnumCore::E3(2, 1);

let s1 = EverythingEnum::E1;
let s2 = EverythingEnum::E3(1, 2);
let s3 = EverythingEnum::E3(2, 1);

assert_eq!(s1.partial_cmp(&s1), c1.partial_cmp(&c1));
assert_eq!(s2.partial_cmp(&s2), c2.partial_cmp(&c2));
assert_eq!(s3.partial_cmp(&s3), c3.partial_cmp(&c3));

assert_eq!(s1.partial_cmp(&s2), c1.partial_cmp(&c2));
assert_eq!(s2.partial_cmp(&s1), c2.partial_cmp(&c1));

assert_eq!(s1.partial_cmp(&s3), c1.partial_cmp(&c3));
assert_eq!(s3.partial_cmp(&s1), c3.partial_cmp(&c1));

assert_eq!(s3.partial_cmp(&s2), c3.partial_cmp(&c2));
assert_eq!(s2.partial_cmp(&s3), c2.partial_cmp(&c3));
}

#[test]
fn enum_hash_matches() {
let c1 = EverythingEnumCore::E1;
let c2 = EverythingEnumCore::E3(1, 2);

let s1 = EverythingEnum::E1;
let s2 = EverythingEnum::E3(1, 2);

assert_eq!(hash_to_int(&s1), hash_to_int(&c1));
assert_eq!(hash_to_int(&s2), hash_to_int(&c2));
}

#[test]
pub fn copy_struct_eq() {
let s1 = EverythingStruct {
Expand Down Expand Up @@ -121,7 +287,7 @@ pub struct NonDefaultable {}
pub enum DefaultableEnum {
E1(NonDefaultable),
#[default]
E2(usize),
E2,
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions tests/perfect_derive_simple_generic_macro.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
macro_rules! make_test {
($trait_name:ident $(,$trait_name_tail:ident)*; $method_name:ident) => {
#[allow(unused)]
mod $method_name {
use perfect_derive::perfect_derive;

Expand Down

0 comments on commit a1ade35

Please sign in to comment.