Skip to content

Commit

Permalink
fix!: Combine ConstIntU and ConstIntS (#974)
Browse files Browse the repository at this point in the history
Fixes #970 .

Drive-by: change name of integer constant to show width rather than log
width, e.g. `u32(1000)` not `u5(1000)`.

(Note that the name always shows the unsigned interpretation. We could
make it show both, but I'm not sure it's worth it.)

I assume this will break serialization. Do I need to change the
serialization version?
  • Loading branch information
cqc-alec authored Apr 25, 2024
1 parent 0c354b6 commit 529f553
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 76 deletions.
4 changes: 2 additions & 2 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,14 @@ mod test {
use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES};
use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES};
use crate::std_extensions::logic::{self, NaryLogic};

use rstest::rstest;

/// int to constant
fn i2c(b: u64) -> Value {
Value::extension(ConstIntU::new(5, b).unwrap())
Value::extension(ConstInt::new_u(5, b).unwrap())
}

/// float to constant
Expand Down
4 changes: 2 additions & 2 deletions hugr/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ mod test {
use crate::ops::{Lift, OpType, Value};
use crate::std_extensions::arithmetic::float_types;
use crate::std_extensions::arithmetic::int_ops::{self, IntOpDef};
use crate::std_extensions::arithmetic::int_types::{self, ConstIntU};
use crate::std_extensions::arithmetic::int_types::{self, ConstInt};
use crate::types::FunctionType;
use crate::utils::test_quantum_extension;
use crate::{type_row, Direction, HugrView, Node, Port};
Expand Down Expand Up @@ -184,7 +184,7 @@ mod test {
d: &mut DFGBuilder<T>,
) -> Result<Wire, Box<dyn std::error::Error>> {
let int_ty = &int_types::INT_TYPES[6];
let cst = Value::extension(ConstIntU::new(6, 15)?);
let cst = Value::extension(ConstInt::new_u(6, 15)?);
let c1 = d.add_load_const(cst);
let [lifted] = d
.add_dataflow_op(
Expand Down
14 changes: 7 additions & 7 deletions hugr/src/std_extensions/arithmetic/conversions/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
ops::constant::CustomConst,
std_extensions::arithmetic::{
float_types::ConstF64,
int_types::{get_log_width, ConstIntS, ConstIntU, INT_TYPES},
int_types::{get_log_width, ConstInt, INT_TYPES},
},
types::ConstTypeError,
IncomingPort,
Expand Down Expand Up @@ -78,7 +78,7 @@ impl ConstFold for TruncU {
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
fold_trunc(type_args, consts, |f, log_width| {
ConstIntU::new(log_width, f.trunc() as u64).map(Into::into)
ConstInt::new_u(log_width, f.trunc() as u64).map(Into::into)
})
}
}
Expand All @@ -92,7 +92,7 @@ impl ConstFold for TruncS {
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
fold_trunc(type_args, consts, |f, log_width| {
ConstIntS::new(log_width, f.trunc() as i64).map(Into::into)
ConstInt::new_s(log_width, f.trunc() as i64).map(Into::into)
})
}
}
Expand All @@ -105,8 +105,8 @@ impl ConstFold for ConvertU {
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
let u: &ConstIntU = get_input(consts)?;
let f = u.value() as f64;
let u: &ConstInt = get_input(consts)?;
let f = u.value_u() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
}
Expand All @@ -119,8 +119,8 @@ impl ConstFold for ConvertS {
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
let u: &ConstIntS = get_input(consts)?;
let f = u.value() as f64;
let u: &ConstInt = get_input(consts)?;
let f = u.value_s() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
}
124 changes: 59 additions & 65 deletions hugr/src/std_extensions/arithmetic/int_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ const fn is_valid_log_width(n: u8) -> bool {
n < LOG_WIDTH_BOUND
}

/// The maximum allowed log width.
pub const LOG_WIDTH_MAX: u8 = 6;

/// The smallest forbidden log width.
pub const LOG_WIDTH_BOUND: u8 = 7;
pub const LOG_WIDTH_BOUND: u8 = LOG_WIDTH_MAX + 1;

/// Type parameter for the log width of the integer.
#[allow(clippy::assertions_on_constants)]
Expand All @@ -71,23 +74,22 @@ const fn type_arg(log_width: u8) -> TypeArg {
n: log_width as u64,
}
}
/// An unsigned integer
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct ConstIntU {
log_width: u8,
value: u64,
}

/// A signed integer
/// An integer (either signed or unsigned)
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct ConstIntS {
pub struct ConstInt {
log_width: u8,
value: i64,
// We always use a u64 for the value. The interpretation is:
// - as an unsigned integer, (value mod 2^N);
// - as a signed integer, (value mod 2^(N-1) - 2^(N-1)*a)
// where N = 2^log_width and a is the (N-1)th bit of x (counting from
// 0 = least significant bit).
value: u64,
}

impl ConstIntU {
/// Create a new [`ConstIntU`]
pub fn new(log_width: u8, value: u64) -> Result<Self, ConstTypeError> {
impl ConstInt {
/// Create a new [`ConstInt`] with a given width and unsigned value
pub fn new_u(log_width: u8, value: u64) -> Result<Self, ConstTypeError> {
if !is_valid_log_width(log_width) {
return Err(ConstTypeError::CustomCheckFail(
crate::types::CustomCheckFailure::Message("Invalid integer width.".to_owned()),
Expand All @@ -103,20 +105,8 @@ impl ConstIntU {
Ok(Self { log_width, value })
}

/// Returns the value of the constant
pub fn value(&self) -> u64 {
self.value
}

/// Returns the number of bits of the constant
pub fn log_width(&self) -> u8 {
self.log_width
}
}

impl ConstIntS {
/// Create a new [`ConstIntS`]
pub fn new(log_width: u8, value: i64) -> Result<Self, ConstTypeError> {
/// Create a new [`ConstInt`] with a given width and signed value
pub fn new_s(log_width: u8, value: i64) -> Result<Self, ConstTypeError> {
if !is_valid_log_width(log_width) {
return Err(ConstTypeError::CustomCheckFail(
crate::types::CustomCheckFailure::Message("Invalid integer width.".to_owned()),
Expand All @@ -130,42 +120,46 @@ impl ConstIntS {
),
));
}
Ok(Self { log_width, value })
}

/// Returns the value of the constant
pub fn value(&self) -> i64 {
self.value
Ok(Self {
log_width,
value: (if value >= 0 || log_width == LOG_WIDTH_MAX {
value
} else {
value + (1i64 << width)
}) as u64,
})
}

/// Returns the number of bits of the constant
pub fn log_width(&self) -> u8 {
self.log_width
}
}

#[typetag::serde]
impl CustomConst for ConstIntU {
fn name(&self) -> SmolStr {
format!("u{}({})", self.log_width, self.value).into()
}
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
}

fn extension_reqs(&self) -> ExtensionSet {
ExtensionSet::singleton(&EXTENSION_ID)
/// Returns the value of the constant as an unsigned integer
pub fn value_u(&self) -> u64 {
self.value
}

fn get_type(&self) -> Type {
int_type(type_arg(self.log_width))
/// Returns the value of the constant as a signed integer
pub fn value_s(&self) -> i64 {
if self.log_width == LOG_WIDTH_MAX {
self.value as i64
} else {
let width = 1u8 << self.log_width;
if ((self.value << 1) >> width) == 0 {
self.value as i64
} else {
self.value as i64 - (1i64 << width)
}
}
}
}

#[typetag::serde]
impl CustomConst for ConstIntS {
impl CustomConst for ConstInt {
fn name(&self) -> SmolStr {
format!("i{}({})", self.log_width, self.value).into()
format!("u{}({})", 1u8 << self.log_width, self.value).into()
}
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
Expand Down Expand Up @@ -239,43 +233,43 @@ mod test {

#[test]
fn test_int_consts() {
let const_u32_7 = ConstIntU::new(5, 7);
let const_u64_7 = ConstIntU::new(6, 7);
let const_u32_8 = ConstIntU::new(5, 8);
let const_u32_7 = ConstInt::new_u(5, 7);
let const_u64_7 = ConstInt::new_u(6, 7);
let const_u32_8 = ConstInt::new_u(5, 8);
assert_ne!(const_u32_7, const_u64_7);
assert_ne!(const_u32_7, const_u32_8);
assert_eq!(const_u32_7, ConstIntU::new(5, 7));
assert_eq!(const_u32_7, ConstInt::new_u(5, 7));

assert_matches!(
ConstIntU::new(3, 256),
ConstInt::new_u(3, 256),
Err(ConstTypeError::CustomCheckFail(_))
);
assert_matches!(
ConstIntU::new(9, 256),
ConstInt::new_u(9, 256),
Err(ConstTypeError::CustomCheckFail(_))
);
assert_matches!(
ConstIntS::new(3, 128),
ConstInt::new_s(3, 128),
Err(ConstTypeError::CustomCheckFail(_))
);
assert!(ConstIntS::new(3, -128).is_ok());
assert!(ConstInt::new_s(3, -128).is_ok());

let const_u32_7 = const_u32_7.unwrap();
assert!(const_u32_7.equal_consts(&ConstIntU::new(5, 7).unwrap()));
assert!(const_u32_7.equal_consts(&ConstInt::new_u(5, 7).unwrap()));
assert_eq!(const_u32_7.log_width(), 5);
assert_eq!(const_u32_7.value(), 7);
assert_eq!(const_u32_7.value_u(), 7);
assert!(const_u32_7.validate().is_ok());

assert_eq!(const_u32_7.name(), "u5(7)");
assert_eq!(const_u32_7.name(), "u32(7)");

let const_i32_2 = ConstIntS::new(5, -2).unwrap();
assert!(const_i32_2.equal_consts(&ConstIntS::new(5, -2).unwrap()));
let const_i32_2 = ConstInt::new_s(5, -2).unwrap();
assert!(const_i32_2.equal_consts(&ConstInt::new_s(5, -2).unwrap()));
assert_eq!(const_i32_2.log_width(), 5);
assert_eq!(const_i32_2.value(), -2);
assert_eq!(const_i32_2.value_s(), -2);
assert!(const_i32_2.validate().is_ok());
assert_eq!(const_i32_2.name(), "i5(-2)");
assert_eq!(const_i32_2.name(), "u32(4294967294)");

ConstIntS::new(50, -2).unwrap_err();
ConstIntU::new(50, 2).unwrap_err();
ConstInt::new_s(50, -2).unwrap_err();
ConstInt::new_u(50, 2).unwrap_err();
}
}

0 comments on commit 529f553

Please sign in to comment.