Skip to content

Commit

Permalink
Add support for generic integer type conversions to Expression::Const…
Browse files Browse the repository at this point in the history
…ant (#333)

Current implementation only provides `From<usize>`, which requires
explicit type conversions in various locations. This PR provides support
for type conversions for arbitrary primitive integer types,
specifically: `u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128,
isize`

In reference to [this
comment](#305 (comment)).

---------

Co-authored-by: Bryan Gillespie <[email protected]>
Co-authored-by: Ming <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2024
1 parent e7e95cf commit 992bc30
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 12 deletions.
37 changes: 34 additions & 3 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,12 +652,43 @@ impl<F: SmallField, E: ExtensionField<BaseField = F>> ToExpr<E> for F {
}
}

impl<F: SmallField, E: ExtensionField<BaseField = F>> From<usize> for Expression<E> {
fn from(value: usize) -> Self {
Expression::Constant(F::from(value as u64))
// Implement From trait for unsigned types of at most 64 bits
macro_rules! impl_from_unsigned {
($($t:ty),*) => {
$(
impl<F: SmallField, E: ExtensionField<BaseField = F>> From<$t> for Expression<E> {
fn from(value: $t) -> Self {
Expression::Constant(F::from(value as u64))
}
}
)*
};
}
impl_from_unsigned!(u8, u16, u32, u64, usize);

// Implement From trait for u128 separately since it requires explicit reduction
impl<F: SmallField, E: ExtensionField<BaseField = F>> From<u128> for Expression<E> {
fn from(value: u128) -> Self {
let reduced = value.rem_euclid(F::MODULUS_U64 as u128) as u64;
Expression::Constant(F::from(reduced))
}
}

// Implement From trait for signed types
macro_rules! impl_from_signed {
($($t:ty),*) => {
$(
impl<F: SmallField, E: ExtensionField<BaseField = F>> From<$t> for Expression<E> {
fn from(value: $t) -> Self {
let reduced = (value as i128).rem_euclid(F::MODULUS_U64 as i128) as u64;
Expression::Constant(F::from(reduced))
}
}
)*
};
}
impl_from_signed!(i8, i16, i32, i64, i128, isize);

impl<E: ExtensionField> Display for Expression<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut wtns = vec![];
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/gadgets/is_lt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl IsLtConfig {
.reduce(|a, b| a + b)
.expect("reduce error");

let range = (1 << (max_num_u16_limbs * u16::BITS as usize)).into();
let range = (1u64 << (max_num_u16_limbs * u16::BITS as usize)).into();

cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?;

Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/b_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ impl<E: ExtensionField> BInstructionConfig<E> {
// Fetch instruction
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
insn_kind.codes().opcode.into(),
0.into(),
(insn_kind.codes().func3 as usize).into(),
insn_kind.codes().func3.into(),
rs1.id.expr(),
rs2.id.expr(),
imm.expr(),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/divu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E
|| "outcome_is_zero",
is_zero.expr(),
outcome_value.clone(),
((1 << UInt::<E>::M) - 1).into(),
((1u64 << UInt::<E>::M) - 1).into(),
outcome_value,
)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/i_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ impl<E: ExtensionField> IInstructionConfig<E> {
// Fetch the instruction.
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
insn_kind.codes().opcode.into(),
rd.id.expr(),
(insn_kind.codes().func3 as usize).into(),
insn_kind.codes().func3.into(),
rs1.id.expr(),
0.into(),
imm.clone(),
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/instructions/riscv/r_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ impl<E: ExtensionField> RInstructionConfig<E> {
// Fetch instruction
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
insn_kind.codes().opcode.into(),
rd.id.expr(),
(insn_kind.codes().func3 as usize).into(),
insn_kind.codes().func3.into(),
rs1.id.expr(),
rs2.id.expr(),
(insn_kind.codes().func7 as usize).into(),
insn_kind.codes().func7.into(),
))?;

Ok(RInstructionConfig {
Expand Down

0 comments on commit 992bc30

Please sign in to comment.