Skip to content

Commit

Permalink
fix SLTIU
Browse files Browse the repository at this point in the history
imm was previously assumed to be 12 bit unsigned. but it should be 12 bit signed.
  • Loading branch information
zemse committed Oct 30, 2024
1 parent 8192c4d commit 932d131
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 49 deletions.
11 changes: 2 additions & 9 deletions ceno_emul/src/rv32im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ impl DecodedInstruction {
match self.codes() {
InsnCodes { format: R | U, .. } => false,
InsnCodes {
kind: SLLI | SRLI | SRAI | ADDI | SLTIU,
kind: SLLI | SRLI | SRAI | ADDI,
..
} => false,
_ => self.top_bit != 0,
Expand All @@ -330,15 +330,8 @@ impl DecodedInstruction {
| (self.rd & 0x1e)
}

/// checks if the imm requires sign extension
pub fn requires_sign_ext(&self) -> bool {
!matches!(self.codes(), InsnCodes { kind: SLTIU, .. })
}

pub fn imm_i(&self) -> u32 {
((self.requires_sign_ext() as u32) * self.top_bit * 0xffff_f000)
| (self.func7 << 5)
| self.rs2
(self.top_bit * 0xffff_f000) | (self.func7 << 5) | self.rs2
}

pub fn imm_s(&self) -> u32 {
Expand Down
37 changes: 37 additions & 0 deletions ceno_zkvm/src/gadgets/is_lt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ impl IsLtConfig {
.assign_instance_signed(instance, lkm, lhs, rhs)?;
Ok(())
}

pub fn assign_instance_rhs_signed<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
lkm: &mut LkMultiplicity,
lhs: u32,
rhs: SWord,
) -> Result<(), ZKVMError> {
set_val!(instance, self.is_lt, ((lhs as i64) < (rhs as i64)) as u64);
self.config
.assign_instance_rhs_signed(instance, lkm, lhs, rhs)?;
Ok(())
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -218,6 +231,30 @@ impl InnerLtConfig {
});
Ok(())
}

// TODO: refactor with the above function
pub fn assign_instance_rhs_signed<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
lkm: &mut LkMultiplicity,
lhs: u32,
rhs: SWord,
) -> Result<(), ZKVMError> {
let lhs = lhs as i64;
let rhs = rhs as i64;
let diff = if lhs < rhs {
Self::range(self.diff.len()) - lhs.abs_diff(rhs)
} else {
lhs.abs_diff(rhs)
};
self.diff.iter().enumerate().for_each(|(i, wit)| {
// extract the 16 bit limb from diff and assign to instance
let val = (diff >> (i * u16::BITS as usize)) & 0xffff;
lkm.assert_ux::<16>(val);
set_val!(instance, wit, val);
});
Ok(())
}
}

pub fn cal_lt_diff(is_lt: bool, max_num_u16_limbs: usize, lhs: u64, rhs: u64) -> u64 {
Expand Down
82 changes: 42 additions & 40 deletions ceno_zkvm/src/instructions/riscv/slti.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for SetLessThanImmInst
let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?;
let imm = cb.create_witin(|| "imm");

let (value_expr, is_rs1_neg) = match I::INST_KIND {
InsnKind::SLTIU => (rs1_read.value(), None),
let (value_expr, is_rs1_neg, max_num_u16_limbs) = match I::INST_KIND {
InsnKind::SLTIU => (rs1_read.value(), None, UINT_LIMBS + 1),
InsnKind::SLTI => {
let max_signed_limb_expr: Expression<_> =
((1 << (UInt::<E>::LIMB_BITS - 1)) - 1).into();
Expand All @@ -71,13 +71,22 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for SetLessThanImmInst
rs1_read.limbs.iter().last().unwrap().expr(), // msb limb
1,
)?;
(rs1_read.to_field_expr(is_rs1_neg.expr()), Some(is_rs1_neg))
(
rs1_read.to_field_expr(is_rs1_neg.expr()),
Some(is_rs1_neg),
UINT_LIMBS,
)
}
_ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND),
};

let lt =
IsLtConfig::construct_circuit(cb, || "rs1 < imm", value_expr, imm.expr(), UINT_LIMBS)?;
let lt = IsLtConfig::construct_circuit(
cb,
|| "rs1 < imm",
value_expr,
imm.expr(),
max_num_u16_limbs,
)?;
let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()]);

let i_insn = IInstructionConfig::<E>::construct_circuit(
Expand Down Expand Up @@ -122,7 +131,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for SetLessThanImmInst
InsnKind::SLTIU => {
config
.lt
.assign_instance(instance, lkm, rs1 as u64, imm as u64)?;
.assign_instance_rhs_signed(instance, lkm, rs1, imm as SWord)?;
}
InsnKind::SLTI => {
config.is_rs1_neg.as_ref().unwrap().assign_instance(
Expand Down Expand Up @@ -161,35 +170,38 @@ mod test {
verify::<SltiuOp>("lt = true, 0 < 1", 0, 1, 1);
verify::<SltiuOp>("lt = true, 1 < 2", 1, 2, 1);
verify::<SltiuOp>("lt = true, 10 < 20", 10, 20, 1);
verify::<SltiuOp>("lt = true, 2000 < 2500", 2000, 2500, 1);
// 0 <= imm <= 4095
verify::<SltiuOp>("lt = true, 0 < imm upper boundary", 0, 4095, 1);
verify::<SltiuOp>("lt = true, 2047 < imm upper boundary", 2047, 4095, 1);
verify::<SltiuOp>("lt = true, imm upper boundary", 1000, 4095, 1);
verify::<SltiuOp>("lt = true, 0 < imm upper boundary", 0, 2047, 1);
}

#[test]
fn test_sltiu_false() {
verify::<SltiuOp>("lt = false, 1 < 0", 1, 0, 0);
verify::<SltiuOp>("lt = false, 2 < 1", 2, 1, 0);
verify::<SltiuOp>("lt = false, 0 < -1", 0, -1, 0); //
verify::<SltiuOp>("lt = false, 1 < -1", 1, -1, 0);
verify::<SltiuOp>("lt = false, 100 < 50", 100, 50, 0);
verify::<SltiuOp>("lt = false, 500 < 100", 500, 100, 0);
verify::<SltiuOp>("lt = false, 2500 < 2500", 2500, 2500, 0);
verify::<SltiuOp>("lt = false, 4095 < 0", 4095, 0, 0);
verify::<SltiuOp>("lt = false, 4095 < 2048", 4095, 2048, 0);
verify::<SltiuOp>("lt = false, 4095 < 4095", 4095, 4095, 0);
// rs1 max value
verify::<SltiuOp>("lt = false, 0xFFFFFFFF < 0", 0xFFFFFFFF, 0, 0);
verify::<SltiuOp>("lt = false, 0xFFFFFFFF < 4095", 0xFFFFFFFF, 4095, 0);
verify::<SltiuOp>("lt = false, 100000 < 2047", 100000, 2047, 0);
verify::<SltiuOp>("lt = false, 100000 < 0", 100000, 0, 0);
verify::<SltiuOp>("lt = false, 0 == 0", 0, 0, 0);
verify::<SltiuOp>("lt = false, 1 == 1", 1, 1, 0);
// -2048 <= imm <= 2047
verify::<SltiuOp>("lt = false, imm upper bondary", u32::MAX, 2047, 0);
verify::<SltiuOp>("lt = false, imm lower bondary", u32::MAX, -2048, 0);
}

#[test]
fn test_sltiu_random() {
let mut rng = rand::thread_rng();
let a: u32 = rng.gen::<u32>();
let b: u32 = rng.gen::<u32>() % 4096;
let b: i32 = rng.gen_range(-2048..2048);
println!("random: {} <? {}", a, b); // For debugging, do not delete.
verify::<SltiuOp>("random unsigned comparison", a, b, (a < b) as u32);
verify::<SltiuOp>(
"random unsigned comparison",
a,
b,
((a as i64) < (b as i64)) as u32,
);
}

#[test]
Expand All @@ -198,35 +210,25 @@ mod test {
verify::<SltiOp>("lt = true, 1 < 2", 1, 2, 1);
verify::<SltiOp>("lt = true, -1 < 0", -1i32 as u32, 0, 1);
verify::<SltiOp>("lt = true, -1 < 1", -1i32 as u32, 1, 1);
verify::<SltiOp>("lt = true, -2 < -1", -2i32 as u32, -1i32 as u32, 1);
verify::<SltiOp>("lt = true, -2 < -1", -2i32 as u32, -1, 1);
// -2048 <= imm <= 2047
verify::<SltiOp>("lt = true, imm upper bondary", i32::MIN as u32, 2047, 1);
verify::<SltiOp>(
"lt = true, imm lower bondary",
i32::MIN as u32,
-2048i32 as u32,
1,
);
verify::<SltiOp>("lt = true, imm lower bondary", i32::MIN as u32, -2048, 1);
}

#[test]
fn test_slti_false() {
verify::<SltiOp>("lt = false, 1 < 0", 1, 0, 0);
verify::<SltiOp>("lt = false, 2 < 1", 2, 1, 0);
verify::<SltiOp>("lt = false, 0 < -1", 0, -1i32 as u32, 0);
verify::<SltiOp>("lt = false, 1 < -1", 1, -1i32 as u32, 0);
verify::<SltiOp>("lt = false, -1 < -2", -1i32 as u32, -2i32 as u32, 0);
verify::<SltiOp>("lt = false, 0 < -1", 0, -1, 0);
verify::<SltiOp>("lt = false, 1 < -1", 1, -1, 0);
verify::<SltiOp>("lt = false, -1 < -2", -1i32 as u32, -2, 0);
verify::<SltiOp>("lt = false, 0 == 0", 0, 0, 0);
verify::<SltiOp>("lt = false, 1 == 1", 1, 1, 0);
verify::<SltiOp>("lt = false, -1 == -1", -1i32 as u32, -1i32 as u32, 0);
verify::<SltiOp>("lt = false, -1 == -1", -1i32 as u32, -1, 0);
// -2048 <= imm <= 2047
verify::<SltiOp>("lt = false, imm upper bondary", i32::MAX as u32, 2047, 0);
verify::<SltiOp>(
"lt = false, imm lower bondary",
i32::MAX as u32,
-2048i32 as u32,
0,
);
verify::<SltiOp>("lt = false, imm lower bondary", i32::MAX as u32, -2048, 0);
}

#[test]
Expand All @@ -235,14 +237,14 @@ mod test {
let a: i32 = rng.gen();
let b: i32 = rng.gen_range(-2048..2048);
println!("random: {} <? {}", a, b); // For debugging, do not delete.
verify::<SltiOp>("random 1", a as u32, b as u32, (a < b) as u32);
verify::<SltiOp>("random 1", a as u32, b, (a < b) as u32);
}

fn verify<I: RIVInstruction>(name: &'static str, rs1_read: u32, imm: u32, expected_rd: u32) {
fn verify<I: RIVInstruction>(name: &'static str, rs1_read: u32, imm: i32, expected_rd: u32) {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);

let insn_code = encode_rv32(I::INST_KIND, 2, 0, 4, imm);
let insn_code = encode_rv32(I::INST_KIND, 2, 0, 4, imm as u32);

let config = cb
.namespace(
Expand Down

0 comments on commit 932d131

Please sign in to comment.