diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index fa76adfc7..267fde902 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -25,8 +25,8 @@ use strum_macros::{EnumIter, EnumString, IntoStaticStr}; pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int"); struct IOValidator { - // whether the first type argument should be greater than the second - f_gt_s: bool, + // whether the first type argument should be greater than or equal to the second + f_ge_s: bool, } impl ValidateJustArgs for IOValidator { @@ -34,7 +34,7 @@ impl ValidateJustArgs for IOValidator { let [arg0, arg1] = collect_array(arg_values); let i: u8 = get_log_width(arg0)?; let o: u8 = get_log_width(arg1)?; - let cmp = if self.f_gt_s { i > o } else { i < o }; + let cmp = if self.f_ge_s { i >= o } else { i <= o }; if !cmp { return Err(SignatureError::InvalidTypeArgs); } @@ -102,12 +102,12 @@ impl MakeOpDef for IntOpDef { match self { iwiden_s | iwiden_u => CustomValidator::new_with_validator( int_polytype(2, vec![int_tv(0)], vec![int_tv(1)]), - IOValidator { f_gt_s: false }, + IOValidator { f_ge_s: false }, ) .into(), inarrow_s | inarrow_u => CustomValidator::new_with_validator( int_polytype(2, vec![int_tv(0)], vec![sum_with_error(int_tv(1))]), - IOValidator { f_gt_s: true }, + IOValidator { f_ge_s: true }, ) .into(), itobool => int_polytype(1, vec![int_tv(0)], type_row![BOOL_T]).into(), @@ -357,6 +357,22 @@ mod test { .signature(), FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))],) ); + assert_eq!( + IntOpDef::iwiden_s + .with_two_widths(3, 3) + .to_extension_op() + .unwrap() + .signature(), + FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(3))],) + ); + assert_eq!( + IntOpDef::inarrow_s + .with_two_widths(3, 3) + .to_extension_op() + .unwrap() + .signature(), + FunctionType::new(vec![int_type(ta(3))], vec![sum_with_error(int_type(ta(3)))],) + ); assert!( IntOpDef::iwiden_u .with_two_widths(4, 3)