Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for generic integer type conversions to Expression::Constant #333

Merged
merged 4 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,12 +652,43 @@ impl<F: SmallField, E: ExtensionField<BaseField = F>> ToExpr<E> for F {
}
}

impl<F: SmallField, E: ExtensionField<BaseField = F>> From<usize> for Expression<E> {
fn from(value: usize) -> Self {
Expression::Constant(F::from(value as u64))
// Implement From trait for unsigned types of at most 64 bits
macro_rules! impl_from_unsigned {
($($t:ty),*) => {
$(
impl<F: SmallField, E: ExtensionField<BaseField = F>> From<$t> for Expression<E> {
fn from(value: $t) -> Self {
Expression::Constant(F::from(value as u64))
}
}
)*
};
}
impl_from_unsigned!(u8, u16, u32, u64, usize);

// Implement From trait for u128 separately since it requires explicit reduction
impl<F: SmallField, E: ExtensionField<BaseField = F>> From<u128> for Expression<E> {
Copy link
Collaborator

@naure naure Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this From<u128> really necessary? It is a lossy conversion therefore error-prone.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing conversion from usize prior to this PR was already slightly lossy in this sense because the value has to be renormalized if between p_goldilocks and 2^64 - 1. The way I interpret the syntax (and the way it's implemented here) is "integer primitive types represent actual integers, and any conversion to a BaseField constant is by the 'reduce mod p' map", which seems natural.

fn from(value: u128) -> Self {
let reduced = value.rem_euclid(F::MODULUS_U64 as u128) as u64;
Expression::Constant(F::from(reduced))
}
}

// Implement From trait for signed types
macro_rules! impl_from_signed {
($($t:ty),*) => {
$(
impl<F: SmallField, E: ExtensionField<BaseField = F>> From<$t> for Expression<E> {
fn from(value: $t) -> Self {
let reduced = (value as i128).rem_euclid(F::MODULUS_U64 as i128) as u64;
hero78119 marked this conversation as resolved.
Show resolved Hide resolved
Expression::Constant(F::from(reduced))
hero78119 marked this conversation as resolved.
Show resolved Hide resolved
}
}
)*
};
}
impl_from_signed!(i8, i16, i32, i64, i128, isize);

impl<E: ExtensionField> Display for Expression<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut wtns = vec![];
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/gadgets/is_lt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl IsLtConfig {
.reduce(|a, b| a + b)
.expect("reduce error");

let range = (1 << (max_num_u16_limbs * u16::BITS as usize)).into();
let range = (1u64 << (max_num_u16_limbs * u16::BITS as usize)).into();

cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?;

Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/b_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ impl<E: ExtensionField> BInstructionConfig<E> {
// Fetch instruction
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
insn_kind.codes().opcode.into(),
0.into(),
(insn_kind.codes().func3 as usize).into(),
insn_kind.codes().func3.into(),
rs1.id.expr(),
rs2.id.expr(),
imm.expr(),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/divu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E
|| "outcome_is_zero",
is_zero.expr(),
outcome_value.clone(),
((1 << UInt::<E>::M) - 1).into(),
((1u64 << UInt::<E>::M) - 1).into(),
outcome_value,
)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/i_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ impl<E: ExtensionField> IInstructionConfig<E> {
// Fetch the instruction.
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
insn_kind.codes().opcode.into(),
rd.id.expr(),
(insn_kind.codes().func3 as usize).into(),
insn_kind.codes().func3.into(),
rs1.id.expr(),
0.into(),
imm.clone(),
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/instructions/riscv/r_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ impl<E: ExtensionField> RInstructionConfig<E> {
// Fetch instruction
circuit_builder.lk_fetch(&InsnRecord::new(
vm_state.pc.expr(),
(insn_kind.codes().opcode as usize).into(),
insn_kind.codes().opcode.into(),
rd.id.expr(),
(insn_kind.codes().func3 as usize).into(),
insn_kind.codes().func3.into(),
rs1.id.expr(),
rs2.id.expr(),
(insn_kind.codes().func7 as usize).into(),
insn_kind.codes().func7.into(),
))?;

Ok(RInstructionConfig {
Expand Down
Loading