diff --git a/core/src/air/builder.rs b/core/src/air/builder.rs index 903aac3045..eba567d920 100644 --- a/core/src/air/builder.rs +++ b/core/src/air/builder.rs @@ -84,6 +84,7 @@ pub trait BaseAirBuilder: AirBuilder + MessageBuilder /// A trait which contains methods for byte interactions in an AIR. pub trait ByteAirBuilder: BaseAirBuilder { /// Sends a byte operation to be processed. + #[allow(clippy::too_many_arguments)] fn send_byte( &mut self, opcode: impl Into, @@ -91,9 +92,19 @@ pub trait ByteAirBuilder: BaseAirBuilder { b: impl Into, c: impl Into, shard: impl Into, + channel: impl Into, multiplicity: impl Into, ) { - self.send_byte_pair(opcode, a, Self::Expr::zero(), b, c, shard, multiplicity) + self.send_byte_pair( + opcode, + a, + Self::Expr::zero(), + b, + c, + shard, + channel, + multiplicity, + ) } /// Sends a byte operation with two outputs to be processed. @@ -106,6 +117,7 @@ pub trait ByteAirBuilder: BaseAirBuilder { b: impl Into, c: impl Into, shard: impl Into, + channel: impl Into, multiplicity: impl Into, ) { self.send(AirInteraction::new( @@ -116,6 +128,7 @@ pub trait ByteAirBuilder: BaseAirBuilder { b.into(), c.into(), shard.into(), + channel.into(), ], multiplicity.into(), InteractionKind::Byte, @@ -123,6 +136,7 @@ pub trait ByteAirBuilder: BaseAirBuilder { } /// Receives a byte operation to be processed. + #[allow(clippy::too_many_arguments)] fn receive_byte( &mut self, opcode: impl Into, @@ -130,9 +144,19 @@ pub trait ByteAirBuilder: BaseAirBuilder { b: impl Into, c: impl Into, shard: impl Into, + channel: impl Into, multiplicity: impl Into, ) { - self.receive_byte_pair(opcode, a, Self::Expr::zero(), b, c, shard, multiplicity) + self.receive_byte_pair( + opcode, + a, + Self::Expr::zero(), + b, + c, + shard, + channel, + multiplicity, + ) } /// Receives a byte operation with two outputs to be processed. @@ -145,6 +169,7 @@ pub trait ByteAirBuilder: BaseAirBuilder { b: impl Into, c: impl Into, shard: impl Into, + channel: impl Into, multiplicity: impl Into, ) { self.receive(AirInteraction::new( @@ -155,6 +180,7 @@ pub trait ByteAirBuilder: BaseAirBuilder { b.into(), c.into(), shard.into(), + channel.into(), ], multiplicity.into(), InteractionKind::Byte, @@ -219,6 +245,7 @@ pub trait WordAirBuilder: ByteAirBuilder { &mut self, input: &[impl Into + Clone], shard: impl Into + Clone, + channel: impl Into + Clone, mult: impl Into + Clone, ) { let mut index = 0; @@ -229,6 +256,7 @@ pub trait WordAirBuilder: ByteAirBuilder { input[index].clone(), input[index + 1].clone(), shard.clone(), + channel.clone(), mult.clone(), ); index += 2; @@ -240,6 +268,7 @@ pub trait WordAirBuilder: ByteAirBuilder { input[index].clone(), Self::Expr::zero(), shard.clone(), + channel.clone(), mult.clone(), ); } @@ -250,6 +279,7 @@ pub trait WordAirBuilder: ByteAirBuilder { &mut self, input: &[impl Into + Copy], shard: impl Into + Clone, + channel: impl Into + Clone, mult: impl Into + Clone, ) { input.iter().for_each(|limb| { @@ -259,6 +289,7 @@ pub trait WordAirBuilder: ByteAirBuilder { Self::Expr::zero(), Self::Expr::zero(), shard.clone(), + channel.clone(), mult.clone(), ); }); @@ -268,6 +299,7 @@ pub trait WordAirBuilder: ByteAirBuilder { /// A trait which contains methods related to ALU interactions in an AIR. pub trait AluAirBuilder: BaseAirBuilder { /// Sends an ALU operation to be processed. + #[allow(clippy::too_many_arguments)] fn send_alu( &mut self, opcode: impl Into, @@ -275,6 +307,7 @@ pub trait AluAirBuilder: BaseAirBuilder { b: Word>, c: Word>, shard: impl Into, + channel: impl Into, multiplicity: impl Into, ) { let values = once(opcode.into()) @@ -282,6 +315,7 @@ pub trait AluAirBuilder: BaseAirBuilder { .chain(b.0.into_iter().map(Into::into)) .chain(c.0.into_iter().map(Into::into)) .chain(once(shard.into())) + .chain(once(channel.into())) .collect(); self.send(AirInteraction::new( @@ -292,6 +326,7 @@ pub trait AluAirBuilder: BaseAirBuilder { } /// Receives an ALU operation to be processed. + #[allow(clippy::too_many_arguments)] fn receive_alu( &mut self, opcode: impl Into, @@ -299,6 +334,7 @@ pub trait AluAirBuilder: BaseAirBuilder { b: Word>, c: Word>, shard: impl Into, + channel: impl Into, multiplicity: impl Into, ) { let values = once(opcode.into()) @@ -306,6 +342,7 @@ pub trait AluAirBuilder: BaseAirBuilder { .chain(b.0.into_iter().map(Into::into)) .chain(c.0.into_iter().map(Into::into)) .chain(once(shard.into())) + .chain(once(channel.into())) .collect(); self.receive(AirInteraction::new( @@ -316,9 +353,11 @@ pub trait AluAirBuilder: BaseAirBuilder { } /// Sends an syscall operation to be processed (with "ECALL" opcode). + #[allow(clippy::too_many_arguments)] fn send_syscall( &mut self, shard: impl Into + Clone, + channel: impl Into + Clone, clk: impl Into + Clone, syscall_id: impl Into + Clone, arg1: impl Into + Clone, @@ -328,6 +367,7 @@ pub trait AluAirBuilder: BaseAirBuilder { self.send(AirInteraction::new( vec![ shard.clone().into(), + channel.clone().into(), clk.clone().into(), syscall_id.clone().into(), arg1.clone().into(), @@ -339,9 +379,11 @@ pub trait AluAirBuilder: BaseAirBuilder { } /// Receives a syscall operation to be processed. + #[allow(clippy::too_many_arguments)] fn receive_syscall( &mut self, shard: impl Into + Clone, + channel: impl Into + Clone, clk: impl Into + Clone, syscall_id: impl Into + Clone, arg1: impl Into + Clone, @@ -351,6 +393,7 @@ pub trait AluAirBuilder: BaseAirBuilder { self.receive(AirInteraction::new( vec![ shard.clone().into(), + channel.clone().into(), clk.clone().into(), syscall_id.clone().into(), arg1.clone().into(), @@ -371,6 +414,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { fn eval_memory_access + Clone>( &mut self, shard: impl Into, + channel: impl Into, clk: impl Into, addr: impl Into, memory_access: &impl MemoryCols, @@ -378,13 +422,20 @@ pub trait MemoryAirBuilder: BaseAirBuilder { ) { let do_check: Self::Expr = do_check.into(); let shard: Self::Expr = shard.into(); + let channel: Self::Expr = channel.into(); let clk: Self::Expr = clk.into(); let mem_access = memory_access.access(); self.assert_bool(do_check.clone()); // Verify that the current memory access time is greater than the previous's. - self.eval_memory_access_timestamp(mem_access, do_check.clone(), shard.clone(), clk.clone()); + self.eval_memory_access_timestamp( + mem_access, + do_check.clone(), + shard.clone(), + channel, + clk.clone(), + ); // Add to the memory argument. let addr = addr.into(); @@ -420,6 +471,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { fn eval_memory_access_slice + Copy>( &mut self, shard: impl Into + Copy, + channel: impl Into + Clone, clk: impl Into + Clone, initial_addr: impl Into + Clone, memory_access_slice: &[impl MemoryCols], @@ -428,6 +480,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { for (i, access_slice) in memory_access_slice.iter().enumerate() { self.eval_memory_access( shard, + channel.clone(), clk.clone(), initial_addr.clone().into() + Self::Expr::from_canonical_usize(i * 4), access_slice, @@ -447,6 +500,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { mem_access: &MemoryAccessCols + Clone>, do_check: impl Into, shard: impl Into + Clone, + channel: impl Into + Clone, clk: impl Into, ) { let do_check: Self::Expr = do_check.into(); @@ -487,6 +541,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { mem_access.diff_16bit_limb.clone(), mem_access.diff_8bit_limb.clone(), shard.clone(), + channel.clone(), do_check, ); } @@ -503,6 +558,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { limb_16: impl Into + Clone, limb_8: impl Into + Clone, shard: impl Into + Clone, + channel: impl Into + Clone, do_check: impl Into + Clone, ) { // Verify that value = limb_16 + limb_8 * 2^16. @@ -519,6 +575,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { Self::Expr::zero(), Self::Expr::zero(), shard.clone(), + channel.clone(), do_check.clone(), ); @@ -528,6 +585,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder { Self::Expr::zero(), limb_8, shard.clone(), + channel.clone(), do_check, ) } diff --git a/core/src/alu/add_sub/mod.rs b/core/src/alu/add_sub/mod.rs index 3de0819383..2321427c53 100644 --- a/core/src/alu/add_sub/mod.rs +++ b/core/src/alu/add_sub/mod.rs @@ -35,6 +35,9 @@ pub struct AddSubCols { /// The shard number, used for byte lookup table. pub shard: T, + /// The channel number, used for byte lookup table. + pub channel: T, + /// Instance of `AddOperation` to handle addition logic in `AddSubChip`'s ALU operations. /// It's result will be `a` for the add operation and `b` for the sub operation. pub add_operation: AddOperation, @@ -88,14 +91,20 @@ impl MachineAir for AddSubChip { let cols: &mut AddSubCols = row.as_mut_slice().borrow_mut(); let is_add = event.opcode == Opcode::ADD; cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.is_add = F::from_bool(is_add); cols.is_sub = F::from_bool(!is_add); let operand_1 = if is_add { event.b } else { event.a }; let operand_2 = event.c; - cols.add_operation - .populate(&mut record, event.shard, operand_1, operand_2); + cols.add_operation.populate( + &mut record, + event.shard, + event.channel, + operand_1, + operand_2, + ); cols.operand_1 = Word::from(operand_1); cols.operand_2 = Word::from(operand_2); row @@ -150,6 +159,7 @@ where local.operand_2, local.add_operation, local.shard, + local.channel, local.is_add + local.is_sub, ); @@ -161,6 +171,7 @@ where local.operand_1, local.operand_2, local.shard, + local.channel, local.is_add, ); @@ -171,6 +182,7 @@ where local.add_operation.value, local.operand_2, local.shard, + local.channel, local.is_sub, ); @@ -203,7 +215,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.add_events = vec![AluEvent::new(0, 0, Opcode::ADD, 14, 8, 6)]; + shard.add_events = vec![AluEvent::new(0, 0, 0, Opcode::ADD, 14, 8, 6)]; let chip = AddSubChip::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -216,12 +228,13 @@ mod tests { let mut challenger = config.challenger(); let mut shard = ExecutionRecord::default(); - for _ in 0..1000 { + for i in 0..1000 { let operand_1 = thread_rng().gen_range(0..u32::MAX); let operand_2 = thread_rng().gen_range(0..u32::MAX); let result = operand_1.wrapping_add(operand_2); shard.add_events.push(AluEvent::new( 0, + i % 2, 0, Opcode::ADD, result, @@ -229,12 +242,13 @@ mod tests { operand_2, )); } - for _ in 0..1000 { + for i in 0..1000 { let operand_1 = thread_rng().gen_range(0..u32::MAX); let operand_2 = thread_rng().gen_range(0..u32::MAX); let result = operand_1.wrapping_sub(operand_2); shard.add_events.push(AluEvent::new( 0, + i % 2, 0, Opcode::SUB, result, diff --git a/core/src/alu/bitwise/mod.rs b/core/src/alu/bitwise/mod.rs index 61ae4b1506..3e7227b709 100644 --- a/core/src/alu/bitwise/mod.rs +++ b/core/src/alu/bitwise/mod.rs @@ -28,6 +28,9 @@ pub struct BitwiseCols { /// The shard number, used for byte lookup table. pub shard: T, + /// The channel number, used for byte lookup table. + pub channel: T, + /// The output operand. pub a: Word, @@ -73,6 +76,7 @@ impl MachineAir for BitwiseChip { let c = event.c.to_le_bytes(); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.a = Word::from(event.a); cols.b = Word::from(event.b); cols.c = Word::from(event.c); @@ -84,6 +88,7 @@ impl MachineAir for BitwiseChip { for ((b_a, b_b), b_c) in a.into_iter().zip(b).zip(c) { let byte_event = ByteLookupEvent { shard: event.shard, + channel: event.channel, opcode: ByteOpcode::from(event.opcode), a1: b_a as u32, a2: 0, @@ -137,7 +142,15 @@ where // Get a multiplicity of `1` only for a true row. let mult = local.is_xor + local.is_or + local.is_and; for ((a, b), c) in local.a.into_iter().zip(local.b).zip(local.c) { - builder.send_byte(opcode.clone(), a, b, c, local.shard, mult.clone()); + builder.send_byte( + opcode.clone(), + a, + b, + c, + local.shard, + local.channel, + mult.clone(), + ); } // Get the cpu opcode, which corresponds to the opcode being sent in the CPU table. @@ -152,6 +165,7 @@ where local.b, local.c, local.shard, + local.channel, local.is_xor + local.is_or + local.is_and, ); @@ -180,7 +194,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.bitwise_events = vec![AluEvent::new(0, 0, Opcode::XOR, 25, 10, 19)]; + shard.bitwise_events = vec![AluEvent::new(0, 0, 0, Opcode::XOR, 25, 10, 19)]; let chip = BitwiseChip::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -194,9 +208,9 @@ mod tests { let mut shard = ExecutionRecord::default(); shard.bitwise_events = [ - AluEvent::new(0, 0, Opcode::XOR, 25, 10, 19), - AluEvent::new(0, 0, Opcode::OR, 27, 10, 19), - AluEvent::new(0, 0, Opcode::AND, 2, 10, 19), + AluEvent::new(0, 0, 0, Opcode::XOR, 25, 10, 19), + AluEvent::new(0, 1, 0, Opcode::OR, 27, 10, 19), + AluEvent::new(0, 0, 0, Opcode::AND, 2, 10, 19), ] .repeat(1000); let chip = BitwiseChip::default(); diff --git a/core/src/alu/divrem/mod.rs b/core/src/alu/divrem/mod.rs index 71b7260ec9..72a2e7510f 100644 --- a/core/src/alu/divrem/mod.rs +++ b/core/src/alu/divrem/mod.rs @@ -104,6 +104,9 @@ pub struct DivRemCols { /// The shard number, used for byte lookup table. pub shard: T, + /// The channel number, used for byte lookup table. + pub channel: T, + /// The output operand. pub a: Word, @@ -233,6 +236,7 @@ impl MachineAir for DivRemChip { cols.b = Word::from(event.b); cols.c = Word::from(event.c); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.is_real = F::one(); cols.is_divu = F::from_bool(event.opcode == Opcode::DIVU); cols.is_remu = F::from_bool(event.opcode == Opcode::REMU); @@ -286,6 +290,7 @@ impl MachineAir for DivRemChip { let most_significant_byte = word.to_le_bytes()[WORD_SIZE - 1]; blu_events.push(ByteLookupEvent { shard: event.shard, + channel: event.channel, opcode: ByteOpcode::MSB, a1: get_msb(*word) as u32, a2: 0, @@ -351,6 +356,7 @@ impl MachineAir for DivRemChip { let lower_multiplication = AluEvent { shard: event.shard, + channel: event.channel, clk: event.clk, opcode: Opcode::MUL, a: lower_word, @@ -361,6 +367,7 @@ impl MachineAir for DivRemChip { let upper_multiplication = AluEvent { shard: event.shard, + channel: event.channel, clk: event.clk, opcode: { if is_signed_operation(event.opcode) { @@ -379,6 +386,7 @@ impl MachineAir for DivRemChip { let lt_event = if is_signed_operation(event.opcode) { AluEvent { shard: event.shard, + channel: event.channel, opcode: Opcode::SLT, a: 1, b: (remainder as i32).abs() as u32, @@ -388,6 +396,7 @@ impl MachineAir for DivRemChip { } else { AluEvent { shard: event.shard, + channel: event.channel, opcode: Opcode::SLTU, a: 1, b: remainder, @@ -402,9 +411,13 @@ impl MachineAir for DivRemChip { // Range check. { - output.add_u8_range_checks(event.shard, "ient.to_le_bytes()); - output.add_u8_range_checks(event.shard, &remainder.to_le_bytes()); - output.add_u8_range_checks(event.shard, &c_times_quotient); + output.add_u8_range_checks(event.shard, event.channel, "ient.to_le_bytes()); + output.add_u8_range_checks( + event.shard, + event.channel, + &remainder.to_le_bytes(), + ); + output.add_u8_range_checks(event.shard, event.channel, &c_times_quotient); } } @@ -499,6 +512,7 @@ where local.quotient, local.c, local.shard, + local.channel, local.is_real, ); @@ -523,6 +537,7 @@ where local.quotient, local.c, local.shard, + local.channel, local.is_real, ); } @@ -752,7 +767,8 @@ where local.abs_remainder, local.max_abs_c_or_1, local.shard, - local.remainder_check_multiplicity, + local.channel, + local.is_real, ); } @@ -767,20 +783,43 @@ where for msb_pair in msb_pairs.iter() { let msb = msb_pair.0; let byte = msb_pair.1; - builder.send_byte(opcode, msb, byte, zero.clone(), local.shard, local.is_real); + builder.send_byte( + opcode, + msb, + byte, + zero.clone(), + local.shard, + local.channel, + local.is_real, + ); } } // Range check all the bytes. { - builder.slice_range_check_u8(&local.quotient.0, local.shard, local.is_real); - builder.slice_range_check_u8(&local.remainder.0, local.shard, local.is_real); + builder.slice_range_check_u8( + &local.quotient.0, + local.shard, + local.channel, + local.is_real, + ); + builder.slice_range_check_u8( + &local.remainder.0, + local.shard, + local.channel, + local.is_real, + ); local.carry.iter().for_each(|carry| { builder.assert_bool(*carry); }); - builder.slice_range_check_u8(&local.c_times_quotient, local.shard, local.is_real); + builder.slice_range_check_u8( + &local.c_times_quotient, + local.shard, + local.channel, + local.is_real, + ); } // Check that the flags are boolean. @@ -831,6 +870,7 @@ where local.b, local.c, local.shard, + local.channel, local.is_real, ); } @@ -859,7 +899,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.divrem_events = vec![AluEvent::new(0, 0, Opcode::DIVU, 2, 17, 3)]; + shard.divrem_events = vec![AluEvent::new(0, 0, 0, Opcode::DIVU, 2, 17, 3)]; let chip = DivRemChip::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -913,12 +953,12 @@ mod tests { (Opcode::REM, 0, 1 << 8, neg(256)), ]; for t in divrems.iter() { - divrem_events.push(AluEvent::new(0, 0, t.0, t.1, t.2, t.3)); + divrem_events.push(AluEvent::new(0, 9, 0, t.0, t.1, t.2, t.3)); } // Append more events until we have 1000 tests. for _ in 0..(1000 - divrems.len()) { - divrem_events.push(AluEvent::new(0, 0, Opcode::DIVU, 1, 1, 1)); + divrem_events.push(AluEvent::new(0, 0, 0, Opcode::DIVU, 1, 1, 1)); } let mut shard = ExecutionRecord::default(); diff --git a/core/src/alu/lt/mod.rs b/core/src/alu/lt/mod.rs index 7727efc219..91b504181c 100644 --- a/core/src/alu/lt/mod.rs +++ b/core/src/alu/lt/mod.rs @@ -31,6 +31,9 @@ pub struct LtCols { /// The shard number, used for byte lookup table. pub shard: T, + /// The channel number, used for byte lookup table. + pub channel: T, + /// If the opcode is SLT. pub is_slt: T, @@ -116,6 +119,7 @@ impl MachineAir for LtChip { let c = event.c.to_le_bytes(); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.a = Word(a.map(F::from_canonical_u8)); cols.b = Word(b.map(F::from_canonical_u8)); cols.c = Word(c.map(F::from_canonical_u8)); @@ -129,6 +133,7 @@ impl MachineAir for LtChip { // Send the masked interaction. new_byte_lookup_events.add_byte_lookup_event(ByteLookupEvent { shard: event.shard, + channel: event.channel, opcode: ByteOpcode::AND, a1: masked_b as u32, a2: 0, @@ -137,6 +142,7 @@ impl MachineAir for LtChip { }); new_byte_lookup_events.add_byte_lookup_event(ByteLookupEvent { shard: event.shard, + channel: event.channel, opcode: ByteOpcode::AND, a1: masked_c as u32, a2: 0, @@ -191,6 +197,7 @@ impl MachineAir for LtChip { new_byte_lookup_events.add_byte_lookup_event(ByteLookupEvent { shard: event.shard, + channel: event.channel, opcode: ByteOpcode::LTU, a1: cols.sltu.as_canonical_u32(), a2: 0, @@ -270,6 +277,7 @@ where local.b[3], AB::F::from_canonical_u8(0x7f), local.shard, + local.channel, is_real.clone(), ); builder.send_byte( @@ -278,6 +286,7 @@ where local.c[3], AB::F::from_canonical_u8(0x7f), local.shard, + local.channel, is_real.clone(), ); @@ -398,6 +407,7 @@ where b_comp_byte, c_comp_byte, local.shard, + local.channel, is_real.clone(), ); @@ -420,6 +430,7 @@ where local.b, local.c, local.shard, + local.channel, is_real, ); } @@ -447,7 +458,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.lt_events = vec![AluEvent::new(0, 0, Opcode::SLT, 0, 3, 2)]; + shard.lt_events = vec![AluEvent::new(0, 1, 0, Opcode::SLT, 0, 3, 2)]; let chip = LtChip::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -475,21 +486,21 @@ mod tests { const NEG_4: u32 = 0b11111111111111111111111111111100; shard.lt_events = vec![ // 0 == 3 < 2 - AluEvent::new(0, 0, Opcode::SLT, 0, 3, 2), + AluEvent::new(0, 0, 0, Opcode::SLT, 0, 3, 2), // 1 == 2 < 3 - AluEvent::new(0, 1, Opcode::SLT, 1, 2, 3), + AluEvent::new(0, 0, 1, Opcode::SLT, 1, 2, 3), // 0 == 5 < -3 - AluEvent::new(0, 3, Opcode::SLT, 0, 5, NEG_3), + AluEvent::new(0, 0, 3, Opcode::SLT, 0, 5, NEG_3), // 1 == -3 < 5 - AluEvent::new(0, 2, Opcode::SLT, 1, NEG_3, 5), + AluEvent::new(0, 0, 2, Opcode::SLT, 1, NEG_3, 5), // 0 == -3 < -4 - AluEvent::new(0, 4, Opcode::SLT, 0, NEG_3, NEG_4), + AluEvent::new(0, 0, 4, Opcode::SLT, 0, NEG_3, NEG_4), // 1 == -4 < -3 - AluEvent::new(0, 4, Opcode::SLT, 1, NEG_4, NEG_3), + AluEvent::new(0, 0, 4, Opcode::SLT, 1, NEG_4, NEG_3), // 0 == 3 < 3 - AluEvent::new(0, 5, Opcode::SLT, 0, 3, 3), + AluEvent::new(0, 0, 5, Opcode::SLT, 0, 3, 3), // 0 == -3 < -3 - AluEvent::new(0, 5, Opcode::SLT, 0, NEG_3, NEG_3), + AluEvent::new(0, 0, 5, Opcode::SLT, 0, NEG_3, NEG_3), ]; prove_babybear_template(&mut shard); @@ -502,17 +513,17 @@ mod tests { const LARGE: u32 = 0b11111111111111111111111111111101; shard.lt_events = vec![ // 0 == 3 < 2 - AluEvent::new(0, 0, Opcode::SLTU, 0, 3, 2), + AluEvent::new(0, 0, 0, Opcode::SLTU, 0, 3, 2), // 1 == 2 < 3 - AluEvent::new(0, 1, Opcode::SLTU, 1, 2, 3), + AluEvent::new(0, 0, 1, Opcode::SLTU, 1, 2, 3), // 0 == LARGE < 5 - AluEvent::new(0, 2, Opcode::SLTU, 0, LARGE, 5), + AluEvent::new(0, 0, 2, Opcode::SLTU, 0, LARGE, 5), // 1 == 5 < LARGE - AluEvent::new(0, 3, Opcode::SLTU, 1, 5, LARGE), + AluEvent::new(0, 0, 3, Opcode::SLTU, 1, 5, LARGE), // 0 == 0 < 0 - AluEvent::new(0, 5, Opcode::SLTU, 0, 0, 0), + AluEvent::new(0, 0, 5, Opcode::SLTU, 0, 0, 0), // 0 == LARGE < LARGE - AluEvent::new(0, 5, Opcode::SLTU, 0, LARGE, LARGE), + AluEvent::new(0, 0, 5, Opcode::SLTU, 0, LARGE, LARGE), ]; prove_babybear_template(&mut shard); diff --git a/core/src/alu/mod.rs b/core/src/alu/mod.rs index 1f780988b3..c667c612c8 100644 --- a/core/src/alu/mod.rs +++ b/core/src/alu/mod.rs @@ -24,6 +24,9 @@ pub struct AluEvent { /// The shard number, used for byte lookup table. pub shard: u32, + /// The channel number, used for byte lookup table. + pub channel: u32, + /// The clock cycle that the operation occurs on. pub clk: u32, @@ -42,9 +45,10 @@ pub struct AluEvent { impl AluEvent { /// Creates a new `AluEvent`. - pub fn new(shard: u32, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) -> Self { + pub fn new(shard: u32, channel: u32, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) -> Self { Self { shard, + channel, clk, opcode, a, diff --git a/core/src/alu/mul/mod.rs b/core/src/alu/mul/mod.rs index 32a15fef94..c30a59c4f4 100644 --- a/core/src/alu/mul/mod.rs +++ b/core/src/alu/mul/mod.rs @@ -76,6 +76,9 @@ pub struct MulCols { /// The shard number, used for byte lookup table. pub shard: T, + /// The channel number, used for byte lookup table. + pub channel: T, + /// The output operand. pub a: Word, @@ -191,6 +194,7 @@ impl MachineAir for MulChip { let most_significant_byte = word[WORD_SIZE - 1]; blu_events.push(ByteLookupEvent { shard: event.shard, + channel: event.channel, opcode: ByteOpcode::MSB, a1: get_msb(*word) as u32, a2: 0, @@ -234,11 +238,16 @@ impl MachineAir for MulChip { cols.is_mulhu = F::from_bool(event.opcode == Opcode::MULHU); cols.is_mulhsu = F::from_bool(event.opcode == Opcode::MULHSU); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); // Range check. { - record.add_u16_range_checks(event.shard, &carry); - record.add_u8_range_checks(event.shard, &product.map(|x| x as u8)); + record.add_u16_range_checks(event.shard, event.channel, &carry); + record.add_u8_range_checks( + event.shard, + event.channel, + &product.map(|x| x as u8), + ); } row }) @@ -299,7 +308,15 @@ where for msb_pair in msb_pairs.iter() { let msb = msb_pair.0; let byte = msb_pair.1; - builder.send_byte(opcode, msb, byte, zero.clone(), local.shard, local.is_real); + builder.send_byte( + opcode, + msb, + byte, + zero.clone(), + local.shard, + local.channel, + local.is_real, + ); } (local.b_msb, local.c_msb) }; @@ -425,9 +442,9 @@ where // Ensure that the carry is at most 2^16. This ensures that // product_before_carry_propagation - carry * base + last_carry never overflows or // underflows enough to "wrap" around to create a second solution. - builder.slice_range_check_u16(&local.carry, local.shard, local.is_real); + builder.slice_range_check_u16(&local.carry, local.shard, local.channel, local.is_real); - builder.slice_range_check_u8(&local.product, local.shard, local.is_real); + builder.slice_range_check_u8(&local.product, local.shard, local.channel, local.is_real); } // Receive the arguments. @@ -437,6 +454,7 @@ where local.b, local.c, local.shard, + local.channel, local.is_real, ); } @@ -469,6 +487,7 @@ mod tests { let mut mul_events: Vec = Vec::new(); for _ in 0..10i32.pow(7) { mul_events.push(AluEvent::new( + 0, 0, 0, Opcode::MULHSU, @@ -544,12 +563,12 @@ mod tests { (Opcode::MULH, 0xffffffff, 0x00000001, 0xffffffff), ]; for t in mul_instructions.iter() { - mul_events.push(AluEvent::new(0, 0, t.0, t.1, t.2, t.3)); + mul_events.push(AluEvent::new(0, 0, 0, t.0, t.1, t.2, t.3)); } // Append more events until we have 1000 tests. for _ in 0..(1000 - mul_instructions.len()) { - mul_events.push(AluEvent::new(0, 0, Opcode::MUL, 1, 1, 1)); + mul_events.push(AluEvent::new(0, 0, 0, Opcode::MUL, 1, 1, 1)); } shard.mul_events = mul_events; diff --git a/core/src/alu/sll/mod.rs b/core/src/alu/sll/mod.rs index ec0f185fe2..d87ee780d6 100644 --- a/core/src/alu/sll/mod.rs +++ b/core/src/alu/sll/mod.rs @@ -64,6 +64,9 @@ pub struct ShiftLeftCols { /// The shard number, used for byte lookup table. pub shard: T, + /// The channel number, used for byte lookup table. + pub channel: T, + /// The output operand. pub a: Word, @@ -118,6 +121,7 @@ impl MachineAir for ShiftLeft { let b = event.b.to_le_bytes(); let c = event.c.to_le_bytes(); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.a = Word(a.map(F::from_canonical_u8)); cols.b = Word(b.map(F::from_canonical_u8)); cols.c = Word(c.map(F::from_canonical_u8)); @@ -156,8 +160,8 @@ impl MachineAir for ShiftLeft { // Range checks. { - output.add_u8_range_checks(event.shard, &bit_shift_result); - output.add_u8_range_checks(event.shard, &bit_shift_result_carry); + output.add_u8_range_checks(event.shard, event.channel, &bit_shift_result); + output.add_u8_range_checks(event.shard, event.channel, &bit_shift_result_carry); } // Sanity check. @@ -314,8 +318,18 @@ where // Range check. { - builder.slice_range_check_u8(&local.bit_shift_result, local.shard, local.is_real); - builder.slice_range_check_u8(&local.bit_shift_result_carry, local.shard, local.is_real); + builder.slice_range_check_u8( + &local.bit_shift_result, + local.shard, + local.channel, + local.is_real, + ); + builder.slice_range_check_u8( + &local.bit_shift_result_carry, + local.shard, + local.channel, + local.is_real, + ); } for shift in local.shift_by_n_bytes.iter() { @@ -339,6 +353,7 @@ where local.b, local.c, local.shard, + local.channel, local.is_real, ); } @@ -366,7 +381,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.shift_left_events = vec![AluEvent::new(0, 0, Opcode::SLL, 16, 8, 1)]; + shard.shift_left_events = vec![AluEvent::new(0, 0, 0, Opcode::SLL, 16, 8, 1)]; let chip = ShiftLeft::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -401,7 +416,7 @@ mod tests { (Opcode::SLL, 0x00000000, 0x21212120, 0xffffffff), ]; for t in shift_instructions.iter() { - shift_events.push(AluEvent::new(0, 0, t.0, t.1, t.2, t.3)); + shift_events.push(AluEvent::new(0, 0, 0, t.0, t.1, t.2, t.3)); } // Append more events until we have 1000 tests. diff --git a/core/src/alu/sr/mod.rs b/core/src/alu/sr/mod.rs index f31bf7ee5a..892ad7a8ab 100644 --- a/core/src/alu/sr/mod.rs +++ b/core/src/alu/sr/mod.rs @@ -82,6 +82,9 @@ pub struct ShiftRightCols { /// The shard number, used for byte lookup table. pub shard: T, + /// The channel number, used for byte lookup table. + pub channel: T, + /// The output operand. pub a: Word, @@ -149,6 +152,7 @@ impl MachineAir for ShiftRightChip { // Initialize cols with basic operands and flags derived from the current event. { cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.a = Word::from(event.a); cols.b = Word::from(event.b); cols.c = Word::from(event.c); @@ -168,6 +172,7 @@ impl MachineAir for ShiftRightChip { let most_significant_byte = event.b.to_le_bytes()[WORD_SIZE - 1]; output.add_byte_lookup_events(vec![ByteLookupEvent { shard: event.shard, + channel: event.channel, opcode: ByteOpcode::MSB, a1: ((most_significant_byte >> 7) & 1) as u32, a2: 0, @@ -217,6 +222,7 @@ impl MachineAir for ShiftRightChip { let byte_event = ByteLookupEvent { shard: event.shard, + channel: event.channel, opcode: ByteOpcode::ShrCarry, a1: shift as u32, a2: carry as u32, @@ -239,10 +245,14 @@ impl MachineAir for ShiftRightChip { debug_assert_eq!(cols.a[i], cols.bit_shift_result[i].clone()); } // Range checks. - output.add_u8_range_checks(event.shard, &byte_shift_result); - output.add_u8_range_checks(event.shard, &bit_shift_result); - output.add_u8_range_checks(event.shard, &shr_carry_output_carry); - output.add_u8_range_checks(event.shard, &shr_carry_output_shifted_byte); + output.add_u8_range_checks(event.shard, event.channel, &byte_shift_result); + output.add_u8_range_checks(event.shard, event.channel, &bit_shift_result); + output.add_u8_range_checks(event.shard, event.channel, &shr_carry_output_carry); + output.add_u8_range_checks( + event.shard, + event.channel, + &shr_carry_output_shifted_byte, + ); } rows.push(row); @@ -303,7 +313,15 @@ where let byte = local.b[WORD_SIZE - 1]; let opcode = AB::F::from_canonical_u32(ByteOpcode::MSB as u32); let msb = local.b_msb; - builder.send_byte(opcode, msb, byte, zero.clone(), local.shard, local.is_real); + builder.send_byte( + opcode, + msb, + byte, + zero.clone(), + local.shard, + local.channel, + local.is_real, + ); } // Calculate the number of bits and bytes to shift by from c. @@ -411,6 +429,7 @@ where local.byte_shift_result[i], num_bits_to_shift.clone(), local.shard, + local.channel, local.is_real, ); } @@ -457,7 +476,7 @@ where ]; for long_word in long_words.iter() { - builder.slice_range_check_u8(long_word, local.shard, local.is_real); + builder.slice_range_check_u8(long_word, local.shard, local.channel, local.is_real); } } @@ -478,6 +497,7 @@ where local.b, local.c, local.shard, + local.channel, local.is_real, ); } @@ -504,7 +524,7 @@ mod tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.shift_right_events = vec![AluEvent::new(0, 0, Opcode::SRL, 6, 12, 1)]; + shard.shift_right_events = vec![AluEvent::new(0, 0, 0, Opcode::SRL, 6, 12, 1)]; let chip = ShiftRightChip::default(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); @@ -555,7 +575,7 @@ mod tests { ]; let mut shift_events: Vec = Vec::new(); for t in shifts.iter() { - shift_events.push(AluEvent::new(0, 0, t.0, t.1, t.2, t.3)); + shift_events.push(AluEvent::new(0, 0, 0, t.0, t.1, t.2, t.3)); } let mut shard = ExecutionRecord::default(); shard.shift_right_events = shift_events; diff --git a/core/src/bytes/air.rs b/core/src/bytes/air.rs index 712ff1b43e..69d2ba82fe 100644 --- a/core/src/bytes/air.rs +++ b/core/src/bytes/air.rs @@ -7,7 +7,7 @@ use p3_field::Field; use p3_matrix::Matrix; use super::columns::{ByteMultCols, BytePreprocessedCols, NUM_BYTE_MULT_COLS}; -use super::{ByteChip, ByteOpcode}; +use super::{ByteChip, ByteOpcode, NUM_BYTE_LOOKUP_CHANNELS}; use crate::air::SP1AirBuilder; impl BaseAir for ByteChip { @@ -27,49 +27,66 @@ impl Air for ByteChip { let local: &BytePreprocessedCols = (*prep).borrow(); // Send all the lookups for each operation. - for (i, opcode) in ByteOpcode::all().iter().enumerate() { - let field_op = opcode.as_field::(); - let mult = local_mult.multiplicities[i]; - let shard = local_mult.shard; - match opcode { - ByteOpcode::AND => { - builder.receive_byte(field_op, local.and, local.b, local.c, shard, mult) + for channel in 0..NUM_BYTE_LOOKUP_CHANNELS { + let channel_f = AB::F::from_canonical_u32(channel); + let channel = channel as usize; + for (i, opcode) in ByteOpcode::all().iter().enumerate() { + let field_op = opcode.as_field::(); + let mult = local_mult.mult_channels[channel].multiplicities[i]; + let shard = local_mult.shard; + match opcode { + ByteOpcode::AND => builder.receive_byte( + field_op, local.and, local.b, local.c, shard, channel_f, mult, + ), + ByteOpcode::OR => builder + .receive_byte(field_op, local.or, local.b, local.c, shard, channel_f, mult), + ByteOpcode::XOR => builder.receive_byte( + field_op, local.xor, local.b, local.c, shard, channel_f, mult, + ), + ByteOpcode::SLL => builder.receive_byte( + field_op, local.sll, local.b, local.c, shard, channel_f, mult, + ), + ByteOpcode::U8Range => builder.receive_byte( + field_op, + AB::F::zero(), + local.b, + local.c, + shard, + channel_f, + mult, + ), + ByteOpcode::ShrCarry => builder.receive_byte_pair( + field_op, + local.shr, + local.shr_carry, + local.b, + local.c, + shard, + channel_f, + mult, + ), + ByteOpcode::LTU => builder.receive_byte( + field_op, local.ltu, local.b, local.c, shard, channel_f, mult, + ), + ByteOpcode::MSB => builder.receive_byte( + field_op, + local.msb, + local.b, + AB::F::zero(), + shard, + channel_f, + mult, + ), + ByteOpcode::U16Range => builder.receive_byte( + field_op, + local.value_u16, + AB::F::zero(), + AB::F::zero(), + shard, + channel_f, + mult, + ), } - ByteOpcode::OR => { - builder.receive_byte(field_op, local.or, local.b, local.c, shard, mult) - } - ByteOpcode::XOR => { - builder.receive_byte(field_op, local.xor, local.b, local.c, shard, mult) - } - ByteOpcode::SLL => { - builder.receive_byte(field_op, local.sll, local.b, local.c, shard, mult) - } - ByteOpcode::U8Range => { - builder.receive_byte(field_op, AB::F::zero(), local.b, local.c, shard, mult) - } - ByteOpcode::ShrCarry => builder.receive_byte_pair( - field_op, - local.shr, - local.shr_carry, - local.b, - local.c, - shard, - mult, - ), - ByteOpcode::LTU => { - builder.receive_byte(field_op, local.ltu, local.b, local.c, shard, mult) - } - ByteOpcode::MSB => { - builder.receive_byte(field_op, local.msb, local.b, AB::F::zero(), shard, mult) - } - ByteOpcode::U16Range => builder.receive_byte( - field_op, - local.value_u16, - AB::F::zero(), - AB::F::zero(), - shard, - mult, - ), } } } diff --git a/core/src/bytes/columns.rs b/core/src/bytes/columns.rs index 4524bec8bf..7134331f63 100644 --- a/core/src/bytes/columns.rs +++ b/core/src/bytes/columns.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use super::NUM_BYTE_OPS; +use super::{NUM_BYTE_LOOKUP_CHANNELS, NUM_BYTE_OPS}; /// The number of main trace columns for `ByteChip`. pub const NUM_BYTE_PREPROCESSED_COLS: usize = size_of::>(); @@ -44,6 +44,14 @@ pub struct BytePreprocessedCols { pub value_u16: T, } +/// For each byte operation in the preprocessed table, a corresponding ByteMultCols row tracks the +/// number of times the operation is used. +#[derive(Debug, Clone, Copy, AlignedBorrow)] +#[repr(C)] +pub struct MultiplicitiesCols { + pub multiplicities: [T; NUM_BYTE_OPS], +} + /// For each byte operation in the preprocessed table, a corresponding ByteMultCols row tracks the /// number of times the operation is used. #[derive(Debug, Clone, Copy, AlignedBorrow)] @@ -53,5 +61,5 @@ pub struct ByteMultCols { pub shard: T, /// The multiplicites of each byte operation. - pub multiplicities: [T; NUM_BYTE_OPS], + pub mult_channels: [MultiplicitiesCols; NUM_BYTE_LOOKUP_CHANNELS as usize], } diff --git a/core/src/bytes/event.rs b/core/src/bytes/event.rs index 719b4542a4..f45c80cfe8 100644 --- a/core/src/bytes/event.rs +++ b/core/src/bytes/event.rs @@ -11,6 +11,9 @@ pub struct ByteLookupEvent { /// The shard number, used for byte lookup table. pub shard: u32, + // The channel multiplicity identifier. + pub channel: u32, + /// The opcode of the operation. pub opcode: ByteOpcode, @@ -40,9 +43,10 @@ pub trait ByteRecord { } /// Adds a `ByteLookupEvent` to verify `a` and `b are indeed bytes to the shard. - fn add_u8_range_check(&mut self, shard: u32, a: u8, b: u8) { + fn add_u8_range_check(&mut self, shard: u32, channel: u32, a: u8, b: u8) { self.add_byte_lookup_event(ByteLookupEvent { shard, + channel, opcode: ByteOpcode::U8Range, a1: 0, a2: 0, @@ -52,9 +56,10 @@ pub trait ByteRecord { } /// Adds a `ByteLookupEvent` to verify `a` is indeed u16. - fn add_u16_range_check(&mut self, shard: u32, a: u32) { + fn add_u16_range_check(&mut self, shard: u32, channel: u32, a: u32) { self.add_byte_lookup_event(ByteLookupEvent { shard, + channel, opcode: ByteOpcode::U16Range, a1: a, a2: 0, @@ -64,23 +69,29 @@ pub trait ByteRecord { } /// Adds `ByteLookupEvent`s to verify that all the bytes in the input slice are indeed bytes. - fn add_u8_range_checks(&mut self, shard: u32, bytes: &[u8]) { + fn add_u8_range_checks(&mut self, shard: u32, channel: u32, bytes: &[u8]) { let mut index = 0; while index + 1 < bytes.len() { - self.add_u8_range_check(shard, bytes[index], bytes[index + 1]); + self.add_u8_range_check(shard, channel, bytes[index], bytes[index + 1]); index += 2; } if index < bytes.len() { // If the input slice's length is odd, we need to add a check for the last byte. - self.add_u8_range_check(shard, bytes[index], 0); + self.add_u8_range_check(shard, channel, bytes[index], 0); } } /// Adds `ByteLookupEvent`s to verify that all the field elements in the input slice are indeed /// bytes. - fn add_u8_range_checks_field(&mut self, shard: u32, field_values: &[F]) { + fn add_u8_range_checks_field( + &mut self, + shard: u32, + channel: u32, + field_values: &[F], + ) { self.add_u8_range_checks( shard, + channel, &field_values .iter() .map(|x| x.as_canonical_u32() as u8) @@ -89,14 +100,16 @@ pub trait ByteRecord { } /// Adds `ByteLookupEvent`s to verify that all the bytes in the input slice are indeed bytes. - fn add_u16_range_checks(&mut self, shard: u32, ls: &[u32]) { - ls.iter().for_each(|x| self.add_u16_range_check(shard, *x)); + fn add_u16_range_checks(&mut self, shard: u32, channel: u32, ls: &[u32]) { + ls.iter() + .for_each(|x| self.add_u16_range_check(shard, channel, *x)); } /// Adds a `ByteLookupEvent` to compute the bitwise OR of the two input values. - fn lookup_or(&mut self, shard: u32, b: u8, c: u8) { + fn lookup_or(&mut self, shard: u32, channel: u32, b: u8, c: u8) { self.add_byte_lookup_event(ByteLookupEvent { shard, + channel, opcode: ByteOpcode::OR, a1: (b | c) as u32, a2: 0, @@ -108,9 +121,18 @@ pub trait ByteRecord { impl ByteLookupEvent { /// Creates a new `ByteLookupEvent`. - pub fn new(shard: u32, opcode: ByteOpcode, a1: u32, a2: u32, b: u32, c: u32) -> Self { + pub fn new( + shard: u32, + channel: u32, + opcode: ByteOpcode, + a1: u32, + a2: u32, + b: u32, + c: u32, + ) -> Self { Self { shard, + channel, opcode, a1, a2, diff --git a/core/src/bytes/mod.rs b/core/src/bytes/mod.rs index 411b065d1a..f6d5bc482c 100644 --- a/core/src/bytes/mod.rs +++ b/core/src/bytes/mod.rs @@ -23,6 +23,9 @@ use crate::bytes::trace::NUM_ROWS; /// The number of different byte operations. pub const NUM_BYTE_OPS: usize = 9; +/// The number of different byte lookup channels. +pub const NUM_BYTE_LOOKUP_CHANNELS: u32 = 4; + /// A chip for computing byte operations. /// /// The chip contains a preprocessed table of all possible byte operations. Other chips can then @@ -64,61 +67,76 @@ impl ByteChip { col.c = F::from_canonical_u8(c); // Iterate over all operations for results and updating the table map. - for (i, opcode) in opcodes.iter().enumerate() { - let event = match opcode { - ByteOpcode::AND => { - let and = b & c; - col.and = F::from_canonical_u8(and); - ByteLookupEvent::new(shard, *opcode, and as u32, 0, b as u32, c as u32) - } - ByteOpcode::OR => { - let or = b | c; - col.or = F::from_canonical_u8(or); - ByteLookupEvent::new(shard, *opcode, or as u32, 0, b as u32, c as u32) - } - ByteOpcode::XOR => { - let xor = b ^ c; - col.xor = F::from_canonical_u8(xor); - ByteLookupEvent::new(shard, *opcode, xor as u32, 0, b as u32, c as u32) - } - ByteOpcode::SLL => { - let sll = b << (c & 7); - col.sll = F::from_canonical_u8(sll); - ByteLookupEvent::new(shard, *opcode, sll as u32, 0, b as u32, c as u32) - } - ByteOpcode::U8Range => { - ByteLookupEvent::new(shard, *opcode, 0, 0, b as u32, c as u32) - } - ByteOpcode::ShrCarry => { - let (res, carry) = shr_carry(b, c); - col.shr = F::from_canonical_u8(res); - col.shr_carry = F::from_canonical_u8(carry); - ByteLookupEvent::new( - shard, - *opcode, - res as u32, - carry as u32, - b as u32, - c as u32, - ) - } - ByteOpcode::LTU => { - let ltu = b < c; - col.ltu = F::from_bool(ltu); - ByteLookupEvent::new(shard, *opcode, ltu as u32, 0, b as u32, c as u32) - } - ByteOpcode::MSB => { - let msb = (b & 0b1000_0000) != 0; - col.msb = F::from_bool(msb); - ByteLookupEvent::new(shard, *opcode, msb as u32, 0, b as u32, 0 as u32) - } - ByteOpcode::U16Range => { - let v = ((b as u32) << 8) + c as u32; - col.value_u16 = F::from_canonical_u32(v); - ByteLookupEvent::new(shard, *opcode, v, 0, 0, 0) - } - }; - event_map.insert(event, (row_index, i)); + for channel in 0..NUM_BYTE_LOOKUP_CHANNELS { + for (i, opcode) in opcodes.iter().enumerate() { + let event = match opcode { + ByteOpcode::AND => { + let and = b & c; + col.and = F::from_canonical_u8(and); + ByteLookupEvent::new( + shard, channel, *opcode, and as u32, 0, b as u32, c as u32, + ) + } + ByteOpcode::OR => { + let or = b | c; + col.or = F::from_canonical_u8(or); + ByteLookupEvent::new( + shard, channel, *opcode, or as u32, 0, b as u32, c as u32, + ) + } + ByteOpcode::XOR => { + let xor = b ^ c; + col.xor = F::from_canonical_u8(xor); + ByteLookupEvent::new( + shard, channel, *opcode, xor as u32, 0, b as u32, c as u32, + ) + } + ByteOpcode::SLL => { + let sll = b << (c & 7); + col.sll = F::from_canonical_u8(sll); + ByteLookupEvent::new( + shard, channel, *opcode, sll as u32, 0, b as u32, c as u32, + ) + } + ByteOpcode::U8Range => { + ByteLookupEvent::new(shard, channel, *opcode, 0, 0, b as u32, c as u32) + } + ByteOpcode::ShrCarry => { + let (res, carry) = shr_carry(b, c); + col.shr = F::from_canonical_u8(res); + col.shr_carry = F::from_canonical_u8(carry); + ByteLookupEvent::new( + shard, + channel, + *opcode, + res as u32, + carry as u32, + b as u32, + c as u32, + ) + } + ByteOpcode::LTU => { + let ltu = b < c; + col.ltu = F::from_bool(ltu); + ByteLookupEvent::new( + shard, channel, *opcode, ltu as u32, 0, b as u32, c as u32, + ) + } + ByteOpcode::MSB => { + let msb = (b & 0b1000_0000) != 0; + col.msb = F::from_bool(msb); + ByteLookupEvent::new( + shard, channel, *opcode, msb as u32, 0, b as u32, 0 as u32, + ) + } + ByteOpcode::U16Range => { + let v = ((b as u32) << 8) + c as u32; + col.value_u16 = F::from_canonical_u32(v); + ByteLookupEvent::new(shard, channel, *opcode, v, 0, 0, 0) + } + }; + event_map.insert(event, (row_index, i)); + } } } diff --git a/core/src/bytes/trace.rs b/core/src/bytes/trace.rs index 22b8204208..39f2b72ff5 100644 --- a/core/src/bytes/trace.rs +++ b/core/src/bytes/trace.rs @@ -54,10 +54,11 @@ impl MachineAir for ByteChip { for (lookup, mult) in input.byte_lookups[&shard].iter() { let (row, index) = event_map[lookup]; + let channel = lookup.channel as usize; let cols: &mut ByteMultCols = trace.row_mut(row).borrow_mut(); // Update the trace multiplicity - cols.multiplicities[index] += F::from_canonical_usize(*mult); + cols.mult_channels[channel].multiplicities[index] += F::from_canonical_usize(*mult); // Set the shard column as the current shard. cols.shard = F::from_canonical_u32(shard); diff --git a/core/src/cpu/air/branch.rs b/core/src/cpu/air/branch.rs index 56d5faf5ef..fad654de35 100644 --- a/core/src/cpu/air/branch.rs +++ b/core/src/cpu/air/branch.rs @@ -64,6 +64,7 @@ impl CpuChip { branch_cols.pc, local.op_c_val(), local.shard, + local.channel, local.branching, ); @@ -155,6 +156,7 @@ impl CpuChip { local.op_a_val(), local.op_b_val(), local.shard, + local.channel, is_branch_instruction.clone(), ); @@ -166,6 +168,7 @@ impl CpuChip { local.op_b_val(), local.op_a_val(), local.shard, + local.channel, is_branch_instruction.clone(), ); } diff --git a/core/src/cpu/air/ecall.rs b/core/src/cpu/air/ecall.rs index 32c2d8a47f..506b2c7b75 100644 --- a/core/src/cpu/air/ecall.rs +++ b/core/src/cpu/air/ecall.rs @@ -41,6 +41,7 @@ impl CpuChip { .assert_eq(send_to_table, local.ecall_mul_send_to_table); builder.send_syscall( local.shard, + local.channel, local.clk, syscall_id, local.op_b_val().reduce::(), diff --git a/core/src/cpu/air/memory.rs b/core/src/cpu/air/memory.rs index eb4404828e..6ac1a07c11 100644 --- a/core/src/cpu/air/memory.rs +++ b/core/src/cpu/air/memory.rs @@ -65,6 +65,7 @@ impl CpuChip { local.op_b_val(), local.op_c_val(), local.shard, + local.channel, is_memory_instruction.clone(), ); @@ -72,6 +73,7 @@ impl CpuChip { builder.slice_range_check_u8( &memory_columns.addr_word.0, local.shard, + local.channel, is_memory_instruction.clone(), ); @@ -90,6 +92,7 @@ impl CpuChip { // value into the memory columns. builder.eval_memory_access( local.shard, + local.channel, local.clk + AB::F::from_canonical_u32(MemoryAccessPosition::Memory as u32), memory_columns.addr_aligned, &memory_columns.memory_access, @@ -139,6 +142,7 @@ impl CpuChip { local.unsigned_mem_val, signed_value, local.shard, + local.channel, local.mem_value_is_neg, ); diff --git a/core/src/cpu/air/mod.rs b/core/src/cpu/air/mod.rs index 75742f8616..11a985bb5e 100644 --- a/core/src/cpu/air/mod.rs +++ b/core/src/cpu/air/mod.rs @@ -24,6 +24,8 @@ use crate::cpu::columns::{CpuCols, NUM_CPU_COLS}; use crate::cpu::CpuChip; use crate::runtime::Opcode; +use super::columns::eval_channel_selectors; + impl Air for CpuChip where AB: SP1AirBuilder + AirBuilderWithPublicValues, @@ -64,6 +66,16 @@ where self.eval_memory_load::(builder, local); self.eval_memory_store::(builder, local); + // Channel constraints. + eval_channel_selectors( + builder, + &local.channel_selectors, + &next.channel_selectors, + local.channel, + local.is_real, + next.is_real, + ); + // ALU instructions. builder.send_alu( local.instruction.opcode, @@ -71,6 +83,7 @@ where local.op_b_val(), local.op_c_val(), local.shard, + local.channel, is_alu_instruction, ); @@ -169,6 +182,7 @@ impl CpuChip { jump_columns.pc, local.op_b_val(), local.shard, + local.channel, local.selectors.is_jal, ); @@ -179,6 +193,7 @@ impl CpuChip { local.op_b_val(), local.op_c_val(), local.shard, + local.channel, local.selectors.is_jalr, ); } @@ -200,6 +215,7 @@ impl CpuChip { auipc_columns.pc, local.op_b_val(), local.shard, + local.channel, local.selectors.is_auipc, ); } @@ -229,6 +245,7 @@ impl CpuChip { AB::Expr::zero(), AB::Expr::zero(), local.shard, + local.channel, local.is_real, ); @@ -255,6 +272,7 @@ impl CpuChip { local.clk_16bit_limb, local.clk_8bit_limb, local.shard, + local.channel, local.is_real, ); } diff --git a/core/src/cpu/air/register.rs b/core/src/cpu/air/register.rs index 14541c48b9..e0b989c2bc 100644 --- a/core/src/cpu/air/register.rs +++ b/core/src/cpu/air/register.rs @@ -25,6 +25,7 @@ impl CpuChip { // If they are not immediates, read `b` and `c` from memory. builder.eval_memory_access( local.shard, + local.channel, local.clk + AB::F::from_canonical_u32(MemoryAccessPosition::B as u32), local.instruction.op_b[0], &local.op_b_access, @@ -33,6 +34,7 @@ impl CpuChip { builder.eval_memory_access( local.shard, + local.channel, local.clk + AB::F::from_canonical_u32(MemoryAccessPosition::C as u32), local.instruction.op_c[0], &local.op_c_access, @@ -48,6 +50,7 @@ impl CpuChip { // we are performing a branch or a store. builder.eval_memory_access( local.shard, + local.channel, local.clk + AB::F::from_canonical_u32(MemoryAccessPosition::A as u32), local.instruction.op_a[0], &local.op_a_access, diff --git a/core/src/cpu/columns/channel.rs b/core/src/cpu/columns/channel.rs new file mode 100644 index 0000000000..9bd6a4ef78 --- /dev/null +++ b/core/src/cpu/columns/channel.rs @@ -0,0 +1,65 @@ +use p3_air::AirBuilder; +use p3_field::AbstractField; +use p3_field::Field; +use sp1_derive::AlignedBorrow; + +use crate::{bytes::NUM_BYTE_LOOKUP_CHANNELS, stark::SP1AirBuilder}; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct ChannelSelectorCols { + pub channel_selectors: [T; NUM_BYTE_LOOKUP_CHANNELS as usize], +} + +impl ChannelSelectorCols { + pub fn populate(&mut self, channel: u32) { + self.channel_selectors = [F::zero(); NUM_BYTE_LOOKUP_CHANNELS as usize]; + self.channel_selectors[channel as usize] = F::one(); + } +} + +pub fn eval_channel_selectors( + builder: &mut AB, + local: &ChannelSelectorCols, + next: &ChannelSelectorCols, + channel: impl Into + Clone, + local_is_real: impl Into + Clone, + next_is_real: impl Into + Clone, +) { + // Constrain: + // - the value of the channel is given by the channel selectors. + // - all selectors are boolean and disjoint. + let mut sum = AB::Expr::zero(); + let mut reconstruct_channel = AB::Expr::zero(); + for (i, selector) in local.channel_selectors.into_iter().enumerate() { + // Constrain that the selector is boolean. + builder.assert_bool(selector); + // Accumulate the sum of the selectors. + sum += selector.into(); + // Accumulate the reconstructed channel. + reconstruct_channel += selector.into() * AB::Expr::from_canonical_u32(i as u32); + } + // Assert that the reconstructed channel is the same as the channel. + builder.assert_eq(reconstruct_channel, channel.clone()); + // For disjointness, assert the sum of the selectors is 1. + builder + .when(local_is_real.clone()) + .assert_eq(sum, AB::Expr::one()); + + // Constrain the first row by asserting that the first selector on the first line is true. + builder + .when_first_row() + .assert_one(local.channel_selectors[0]); + + // Constrain the transition by asserting that the selectors satisfy the recursion relation: + // selectors_next[(i + 1) % NUM_BYTE_LOOKUP_CHANNELS] = selectors[i] + for i in 0..NUM_BYTE_LOOKUP_CHANNELS as usize { + builder + .when_transition() + .when(next_is_real.clone()) + .assert_eq( + local.channel_selectors[i], + next.channel_selectors[(i + 1) % NUM_BYTE_LOOKUP_CHANNELS as usize], + ); + } +} diff --git a/core/src/cpu/columns/mod.rs b/core/src/cpu/columns/mod.rs index e40f30518e..d81bd806fc 100644 --- a/core/src/cpu/columns/mod.rs +++ b/core/src/cpu/columns/mod.rs @@ -1,5 +1,6 @@ mod auipc; mod branch; +mod channel; mod ecall; mod instruction; mod jump; @@ -9,6 +10,7 @@ mod opcode_specific; pub use auipc::*; pub use branch::*; +pub use channel::*; pub use ecall::*; pub use instruction::*; pub use jump::*; @@ -35,6 +37,8 @@ pub const CPU_COL_MAP: CpuCols = make_col_map(); pub struct CpuCols { /// The current shard. pub shard: T, + /// The channel value, used for byte lookup multiplicity. + pub channel: T, /// The clock cycle value. This should be within 24 bits. pub clk: T, @@ -52,6 +56,9 @@ pub struct CpuCols { /// Columns related to the instruction. pub instruction: InstructionCols, + /// Columns related to the byte lookup channel. + pub channel_selectors: ChannelSelectorCols, + /// Selectors for the opcode. pub selectors: OpcodeSelectorCols, diff --git a/core/src/cpu/event.rs b/core/src/cpu/event.rs index 8baa8fb730..2170d91d5d 100644 --- a/core/src/cpu/event.rs +++ b/core/src/cpu/event.rs @@ -9,6 +9,9 @@ pub struct CpuEvent { /// The current shard. pub shard: u32, + /// The current channel. + pub channel: u32, + /// The current clock. pub clk: u32, diff --git a/core/src/cpu/trace.rs b/core/src/cpu/trace.rs index 7000fe0624..aa71f130f5 100644 --- a/core/src/cpu/trace.rs +++ b/core/src/cpu/trace.rs @@ -149,13 +149,16 @@ impl CpuChip { // Populate memory accesses for a, b, and c. if let Some(record) = event.a_record { - cols.op_a_access.populate(record, &mut new_blu_events) + cols.op_a_access + .populate(event.channel, record, &mut new_blu_events) } if let Some(MemoryRecordEnum::Read(record)) = event.b_record { - cols.op_b_access.populate(record, &mut new_blu_events) + cols.op_b_access + .populate(event.channel, record, &mut new_blu_events) } if let Some(MemoryRecordEnum::Read(record)) = event.c_record { - cols.op_c_access.populate(record, &mut new_blu_events) + cols.op_c_access + .populate(event.channel, record, &mut new_blu_events) } // Populate memory accesses for reading from memory. @@ -164,7 +167,7 @@ impl CpuChip { if let Some(record) = event.memory_record { memory_columns .memory_access - .populate(record, &mut new_blu_events) + .populate(event.channel, record, &mut new_blu_events) } // Populate memory, branch, jump, and auipc specific fields. @@ -188,7 +191,7 @@ impl CpuChip { (row, new_alu_events, new_blu_events) } - /// Populates the shard and clk related rows. + /// Populates the shard, channel, and clk related rows. fn populate_shard_clk( &self, cols: &mut CpuCols, @@ -196,8 +199,11 @@ impl CpuChip { new_blu_events: &mut Vec, ) { cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); + cols.channel_selectors.populate(event.channel); new_blu_events.push(ByteLookupEvent::new( event.shard, + event.channel, U16Range, event.shard, 0, @@ -212,6 +218,7 @@ impl CpuChip { cols.clk_8bit_limb = F::from_canonical_u32(clk_8bit_limb); new_blu_events.push(ByteLookupEvent::new( event.shard, + event.channel, U16Range, clk_16bit_limb, 0, @@ -220,6 +227,7 @@ impl CpuChip { )); new_blu_events.push(ByteLookupEvent::new( event.shard, + event.channel, U8Range, 0, 0, @@ -260,6 +268,7 @@ impl CpuChip { // Add event to ALU check to check that addr == b + c let add_event = AluEvent { shard: event.shard, + channel: event.channel, clk: event.clk, opcode: Opcode::ADD, a: memory_addr, @@ -323,6 +332,7 @@ impl CpuChip { if memory_columns.most_sig_byte_decomp[7] == F::one() { cols.mem_value_is_neg = F::one(); let sub_event = AluEvent { + channel: event.channel, shard: event.shard, clk: event.clk, opcode: Opcode::SUB, @@ -344,6 +354,7 @@ impl CpuChip { for byte_pair in addr_bytes.chunks_exact(2) { new_blu_events.push(ByteLookupEvent { shard: event.shard, + channel: event.channel, opcode: ByteOpcode::U8Range, a1: 0, a2: 0, @@ -387,6 +398,7 @@ impl CpuChip { // Add the ALU events for the comparisons let lt_comp_event = AluEvent { shard: event.shard, + channel: event.channel, clk: event.clk, opcode: alu_op_code, a: a_lt_b as u32, @@ -401,6 +413,7 @@ impl CpuChip { let gt_comp_event = AluEvent { shard: event.shard, + channel: event.channel, clk: event.clk, opcode: alu_op_code, a: a_gt_b as u32, @@ -434,6 +447,7 @@ impl CpuChip { let add_event = AluEvent { shard: event.shard, + channel: event.channel, clk: event.clk, opcode: Opcode::ADD, a: next_pc, @@ -469,6 +483,7 @@ impl CpuChip { let add_event = AluEvent { shard: event.shard, + channel: event.channel, clk: event.clk, opcode: Opcode::ADD, a: next_pc, @@ -487,6 +502,7 @@ impl CpuChip { let add_event = AluEvent { shard: event.shard, + channel: event.channel, clk: event.clk, opcode: Opcode::ADD, a: next_pc, @@ -518,6 +534,7 @@ impl CpuChip { let add_event = AluEvent { shard: event.shard, + channel: event.channel, clk: event.clk, opcode: Opcode::ADD, a: event.a, @@ -631,6 +648,7 @@ mod tests { let mut shard = ExecutionRecord::default(); shard.cpu_events = vec![CpuEvent { shard: 1, + channel: 0, clk: 6, pc: 1, next_pc: 5, diff --git a/core/src/memory/trace.rs b/core/src/memory/trace.rs index e7593d0134..678dfcf7bc 100644 --- a/core/src/memory/trace.rs +++ b/core/src/memory/trace.rs @@ -5,7 +5,12 @@ use crate::bytes::event::ByteRecord; use crate::runtime::{MemoryReadRecord, MemoryRecord, MemoryRecordEnum, MemoryWriteRecord}; impl MemoryWriteCols { - pub fn populate(&mut self, record: MemoryWriteRecord, output: &mut impl ByteRecord) { + pub fn populate( + &mut self, + channel: u32, + record: MemoryWriteRecord, + output: &mut impl ByteRecord, + ) { let current_record = MemoryRecord { value: record.value, shard: record.shard, @@ -18,12 +23,17 @@ impl MemoryWriteCols { }; self.prev_value = prev_record.value.into(); self.access - .populate_access(current_record, prev_record, output); + .populate_access(channel, current_record, prev_record, output); } } impl MemoryReadCols { - pub fn populate(&mut self, record: MemoryReadRecord, output: &mut impl ByteRecord) { + pub fn populate( + &mut self, + channel: u32, + record: MemoryReadRecord, + output: &mut impl ByteRecord, + ) { let current_record = MemoryRecord { value: record.value, shard: record.shard, @@ -35,19 +45,31 @@ impl MemoryReadCols { timestamp: record.prev_timestamp, }; self.access - .populate_access(current_record, prev_record, output); + .populate_access(channel, current_record, prev_record, output); } } impl MemoryReadWriteCols { - pub fn populate(&mut self, record: MemoryRecordEnum, output: &mut impl ByteRecord) { + pub fn populate( + &mut self, + channel: u32, + record: MemoryRecordEnum, + output: &mut impl ByteRecord, + ) { match record { - MemoryRecordEnum::Read(read_record) => self.populate_read(read_record, output), - MemoryRecordEnum::Write(write_record) => self.populate_write(write_record, output), + MemoryRecordEnum::Read(read_record) => self.populate_read(channel, read_record, output), + MemoryRecordEnum::Write(write_record) => { + self.populate_write(channel, write_record, output) + } } } - pub fn populate_write(&mut self, record: MemoryWriteRecord, output: &mut impl ByteRecord) { + pub fn populate_write( + &mut self, + channel: u32, + record: MemoryWriteRecord, + output: &mut impl ByteRecord, + ) { let current_record = MemoryRecord { value: record.value, shard: record.shard, @@ -60,10 +82,15 @@ impl MemoryReadWriteCols { }; self.prev_value = prev_record.value.into(); self.access - .populate_access(current_record, prev_record, output); + .populate_access(channel, current_record, prev_record, output); } - pub fn populate_read(&mut self, record: MemoryReadRecord, output: &mut impl ByteRecord) { + pub fn populate_read( + &mut self, + channel: u32, + record: MemoryReadRecord, + output: &mut impl ByteRecord, + ) { let current_record = MemoryRecord { value: record.value, shard: record.shard, @@ -76,13 +103,14 @@ impl MemoryReadWriteCols { }; self.prev_value = prev_record.value.into(); self.access - .populate_access(current_record, prev_record, output); + .populate_access(channel, current_record, prev_record, output); } } impl MemoryAccessCols { pub(crate) fn populate_access( &mut self, + channel: u32, current_record: MemoryRecord, prev_record: MemoryRecord, output: &mut impl ByteRecord, @@ -115,9 +143,9 @@ impl MemoryAccessCols { let shard = current_record.shard; // Add a byte table lookup with the 16Range op. - output.add_u16_range_check(shard, diff_16bit_limb); + output.add_u16_range_check(shard, channel, diff_16bit_limb); // Add a byte table lookup with the U8Range op. - output.add_u8_range_check(shard, 0, diff_8bit_limb as u8); + output.add_u8_range_check(shard, channel, 0, diff_8bit_limb as u8); } } diff --git a/core/src/operations/add.rs b/core/src/operations/add.rs index 78e8dc1826..27f002db3f 100644 --- a/core/src/operations/add.rs +++ b/core/src/operations/add.rs @@ -24,6 +24,7 @@ impl AddOperation { &mut self, record: &mut ExecutionRecord, shard: u32, + channel: u32, a_u32: u32, b_u32: u32, ) -> u32 { @@ -54,9 +55,9 @@ impl AddOperation { // Range check { - record.add_u8_range_checks(shard, &a); - record.add_u8_range_checks(shard, &b); - record.add_u8_range_checks(shard, &expected.to_le_bytes()); + record.add_u8_range_checks(shard, channel, &a); + record.add_u8_range_checks(shard, channel, &b); + record.add_u8_range_checks(shard, channel, &expected.to_le_bytes()); } expected } @@ -67,6 +68,7 @@ impl AddOperation { b: Word, cols: AddOperation, shard: AB::Var, + channel: impl Into + Clone, is_real: AB::Expr, ) { let one = AB::Expr::one(); @@ -103,9 +105,9 @@ impl AddOperation { // Range check each byte. { - builder.slice_range_check_u8(&a.0, shard, is_real.clone()); - builder.slice_range_check_u8(&b.0, shard, is_real.clone()); - builder.slice_range_check_u8(&cols.value.0, shard, is_real); + builder.slice_range_check_u8(&a.0, shard, channel.clone(), is_real.clone()); + builder.slice_range_check_u8(&b.0, shard, channel.clone(), is_real.clone()); + builder.slice_range_check_u8(&cols.value.0, shard, channel.clone(), is_real); } } } diff --git a/core/src/operations/add4.rs b/core/src/operations/add4.rs index 10e1f95774..f3066a3851 100644 --- a/core/src/operations/add4.rs +++ b/core/src/operations/add4.rs @@ -33,10 +33,12 @@ pub struct Add4Operation { } impl Add4Operation { + #[allow(clippy::too_many_arguments)] pub fn populate( &mut self, record: &mut ExecutionRecord, shard: u32, + channel: u32, a_u32: u32, b_u32: u32, c_u32: u32, @@ -71,11 +73,11 @@ impl Add4Operation { // Range check. { - record.add_u8_range_checks(shard, &a); - record.add_u8_range_checks(shard, &b); - record.add_u8_range_checks(shard, &c); - record.add_u8_range_checks(shard, &d); - record.add_u8_range_checks(shard, &expected.to_le_bytes()); + record.add_u8_range_checks(shard, channel, &a); + record.add_u8_range_checks(shard, channel, &b); + record.add_u8_range_checks(shard, channel, &c); + record.add_u8_range_checks(shard, channel, &d); + record.add_u8_range_checks(shard, channel, &expected.to_le_bytes()); } expected } @@ -88,16 +90,17 @@ impl Add4Operation { c: Word, d: Word, shard: AB::Var, + channel: impl Into + Copy, is_real: AB::Var, cols: Add4Operation, ) { // Range check each byte. { - builder.slice_range_check_u8(&a.0, shard, is_real); - builder.slice_range_check_u8(&b.0, shard, is_real); - builder.slice_range_check_u8(&c.0, shard, is_real); - builder.slice_range_check_u8(&d.0, shard, is_real); - builder.slice_range_check_u8(&cols.value.0, shard, is_real); + builder.slice_range_check_u8(&a.0, shard, channel, is_real); + builder.slice_range_check_u8(&b.0, shard, channel, is_real); + builder.slice_range_check_u8(&c.0, shard, channel, is_real); + builder.slice_range_check_u8(&d.0, shard, channel, is_real); + builder.slice_range_check_u8(&cols.value.0, shard, channel, is_real); } builder.assert_bool(is_real); diff --git a/core/src/operations/add5.rs b/core/src/operations/add5.rs index 67421f8aa8..00da26bf84 100644 --- a/core/src/operations/add5.rs +++ b/core/src/operations/add5.rs @@ -41,6 +41,7 @@ impl Add5Operation { &mut self, record: &mut ExecutionRecord, shard: u32, + channel: u32, a_u32: u32, b_u32: u32, c_u32: u32, @@ -81,12 +82,12 @@ impl Add5Operation { // Range check. { - record.add_u8_range_checks(shard, &a); - record.add_u8_range_checks(shard, &b); - record.add_u8_range_checks(shard, &c); - record.add_u8_range_checks(shard, &d); - record.add_u8_range_checks(shard, &e); - record.add_u8_range_checks(shard, &expected.to_le_bytes()); + record.add_u8_range_checks(shard, channel, &a); + record.add_u8_range_checks(shard, channel, &b); + record.add_u8_range_checks(shard, channel, &c); + record.add_u8_range_checks(shard, channel, &d); + record.add_u8_range_checks(shard, channel, &e); + record.add_u8_range_checks(shard, channel, &expected.to_le_bytes()); } expected @@ -96,6 +97,7 @@ impl Add5Operation { builder: &mut AB, words: &[Word; 5], shard: AB::Var, + channel: impl Into + Copy, is_real: AB::Var, cols: Add5Operation, ) { @@ -104,8 +106,8 @@ impl Add5Operation { { words .iter() - .for_each(|word| builder.slice_range_check_u8(&word.0, shard, is_real)); - builder.slice_range_check_u8(&cols.value.0, shard, is_real); + .for_each(|word| builder.slice_range_check_u8(&word.0, shard, channel, is_real)); + builder.slice_range_check_u8(&cols.value.0, shard, channel, is_real); } let mut builder_is_real = builder.when(is_real); diff --git a/core/src/operations/and.rs b/core/src/operations/and.rs index 9bb27a14a0..adeade4190 100644 --- a/core/src/operations/and.rs +++ b/core/src/operations/and.rs @@ -19,7 +19,14 @@ pub struct AndOperation { } impl AndOperation { - pub fn populate(&mut self, record: &mut ExecutionRecord, shard: u32, x: u32, y: u32) -> u32 { + pub fn populate( + &mut self, + record: &mut ExecutionRecord, + shard: u32, + channel: u32, + x: u32, + y: u32, + ) -> u32 { let expected = x & y; let x_bytes = x.to_le_bytes(); let y_bytes = y.to_le_bytes(); @@ -29,6 +36,7 @@ impl AndOperation { let byte_event = ByteLookupEvent { shard, + channel, opcode: ByteOpcode::AND, a1: and as u32, a2: 0, @@ -47,6 +55,7 @@ impl AndOperation { b: Word, cols: AndOperation, shard: AB::Var, + channel: impl Into + Copy, is_real: AB::Var, ) { for i in 0..WORD_SIZE { @@ -56,6 +65,7 @@ impl AndOperation { a[i], b[i], shard, + channel, is_real, ); } diff --git a/core/src/operations/field/field_den.rs b/core/src/operations/field/field_den.rs index eab87631ca..7940d8a0c5 100644 --- a/core/src/operations/field/field_den.rs +++ b/core/src/operations/field/field_den.rs @@ -33,6 +33,7 @@ impl FieldDenCols { &mut self, record: &mut impl ByteRecord, shard: u32, + channel: u32, a: &BigUint, b: &BigUint, sign: bool, @@ -84,10 +85,10 @@ impl FieldDenCols { self.witness_high = Limbs(p_witness_high.try_into().unwrap()); // Range checks - record.add_u8_range_checks_field(shard, &self.result.0); - record.add_u8_range_checks_field(shard, &self.carry.0); - record.add_u8_range_checks_field(shard, &self.witness_low.0); - record.add_u8_range_checks_field(shard, &self.witness_high.0); + record.add_u8_range_checks_field(shard, channel, &self.result.0); + record.add_u8_range_checks_field(shard, channel, &self.carry.0); + record.add_u8_range_checks_field(shard, channel, &self.witness_low.0); + record.add_u8_range_checks_field(shard, channel, &self.witness_high.0); result } @@ -97,18 +98,16 @@ impl FieldDenCols where Limbs: Copy, { - pub fn eval< - AB: SP1AirBuilder, - EShard: Into + Clone, - ER: Into + Clone, - >( + #[allow(clippy::too_many_arguments)] + pub fn eval>( &self, builder: &mut AB, a: &Limbs, b: &Limbs, sign: bool, - shard: EShard, - is_real: ER, + shard: impl Into + Clone, + channel: impl Into + Clone, + is_real: impl Into + Clone, ) where V: Into, { @@ -139,10 +138,25 @@ where eval_field_operation::(builder, &p_vanishing, &p_witness_low, &p_witness_high); // Range checks for the result, carry, and witness columns. - builder.slice_range_check_u8(&self.result.0, shard.clone(), is_real.clone()); - builder.slice_range_check_u8(&self.carry.0, shard.clone(), is_real.clone()); - builder.slice_range_check_u8(&self.witness_low.0, shard.clone(), is_real.clone()); - builder.slice_range_check_u8(&self.witness_high.0, shard, is_real); + builder.slice_range_check_u8( + &self.result.0, + shard.clone(), + channel.clone(), + is_real.clone(), + ); + builder.slice_range_check_u8( + &self.carry.0, + shard.clone(), + channel.clone(), + is_real.clone(), + ); + builder.slice_range_check_u8( + &self.witness_low.0, + shard.clone(), + channel.clone(), + is_real.clone(), + ); + builder.slice_range_check_u8(&self.witness_high.0, shard, channel.clone(), is_real); } } @@ -238,7 +252,7 @@ mod tests { let cols: &mut TestCols = row.as_mut_slice().borrow_mut(); cols.a = P::to_limbs_field::(a); cols.b = P::to_limbs_field::(b); - cols.a_den_b.populate(output, 1, a, b, self.sign); + cols.a_den_b.populate(output, 1, 0, a, b, self.sign); row }) .collect::>(); @@ -278,6 +292,7 @@ mod tests { &local.b, self.sign, AB::F::one(), + AB::F::zero(), AB::F::one(), ); } diff --git a/core/src/operations/field/field_inner_product.rs b/core/src/operations/field/field_inner_product.rs index 070b174bd1..2b259f2f7e 100644 --- a/core/src/operations/field/field_inner_product.rs +++ b/core/src/operations/field/field_inner_product.rs @@ -34,6 +34,7 @@ impl FieldInnerProductCols { &mut self, record: &mut impl ByteRecord, shard: u32, + channel: u32, a: &[BigUint], b: &[BigUint], ) -> BigUint { @@ -86,10 +87,10 @@ impl FieldInnerProductCols { self.witness_high = Limbs(p_witness_high.try_into().unwrap()); // Range checks - record.add_u8_range_checks_field(shard, &self.result.0); - record.add_u8_range_checks_field(shard, &self.carry.0); - record.add_u8_range_checks_field(shard, &self.witness_low.0); - record.add_u8_range_checks_field(shard, &self.witness_high.0); + record.add_u8_range_checks_field(shard, channel, &self.result.0); + record.add_u8_range_checks_field(shard, channel, &self.carry.0); + record.add_u8_range_checks_field(shard, channel, &self.witness_low.0); + record.add_u8_range_checks_field(shard, channel, &self.witness_high.0); result.clone() } @@ -99,17 +100,14 @@ impl FieldInnerProductCols where Limbs: Copy, { - pub fn eval< - AB: SP1AirBuilder, - EShard: Into + Clone, - ER: Into + Clone, - >( + pub fn eval>( &self, builder: &mut AB, a: &[Limbs], b: &[Limbs], - shard: EShard, - is_real: ER, + shard: impl Into + Clone, + channel: impl Into + Clone, + is_real: impl Into + Clone, ) where V: Into, { @@ -138,10 +136,25 @@ where eval_field_operation::(builder, &p_vanishing, &p_witness_low, &p_witness_high); // Range checks for the result, carry, and witness columns. - builder.slice_range_check_u8(&self.result.0, shard.clone(), is_real.clone()); - builder.slice_range_check_u8(&self.carry.0, shard.clone(), is_real.clone()); - builder.slice_range_check_u8(&self.witness_low.0, shard.clone(), is_real.clone()); - builder.slice_range_check_u8(&self.witness_high.0, shard, is_real); + builder.slice_range_check_u8( + &self.result.0, + shard.clone(), + channel.clone(), + is_real.clone(), + ); + builder.slice_range_check_u8( + &self.carry.0, + shard.clone(), + channel.clone(), + is_real.clone(), + ); + builder.slice_range_check_u8( + &self.witness_low.0, + shard.clone(), + channel.clone(), + is_real.clone(), + ); + builder.slice_range_check_u8(&self.witness_high.0, shard, channel.clone(), is_real); } } @@ -231,7 +244,7 @@ mod tests { let cols: &mut TestCols = row.as_mut_slice().borrow_mut(); cols.a[0] = P::to_limbs_field::(&a[0]); cols.b[0] = P::to_limbs_field::(&b[0]); - cols.a_ip_b.populate(output, 1, a, b); + cols.a_ip_b.populate(output, 1, 0, a, b); row }) .collect::>(); @@ -267,9 +280,14 @@ mod tests { let main = builder.main(); let local = main.row_slice(0); let local: &TestCols = (*local).borrow(); - local - .a_ip_b - .eval(builder, &local.a, &local.b, AB::F::one(), AB::F::one()); + local.a_ip_b.eval( + builder, + &local.a, + &local.b, + AB::F::one(), + AB::F::zero(), + AB::F::one(), + ); } } diff --git a/core/src/operations/field/field_op.rs b/core/src/operations/field/field_op.rs index 9c8943d0ce..59bb85f2a2 100644 --- a/core/src/operations/field/field_op.rs +++ b/core/src/operations/field/field_op.rs @@ -108,10 +108,12 @@ impl FieldOpCols { /// Populate these columns with a specified modulus. This is useful in the `mulmod` precompile /// as an example. + #[allow(clippy::too_many_arguments)] pub fn populate_with_modulus( &mut self, record: &mut impl ByteRecord, shard: u32, + channel: u32, a: &BigUint, b: &BigUint, modulus: &BigUint, @@ -159,10 +161,10 @@ impl FieldOpCols { }; // Range checks - record.add_u8_range_checks_field(shard, &self.result.0); - record.add_u8_range_checks_field(shard, &self.carry.0); - record.add_u8_range_checks_field(shard, &self.witness_low.0); - record.add_u8_range_checks_field(shard, &self.witness_high.0); + record.add_u8_range_checks_field(shard, channel, &self.result.0); + record.add_u8_range_checks_field(shard, channel, &self.carry.0); + record.add_u8_range_checks_field(shard, channel, &self.witness_low.0); + record.add_u8_range_checks_field(shard, channel, &self.witness_high.0); result } @@ -172,11 +174,12 @@ impl FieldOpCols { &mut self, record: &mut impl ByteRecord, shard: u32, + channel: u32, a: &BigUint, b: &BigUint, op: FieldOperation, ) -> BigUint { - self.populate_with_modulus(record, shard, a, b, &P::modulus(), op) + self.populate_with_modulus(record, shard, channel, a, b, &P::modulus(), op) } } @@ -190,6 +193,7 @@ impl FieldOpCols { modulus: &(impl Into> + Clone), op: FieldOperation, shard: impl Into + Clone, + channel: impl Into + Clone, is_real: impl Into + Clone, ) where V: Into, @@ -215,12 +219,33 @@ impl FieldOpCols { eval_field_operation::(builder, &p_vanishing, &p_witness_low, &p_witness_high); // Range checks for the result, carry, and witness columns. - builder.slice_range_check_u8(&self.result.0, shard.clone(), is_real.clone()); - builder.slice_range_check_u8(&self.carry.0, shard.clone(), is_real.clone()); - builder.slice_range_check_u8(p_witness_low.coefficients(), shard.clone(), is_real.clone()); - builder.slice_range_check_u8(p_witness_high.coefficients(), shard.clone(), is_real); + builder.slice_range_check_u8( + &self.result.0, + shard.clone(), + channel.clone(), + is_real.clone(), + ); + builder.slice_range_check_u8( + &self.carry.0, + shard.clone(), + channel.clone(), + is_real.clone(), + ); + builder.slice_range_check_u8( + p_witness_low.coefficients(), + shard.clone(), + channel.clone(), + is_real.clone(), + ); + builder.slice_range_check_u8( + p_witness_high.coefficients(), + shard.clone(), + channel.clone(), + is_real, + ); } + #[allow(clippy::too_many_arguments)] pub fn eval>( &self, builder: &mut AB, @@ -228,13 +253,14 @@ impl FieldOpCols { b: &(impl Into> + Clone), op: FieldOperation, shard: impl Into + Clone, + channel: impl Into + Clone, is_real: impl Into + Clone, ) where V: Into, Limbs: Copy, { let p_limbs = Polynomial::from_iter(P::modulus_field_iter::().map(AB::Expr::from)); - self.eval_with_modulus::(builder, a, b, &p_limbs, op, shard, is_real); + self.eval_with_modulus::(builder, a, b, &p_limbs, op, shard, channel, is_real); } } @@ -336,7 +362,7 @@ mod tests { cols.a = P::to_limbs_field::(a); cols.b = P::to_limbs_field::(b); cols.a_op_b - .populate(&mut blu_events, 1, a, b, self.operation); + .populate(&mut blu_events, 1, 0, a, b, self.operation); output.add_byte_lookup_events(blu_events); row }) @@ -379,6 +405,7 @@ mod tests { &local.b, self.operation, AB::F::one(), + AB::F::zero(), AB::F::one(), ); } diff --git a/core/src/operations/field/field_sqrt.rs b/core/src/operations/field/field_sqrt.rs index 95f4f55319..6693c2734e 100644 --- a/core/src/operations/field/field_sqrt.rs +++ b/core/src/operations/field/field_sqrt.rs @@ -41,6 +41,7 @@ impl FieldSqrtCols { &mut self, record: &mut impl ByteRecord, shard: u32, + channel: u32, a: &BigUint, sqrt_fn: impl Fn(&BigUint) -> BigUint, ) -> BigUint { @@ -52,6 +53,7 @@ impl FieldSqrtCols { let sqrt_squared = self.multiplication.populate( record, shard, + channel, &sqrt, &sqrt, super::field_op::FieldOperation::Mul, @@ -65,13 +67,14 @@ impl FieldSqrtCols { self.multiplication.result = P::to_limbs_field::(&sqrt); // Populate the range columns. - self.range.populate(record, shard, &sqrt); + self.range.populate(record, shard, channel, &sqrt); let sqrt_bytes = P::to_limbs(&sqrt); self.lsb = F::from_canonical_u8(sqrt_bytes[0] & 1); let and_event = ByteLookupEvent { shard, + channel, opcode: ByteOpcode::AND, a1: self.lsb.as_canonical_u32(), a2: 0, @@ -89,18 +92,14 @@ where Limbs: Copy, { /// Calculates the square root of `a`. - pub fn eval< - AB: SP1AirBuilder, - ER: Into + Clone, - EOdd: Into, - EShard: Into + Clone, - >( + pub fn eval>( &self, builder: &mut AB, a: &Limbs, - is_odd: EOdd, - shard: EShard, - is_real: ER, + is_odd: impl Into, + shard: impl Into + Clone, + channel: impl Into + Clone, + is_real: impl Into + Clone, ) where V: Into, { @@ -118,11 +117,17 @@ where &sqrt, super::field_op::FieldOperation::Mul, shard.clone(), + channel.clone(), is_real.clone(), ); - self.range - .eval(builder, &sqrt, shard.clone(), is_real.clone()); + self.range.eval( + builder, + &sqrt, + shard.clone(), + channel.clone(), + is_real.clone(), + ); // Assert that the square root is the positive one, i.e., with least significant bit 0. // This is done by computing LSB = least_significant_byte & 1. @@ -134,6 +139,7 @@ where sqrt[0], AB::F::one(), shard, + channel, is_real, ); } @@ -224,7 +230,7 @@ mod tests { let mut row = [F::zero(); NUM_TEST_COLS]; let cols: &mut TestCols = row.as_mut_slice().borrow_mut(); cols.a = P::to_limbs_field::(a); - cols.sqrt.populate(&mut blu_events, 1, a, ed25519_sqrt); + cols.sqrt.populate(&mut blu_events, 1, 0, a, ed25519_sqrt); output.add_byte_lookup_events(blu_events); row }) @@ -263,9 +269,14 @@ mod tests { let local: &TestCols = (*local).borrow(); // eval verifies that local.sqrt.result is indeed the square root of local.a. - local - .sqrt - .eval(builder, &local.a, AB::F::zero(), AB::F::one(), AB::F::one()); + local.sqrt.eval( + builder, + &local.a, + AB::F::zero(), + AB::F::one(), + AB::F::zero(), + AB::F::one(), + ); } } diff --git a/core/src/operations/field/range.rs b/core/src/operations/field/range.rs index 484fff6ebe..da2a2de4e8 100644 --- a/core/src/operations/field/range.rs +++ b/core/src/operations/field/range.rs @@ -28,7 +28,13 @@ pub struct FieldRangeCols { } impl FieldRangeCols { - pub fn populate(&mut self, record: &mut impl ByteRecord, shard: u32, value: &BigUint) { + pub fn populate( + &mut self, + record: &mut impl ByteRecord, + shard: u32, + channel: u32, + value: &BigUint, + ) { let value_limbs = P::to_limbs(value); let modulus_limbs = P::to_limbs(&P::modulus()); @@ -46,6 +52,7 @@ impl FieldRangeCols { record.add_byte_lookup_event(ByteLookupEvent { opcode: ByteOpcode::LTU, shard, + channel, a1: 1, a2: 0, b: *byte as u32, @@ -62,17 +69,13 @@ impl FieldRangeCols { } impl FieldRangeCols { - pub fn eval< - AB: SP1AirBuilder, - E: Into> + Clone, - EShard: Into + Clone, - ER: Into + Clone, - >( + pub fn eval, E: Into> + Clone>( &self, builder: &mut AB, element: &E, - shard: EShard, - is_real: ER, + shard: impl Into + Clone, + channel: impl Into + Clone, + is_real: impl Into + Clone, ) where V: Into, Limbs: Copy, @@ -137,6 +140,7 @@ impl FieldRangeCols { self.comparison_byte, modulus_comparison_byte, shard, + channel, is_real, ) } diff --git a/core/src/operations/fixed_rotate_right.rs b/core/src/operations/fixed_rotate_right.rs index 209b1bb045..150626db56 100644 --- a/core/src/operations/fixed_rotate_right.rs +++ b/core/src/operations/fixed_rotate_right.rs @@ -45,6 +45,7 @@ impl FixedRotateRightOperation { &mut self, record: &mut ExecutionRecord, shard: u32, + channel: u32, input: u32, rotation: usize, ) -> u32 { @@ -76,6 +77,7 @@ impl FixedRotateRightOperation { let byte_event = ByteLookupEvent { shard, + channel, opcode: ByteOpcode::ShrCarry, a1: shift as u32, a2: carry as u32, @@ -111,6 +113,7 @@ impl FixedRotateRightOperation { rotation: usize, cols: FixedRotateRightOperation, shard: AB::Var, + channel: impl Into + Clone, is_real: AB::Var, ) { // Compute some constants with respect to the rotation needed for the rotation. @@ -138,6 +141,7 @@ impl FixedRotateRightOperation { input_bytes_rotated[i], AB::F::from_canonical_usize(nb_bits_to_shift), shard, + channel.clone(), is_real, ); diff --git a/core/src/operations/fixed_shift_right.rs b/core/src/operations/fixed_shift_right.rs index 5cadb7acca..19e7f02308 100644 --- a/core/src/operations/fixed_shift_right.rs +++ b/core/src/operations/fixed_shift_right.rs @@ -45,6 +45,7 @@ impl FixedShiftRightOperation { &mut self, record: &mut ExecutionRecord, shard: u32, + channel: u32, input: u32, rotation: usize, ) -> u32 { @@ -75,6 +76,7 @@ impl FixedShiftRightOperation { let (shift, carry) = shr_carry(b, c); let byte_event = ByteLookupEvent { shard, + channel, opcode: ByteOpcode::ShrCarry, a1: shift as u32, a2: carry as u32, @@ -109,7 +111,8 @@ impl FixedShiftRightOperation { input: Word, rotation: usize, cols: FixedShiftRightOperation, - shard: AB::Var, + shard: impl Into + Copy, + channel: impl Into + Copy, is_real: AB::Var, ) { // Compute some constants with respect to the rotation needed for the rotation. @@ -138,6 +141,7 @@ impl FixedShiftRightOperation { input_bytes_rotated[i].clone(), AB::F::from_canonical_usize(nb_bits_to_shift), shard, + channel, is_real, ); diff --git a/core/src/operations/not.rs b/core/src/operations/not.rs index c7a6adad9e..309c630610 100644 --- a/core/src/operations/not.rs +++ b/core/src/operations/not.rs @@ -19,13 +19,19 @@ pub struct NotOperation { } impl NotOperation { - pub fn populate(&mut self, record: &mut ExecutionRecord, shard: u32, x: u32) -> u32 { + pub fn populate( + &mut self, + record: &mut ExecutionRecord, + shard: u32, + channel: u32, + x: u32, + ) -> u32 { let expected = !x; let x_bytes = x.to_le_bytes(); for i in 0..WORD_SIZE { self.value[i] = F::from_canonical_u8(!x_bytes[i]); } - record.add_u8_range_checks(shard, &x_bytes); + record.add_u8_range_checks(shard, channel, &x_bytes); expected } @@ -34,8 +40,9 @@ impl NotOperation { builder: &mut AB, a: Word, cols: NotOperation, - shard: AB::Var, - is_real: AB::Var, + shard: impl Into + Copy, + channel: impl Into + Copy, + is_real: impl Into + Copy, ) { for i in (0..WORD_SIZE).step_by(2) { builder.send_byte_pair( @@ -45,6 +52,7 @@ impl NotOperation { a[i], a[i + 1], shard, + channel, is_real, ); } diff --git a/core/src/operations/or.rs b/core/src/operations/or.rs index f3dd1a1a91..8cb3f00191 100644 --- a/core/src/operations/or.rs +++ b/core/src/operations/or.rs @@ -20,13 +20,20 @@ pub struct OrOperation { } impl OrOperation { - pub fn populate(&mut self, record: &mut ExecutionRecord, shard: u32, x: u32, y: u32) -> u32 { + pub fn populate( + &mut self, + record: &mut ExecutionRecord, + shard: u32, + channel: u32, + x: u32, + y: u32, + ) -> u32 { let expected = x | y; let x_bytes = x.to_le_bytes(); let y_bytes = y.to_le_bytes(); for i in 0..WORD_SIZE { self.value[i] = F::from_canonical_u8(x_bytes[i] | y_bytes[i]); - record.lookup_or(shard, x_bytes[i], y_bytes[i]); + record.lookup_or(shard, channel, x_bytes[i], y_bytes[i]); } expected } @@ -36,7 +43,8 @@ impl OrOperation { a: Word, b: Word, cols: OrOperation, - shard: AB::Var, + shard: impl Into + Copy, + channel: impl Into + Copy, is_real: AB::Var, ) { for i in 0..WORD_SIZE { @@ -46,6 +54,7 @@ impl OrOperation { a[i], b[i], shard, + channel, is_real, ); } diff --git a/core/src/operations/xor.rs b/core/src/operations/xor.rs index 0b6858f6bf..ffb70f6847 100644 --- a/core/src/operations/xor.rs +++ b/core/src/operations/xor.rs @@ -19,7 +19,14 @@ pub struct XorOperation { } impl XorOperation { - pub fn populate(&mut self, record: &mut ExecutionRecord, shard: u32, x: u32, y: u32) -> u32 { + pub fn populate( + &mut self, + record: &mut ExecutionRecord, + shard: u32, + channel: u32, + x: u32, + y: u32, + ) -> u32 { let expected = x ^ y; let x_bytes = x.to_le_bytes(); let y_bytes = y.to_le_bytes(); @@ -29,6 +36,7 @@ impl XorOperation { let byte_event = ByteLookupEvent { shard, + channel, opcode: ByteOpcode::XOR, a1: xor as u32, a2: 0, @@ -47,6 +55,7 @@ impl XorOperation { b: Word, cols: XorOperation, shard: AB::Var, + channel: impl Into + Clone, is_real: AB::Var, ) { for i in 0..WORD_SIZE { @@ -56,6 +65,7 @@ impl XorOperation { a[i], b[i], shard, + channel.clone(), is_real, ); } diff --git a/core/src/runtime/mod.rs b/core/src/runtime/mod.rs index 09b21a88aa..7868098a4d 100644 --- a/core/src/runtime/mod.rs +++ b/core/src/runtime/mod.rs @@ -29,6 +29,7 @@ use std::sync::Arc; use thiserror::Error; +use crate::bytes::NUM_BYTE_LOOKUP_CHANNELS; use crate::memory::MemoryInitializeFinalizeEvent; use crate::utils::env; use crate::{alu::AluEvent, cpu::CpuEvent}; @@ -196,10 +197,16 @@ impl Runtime { } /// Get the current shard. + #[inline] pub fn shard(&self) -> u32 { self.state.current_shard } + #[inline] + pub fn channel(&self) -> u32 { + self.state.channel + } + /// Read a word from memory and create an access record. pub fn mr(&mut self, addr: u32, shard: u32, timestamp: u32) -> MemoryReadRecord { // Get the memory record entry. @@ -377,6 +384,7 @@ impl Runtime { fn emit_cpu( &mut self, shard: u32, + channel: u32, clk: u32, pc: u32, next_pc: u32, @@ -390,6 +398,7 @@ impl Runtime { ) { let cpu_event = CpuEvent { shard, + channel, clk, pc, next_pc, @@ -413,6 +422,7 @@ impl Runtime { let event = AluEvent { shard: self.shard(), clk, + channel: self.channel(), opcode, a, b, @@ -852,10 +862,18 @@ impl Runtime { // Update the clk to the next cycle. self.state.clk += 4; + let channel = self.channel(); + + // Update the channel to the next cycle. + if !self.unconstrained { + self.state.channel = (self.state.channel + 1) % NUM_BYTE_LOOKUP_CHANNELS; + } + // Emit the CPU event for this cycle. if self.emit_events { self.emit_cpu( self.shard(), + channel, clk, pc, next_pc, @@ -868,7 +886,6 @@ impl Runtime { exit_code, ); }; - Ok(()) } @@ -892,6 +909,7 @@ impl Runtime { if !self.unconstrained && self.max_syscall_cycles + self.state.clk >= self.shard_size { self.state.current_shard += 1; self.state.clk = 0; + self.state.channel = 0; } Ok(self.state.pc.wrapping_sub(self.program.pc_base) @@ -915,6 +933,7 @@ impl Runtime { fn initialize(&mut self) { self.state.clk = 0; + self.state.channel = 0; tracing::info!("loading memory image"); for (addr, value) in self.program.memory_image.iter() { diff --git a/core/src/runtime/state.rs b/core/src/runtime/state.rs index 09f6236cc4..8b1c289ee9 100644 --- a/core/src/runtime/state.rs +++ b/core/src/runtime/state.rs @@ -25,6 +25,10 @@ pub struct ExecutionState { /// executed in this shard. pub clk: u32, + /// The channel alternates between 0 and [crate::bytes::NUM_BYTE_LOOKUP_CHANNELS], + /// used to controll byte lookup multiplicity. + pub channel: u32, + /// The program counter. pub pc: u32, @@ -65,6 +69,7 @@ impl ExecutionState { // Start at shard 1 since shard 0 is reserved for memory initialization. current_shard: 1, clk: 0, + channel: 0, pc: pc_start, memory: HashMap::default(), uninitialized_memory: HashMap::default(), diff --git a/core/src/runtime/syscall.rs b/core/src/runtime/syscall.rs index 6f89b374e8..4915f2640f 100644 --- a/core/src/runtime/syscall.rs +++ b/core/src/runtime/syscall.rs @@ -196,6 +196,10 @@ impl<'a> SyscallContext<'a> { self.rt.state.current_shard } + pub fn current_channel(&self) -> u32 { + self.rt.state.channel + } + pub fn mr(&mut self, addr: u32) -> (MemoryReadRecord, u32) { let record = self.rt.mr(addr, self.current_shard, self.clk); (record, record.value) diff --git a/core/src/syscall/precompiles/blake3/compress/air.rs b/core/src/syscall/precompiles/blake3/compress/air.rs index 6037d46791..a5876866e3 100644 --- a/core/src/syscall/precompiles/blake3/compress/air.rs +++ b/core/src/syscall/precompiles/blake3/compress/air.rs @@ -39,6 +39,7 @@ where // TODO: constraint clk column to increment by 1 within same invocation of syscall. builder.receive_syscall( local.shard, + local.channel, local.clk, AB::F::from_canonical_u32(SyscallCode::BLAKE3_COMPRESS_INNER.syscall_id()), local.state_ptr, @@ -145,6 +146,7 @@ impl Blake3CompressInnerChip { for i in 0..NUM_STATE_WORDS_PER_CALL { builder.eval_memory_access( local.shard, + local.channel, local.clk, local.state_ptr + local.state_index[i] * AB::F::from_canonical_usize(WORD_SIZE), &local.state_reads_writes[i], @@ -181,6 +183,7 @@ impl Blake3CompressInnerChip { for i in 0..NUM_MSG_WORDS_PER_CALL { builder.eval_memory_access( local.shard, + local.channel, local.clk, local.message_ptr + local.msg_schedule[i] * AB::F::from_canonical_usize(WORD_SIZE), &local.message_reads[i], @@ -209,7 +212,14 @@ impl Blake3CompressInnerChip { ]; // Call the g function. - GOperation::::eval(builder, input, local.g, local.shard, local.is_real); + GOperation::::eval( + builder, + input, + local.g, + local.shard, + local.channel, + local.is_real, + ); // Finally, the results of the g function should be written to the memory. for i in 0..NUM_STATE_WORDS_PER_CALL { diff --git a/core/src/syscall/precompiles/blake3/compress/columns.rs b/core/src/syscall/precompiles/blake3/compress/columns.rs index 12ea0139c9..bf7bbe4e1e 100644 --- a/core/src/syscall/precompiles/blake3/compress/columns.rs +++ b/core/src/syscall/precompiles/blake3/compress/columns.rs @@ -17,6 +17,7 @@ pub const NUM_BLAKE3_COMPRESS_INNER_COLS: usize = size_of:: { pub shard: T, + pub channel: T, pub clk: T, pub ecall_receive: T, diff --git a/core/src/syscall/precompiles/blake3/compress/execute.rs b/core/src/syscall/precompiles/blake3/compress/execute.rs index 113c126783..35298b0415 100644 --- a/core/src/syscall/precompiles/blake3/compress/execute.rs +++ b/core/src/syscall/precompiles/blake3/compress/execute.rs @@ -57,11 +57,13 @@ impl Syscall for Blake3CompressInnerChip { } let shard = rt.current_shard(); + let channel = rt.current_channel(); rt.record_mut() .blake3_compress_inner_events .push(Blake3CompressInnerEvent { shard, + channel, clk: start_clk, state_ptr, message_reads, diff --git a/core/src/syscall/precompiles/blake3/compress/g.rs b/core/src/syscall/precompiles/blake3/compress/g.rs index 53f93e5ca0..06e8c30348 100644 --- a/core/src/syscall/precompiles/blake3/compress/g.rs +++ b/core/src/syscall/precompiles/blake3/compress/g.rs @@ -49,6 +49,7 @@ impl GOperation { &mut self, record: &mut ExecutionRecord, shard: u32, + channel: u32, input: [u32; 6], ) -> [u32; 4] { let mut a = input[0]; @@ -61,37 +62,41 @@ impl GOperation { // First 4 steps. { // a = a + b + x. - a = self.a_plus_b.populate(record, shard, a, b); - a = self.a_plus_b_plus_x.populate(record, shard, a, x); + a = self.a_plus_b.populate(record, shard, channel, a, b); + a = self.a_plus_b_plus_x.populate(record, shard, channel, a, x); // d = (d ^ a).rotate_right(16). - d = self.d_xor_a.populate(record, shard, d, a); + d = self.d_xor_a.populate(record, shard, channel, d, a); d = d.rotate_right(16); // c = c + d. - c = self.c_plus_d.populate(record, shard, c, d); + c = self.c_plus_d.populate(record, shard, channel, c, d); // b = (b ^ c).rotate_right(12). - b = self.b_xor_c.populate(record, shard, b, c); - b = self.b_xor_c_rotate_right_12.populate(record, shard, b, 12); + b = self.b_xor_c.populate(record, shard, channel, b, c); + b = self + .b_xor_c_rotate_right_12 + .populate(record, shard, channel, b, 12); } // Second 4 steps. { // a = a + b + y. - a = self.a_plus_b_2.populate(record, shard, a, b); - a = self.a_plus_b_2_add_y.populate(record, shard, a, y); + a = self.a_plus_b_2.populate(record, shard, channel, a, b); + a = self.a_plus_b_2_add_y.populate(record, shard, channel, a, y); // d = (d ^ a).rotate_right(8). - d = self.d_xor_a_2.populate(record, shard, d, a); + d = self.d_xor_a_2.populate(record, shard, channel, d, a); d = d.rotate_right(8); // c = c + d. - c = self.c_plus_d_2.populate(record, shard, c, d); + c = self.c_plus_d_2.populate(record, shard, channel, c, d); // b = (b ^ c).rotate_right(7). - b = self.b_xor_c_2.populate(record, shard, b, c); - b = self.b_xor_c_2_rotate_right_7.populate(record, shard, b, 7); + b = self.b_xor_c_2.populate(record, shard, channel, b, c); + b = self + .b_xor_c_2_rotate_right_7 + .populate(record, shard, channel, b, 7); } let result = [a, b, c, d]; @@ -105,6 +110,7 @@ impl GOperation { input: [Word; 6], cols: GOperation, shard: AB::Var, + channel: impl Into + Clone, is_real: AB::Var, ) { builder.assert_bool(is_real); @@ -118,23 +124,63 @@ impl GOperation { // First 4 steps. { // a = a + b + x. - AddOperation::::eval(builder, a, b, cols.a_plus_b, shard, is_real.into()); + AddOperation::::eval( + builder, + a, + b, + cols.a_plus_b, + shard, + channel.clone(), + is_real.into(), + ); a = cols.a_plus_b.value; - AddOperation::::eval(builder, a, x, cols.a_plus_b_plus_x, shard, is_real.into()); + AddOperation::::eval( + builder, + a, + x, + cols.a_plus_b_plus_x, + shard, + channel.clone(), + is_real.into(), + ); a = cols.a_plus_b_plus_x.value; // d = (d ^ a).rotate_right(16). - XorOperation::::eval(builder, d, a, cols.d_xor_a, shard, is_real); + XorOperation::::eval( + builder, + d, + a, + cols.d_xor_a, + shard, + channel.clone(), + is_real, + ); d = cols.d_xor_a.value; // Rotate right by 16 bits. d = Word([d[2], d[3], d[0], d[1]]); // c = c + d. - AddOperation::::eval(builder, c, d, cols.c_plus_d, shard, is_real.into()); + AddOperation::::eval( + builder, + c, + d, + cols.c_plus_d, + shard, + channel.clone(), + is_real.into(), + ); c = cols.c_plus_d.value; // b = (b ^ c).rotate_right(12). - XorOperation::::eval(builder, b, c, cols.b_xor_c, shard, is_real); + XorOperation::::eval( + builder, + b, + c, + cols.b_xor_c, + shard, + channel.clone(), + is_real, + ); b = cols.b_xor_c.value; FixedRotateRightOperation::::eval( builder, @@ -142,6 +188,7 @@ impl GOperation { 12, cols.b_xor_c_rotate_right_12, shard, + channel.clone(), is_real, ); b = cols.b_xor_c_rotate_right_12.value; @@ -150,7 +197,15 @@ impl GOperation { // Second 4 steps. { // a = a + b + y. - AddOperation::::eval(builder, a, b, cols.a_plus_b_2, shard, is_real.into()); + AddOperation::::eval( + builder, + a, + b, + cols.a_plus_b_2, + shard, + channel.clone(), + is_real.into(), + ); a = cols.a_plus_b_2.value; AddOperation::::eval( builder, @@ -158,22 +213,47 @@ impl GOperation { y, cols.a_plus_b_2_add_y, shard, + channel.clone(), is_real.into(), ); a = cols.a_plus_b_2_add_y.value; // d = (d ^ a).rotate_right(8). - XorOperation::::eval(builder, d, a, cols.d_xor_a_2, shard, is_real); + XorOperation::::eval( + builder, + d, + a, + cols.d_xor_a_2, + shard, + channel.clone(), + is_real, + ); d = cols.d_xor_a_2.value; // Rotate right by 8 bits. d = Word([d[1], d[2], d[3], d[0]]); // c = c + d. - AddOperation::::eval(builder, c, d, cols.c_plus_d_2, shard, is_real.into()); + AddOperation::::eval( + builder, + c, + d, + cols.c_plus_d_2, + shard, + channel.clone(), + is_real.into(), + ); c = cols.c_plus_d_2.value; // b = (b ^ c).rotate_right(7). - XorOperation::::eval(builder, b, c, cols.b_xor_c_2, shard, is_real); + XorOperation::::eval( + builder, + b, + c, + cols.b_xor_c_2, + shard, + channel.clone(), + is_real, + ); b = cols.b_xor_c_2.value; FixedRotateRightOperation::::eval( builder, @@ -181,6 +261,7 @@ impl GOperation { 7, cols.b_xor_c_2_rotate_right_7, shard, + channel.clone(), is_real, ); b = cols.b_xor_c_2_rotate_right_7.value; diff --git a/core/src/syscall/precompiles/blake3/compress/mod.rs b/core/src/syscall/precompiles/blake3/compress/mod.rs index e8c1f8c3e1..e8d800d0df 100644 --- a/core/src/syscall/precompiles/blake3/compress/mod.rs +++ b/core/src/syscall/precompiles/blake3/compress/mod.rs @@ -94,6 +94,7 @@ pub(crate) fn g_func(input: [u32; 6]) -> [u32; 4] { pub struct Blake3CompressInnerEvent { pub clk: u32, pub shard: u32, + pub channel: u32, pub state_ptr: u32, pub message_ptr: u32, pub message_reads: [[[MemoryReadRecord; NUM_MSG_WORDS_PER_CALL]; OPERATION_COUNT]; ROUND_COUNT], diff --git a/core/src/syscall/precompiles/blake3/compress/trace.rs b/core/src/syscall/precompiles/blake3/compress/trace.rs index dec27f7939..14994cb031 100644 --- a/core/src/syscall/precompiles/blake3/compress/trace.rs +++ b/core/src/syscall/precompiles/blake3/compress/trace.rs @@ -37,6 +37,7 @@ impl MachineAir for Blake3CompressInnerChip { for i in 0..input.blake3_compress_inner_events.len() { let event = input.blake3_compress_inner_events[i].clone(); let shard = event.shard; + let channel = event.channel; let mut clk = event.clk; for round in 0..ROUND_COUNT { for operation in 0..OPERATION_COUNT { @@ -46,6 +47,7 @@ impl MachineAir for Blake3CompressInnerChip { // Assign basic values to the columns. { cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.clk = F::from_canonical_u32(clk); cols.round_index = F::from_canonical_u32(round as u32); @@ -73,6 +75,7 @@ impl MachineAir for Blake3CompressInnerChip { cols.message_ptr = F::from_canonical_u32(event.message_ptr); for i in 0..NUM_MSG_WORDS_PER_CALL { cols.message_reads[i].populate( + channel, event.message_reads[round][operation][i], &mut new_byte_lookup_events, ); @@ -81,6 +84,7 @@ impl MachineAir for Blake3CompressInnerChip { cols.state_ptr = F::from_canonical_u32(event.state_ptr); for i in 0..NUM_STATE_WORDS_PER_CALL { cols.state_reads_writes[i].populate( + channel, MemoryRecordEnum::Write(event.state_writes[round][operation][i]), &mut new_byte_lookup_events, ); @@ -98,7 +102,7 @@ impl MachineAir for Blake3CompressInnerChip { event.message_reads[round][operation][1].value, ]; - cols.g.populate(output, shard, input); + cols.g.populate(output, shard, channel, input); } clk += 1; diff --git a/core/src/syscall/precompiles/edwards/ed_add.rs b/core/src/syscall/precompiles/edwards/ed_add.rs index db820faac6..4e215a0df0 100644 --- a/core/src/syscall/precompiles/edwards/ed_add.rs +++ b/core/src/syscall/precompiles/edwards/ed_add.rs @@ -52,6 +52,7 @@ pub const NUM_ED_ADD_COLS: usize = size_of::>(); pub struct EdAddAssignCols { pub is_real: T, pub shard: T, + pub channel: T, pub clk: T, pub p_ptr: T, pub q_ptr: T, @@ -78,9 +79,12 @@ impl EdAddAssignChip { _marker: PhantomData, } } + + #[allow(clippy::too_many_arguments)] fn populate_field_ops( record: &mut impl ByteRecord, shard: u32, + channel: u32, cols: &mut EdAddAssignCols, p_x: BigUint, p_y: BigUint, @@ -90,34 +94,41 @@ impl EdAddAssignChip { let x3_numerator = cols.x3_numerator.populate( record, shard, + channel, &[p_x.clone(), q_x.clone()], &[q_y.clone(), p_y.clone()], ); let y3_numerator = cols.y3_numerator.populate( record, shard, + channel, &[p_y.clone(), p_x.clone()], &[q_y.clone(), q_x.clone()], ); - let x1_mul_y1 = cols - .x1_mul_y1 - .populate(record, shard, &p_x, &p_y, FieldOperation::Mul); - let x2_mul_y2 = cols - .x2_mul_y2 - .populate(record, shard, &q_x, &q_y, FieldOperation::Mul); - let f = cols - .f - .populate(record, shard, &x1_mul_y1, &x2_mul_y2, FieldOperation::Mul); + let x1_mul_y1 = + cols.x1_mul_y1 + .populate(record, shard, channel, &p_x, &p_y, FieldOperation::Mul); + let x2_mul_y2 = + cols.x2_mul_y2 + .populate(record, shard, channel, &q_x, &q_y, FieldOperation::Mul); + let f = cols.f.populate( + record, + shard, + channel, + &x1_mul_y1, + &x2_mul_y2, + FieldOperation::Mul, + ); let d = E::d_biguint(); let d_mul_f = cols .d_mul_f - .populate(record, shard, &f, &d, FieldOperation::Mul); + .populate(record, shard, channel, &f, &d, FieldOperation::Mul); cols.x3_ins - .populate(record, shard, &x3_numerator, &d_mul_f, true); + .populate(record, shard, channel, &x3_numerator, &d_mul_f, true); cols.y3_ins - .populate(record, shard, &y3_numerator, &d_mul_f, false); + .populate(record, shard, channel, &y3_numerator, &d_mul_f, false); } } @@ -168,6 +179,7 @@ impl MachineAir for Ed // Populate basic columns. cols.is_real = F::one(); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.clk = F::from_canonical_u32(event.clk); cols.p_ptr = F::from_canonical_u32(event.p_ptr); cols.q_ptr = F::from_canonical_u32(event.q_ptr); @@ -176,6 +188,7 @@ impl MachineAir for Ed Self::populate_field_ops( &mut new_byte_lookup_events, event.shard, + event.channel, cols, p_x, p_y, @@ -185,12 +198,18 @@ impl MachineAir for Ed // Populate the memory access columns. for i in 0..WORDS_CURVE_POINT { - cols.q_access[i] - .populate(event.q_memory_records[i], &mut new_byte_lookup_events); + cols.q_access[i].populate( + event.channel, + event.q_memory_records[i], + &mut new_byte_lookup_events, + ); } for i in 0..WORDS_CURVE_POINT { - cols.p_access[i] - .populate(event.p_memory_records[i], &mut new_byte_lookup_events); + cols.p_access[i].populate( + event.channel, + event.p_memory_records[i], + &mut new_byte_lookup_events, + ); } (row, new_byte_lookup_events) @@ -208,6 +227,7 @@ impl MachineAir for Ed Self::populate_field_ops( &mut vec![], 0, + 0, cols, zero.clone(), zero.clone(), @@ -250,12 +270,24 @@ where let y2 = limbs_from_prev_access(&row.q_access[8..16]); // x3_numerator = x1 * y2 + x2 * y1. - row.x3_numerator - .eval(builder, &[x1, x2], &[y2, y1], row.shard, row.is_real); + row.x3_numerator.eval( + builder, + &[x1, x2], + &[y2, y1], + row.shard, + row.channel, + row.is_real, + ); // y3_numerator = y1 * y2 + x1 * x2. - row.y3_numerator - .eval(builder, &[y1, x1], &[y2, x2], row.shard, row.is_real); + row.y3_numerator.eval( + builder, + &[y1, x1], + &[y2, x2], + row.shard, + row.channel, + row.is_real, + ); // f = x1 * x2 * y1 * y2. row.x1_mul_y1.eval( @@ -264,6 +296,7 @@ where &y1, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); row.x2_mul_y2.eval( @@ -272,6 +305,7 @@ where &y2, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); @@ -283,6 +317,7 @@ where &x2_mul_y2, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); @@ -296,6 +331,7 @@ where &d_const, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); @@ -308,6 +344,7 @@ where &d_mul_f, true, row.shard, + row.channel, row.is_real, ); @@ -318,6 +355,7 @@ where &d_mul_f, false, row.shard, + row.channel, row.is_real, ); @@ -334,6 +372,7 @@ where builder.eval_memory_access_slice( row.shard, + row.channel, row.clk.into(), row.q_ptr, &row.q_access, @@ -342,6 +381,7 @@ where builder.eval_memory_access_slice( row.shard, + row.channel, row.clk + AB::F::from_canonical_u32(1), row.p_ptr, &row.p_access, @@ -350,6 +390,7 @@ where builder.receive_syscall( row.shard, + row.channel, row.clk, AB::F::from_canonical_u32(SyscallCode::ED_ADD.syscall_id()), row.p_ptr, diff --git a/core/src/syscall/precompiles/edwards/ed_decompress.rs b/core/src/syscall/precompiles/edwards/ed_decompress.rs index 62da4f9c3e..3833e21e93 100644 --- a/core/src/syscall/precompiles/edwards/ed_decompress.rs +++ b/core/src/syscall/precompiles/edwards/ed_decompress.rs @@ -54,6 +54,7 @@ use super::{WordsFieldElement, WORDS_FIELD_ELEMENT}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EdDecompressEvent { pub shard: u32, + pub channel: u32, pub clk: u32, pub ptr: u32, pub sign: bool, @@ -75,6 +76,7 @@ pub const NUM_ED_DECOMPRESS_COLS: usize = size_of::>(); pub struct EdDecompressCols { pub is_real: T, pub shard: T, + pub channel: T, pub clk: T, pub ptr: T, pub sign: T, @@ -99,16 +101,25 @@ impl EdDecompressCols { let mut new_byte_lookup_events = Vec::new(); self.is_real = F::from_bool(true); self.shard = F::from_canonical_u32(event.shard); + self.channel = F::from_canonical_u32(event.channel); self.clk = F::from_canonical_u32(event.clk); self.ptr = F::from_canonical_u32(event.ptr); self.sign = F::from_bool(event.sign); for i in 0..8 { - self.x_access[i].populate(event.x_memory_records[i], &mut new_byte_lookup_events); - self.y_access[i].populate(event.y_memory_records[i], &mut new_byte_lookup_events); + self.x_access[i].populate( + event.channel, + event.x_memory_records[i], + &mut new_byte_lookup_events, + ); + self.y_access[i].populate( + event.channel, + event.y_memory_records[i], + &mut new_byte_lookup_events, + ); } let y = &BigUint::from_bytes_le(&event.y_bytes); - self.populate_field_ops::(&mut new_byte_lookup_events, event.shard, y); + self.populate_field_ops::(&mut new_byte_lookup_events, event.shard, event.channel, y); record.add_byte_lookup_events(new_byte_lookup_events); } @@ -117,28 +128,42 @@ impl EdDecompressCols { &mut self, blu_events: &mut Vec, shard: u32, + channel: u32, y: &BigUint, ) { let one = BigUint::one(); - self.y_range.populate(blu_events, shard, y); + self.y_range.populate(blu_events, shard, channel, y); let yy = self .yy - .populate(blu_events, shard, y, y, FieldOperation::Mul); + .populate(blu_events, shard, channel, y, y, FieldOperation::Mul); let u = self .u - .populate(blu_events, shard, &yy, &one, FieldOperation::Sub); - let dyy = self - .dyy - .populate(blu_events, shard, &E::d_biguint(), &yy, FieldOperation::Mul); + .populate(blu_events, shard, channel, &yy, &one, FieldOperation::Sub); + let dyy = self.dyy.populate( + blu_events, + shard, + channel, + &E::d_biguint(), + &yy, + FieldOperation::Mul, + ); let v = self .v - .populate(blu_events, shard, &one, &dyy, FieldOperation::Add); - let u_div_v = self - .u_div_v - .populate(blu_events, shard, &u, &v, FieldOperation::Div); - let x = self.x.populate(blu_events, shard, &u_div_v, ed25519_sqrt); - self.neg_x - .populate(blu_events, shard, &BigUint::zero(), &x, FieldOperation::Sub); + .populate(blu_events, shard, channel, &one, &dyy, FieldOperation::Add); + let u_div_v = + self.u_div_v + .populate(blu_events, shard, channel, &u, &v, FieldOperation::Div); + let x = self + .x + .populate(blu_events, shard, channel, &u_div_v, ed25519_sqrt); + self.neg_x.populate( + blu_events, + shard, + channel, + &BigUint::zero(), + &x, + FieldOperation::Sub, + ); } } @@ -152,13 +177,15 @@ impl EdDecompressCols { builder.assert_bool(self.sign); let y: Limbs = limbs_from_prev_access(&self.y_access); - self.y_range.eval(builder, &y, self.shard, self.is_real); + self.y_range + .eval(builder, &y, self.shard, self.channel, self.is_real); self.yy.eval( builder, &y, &y, FieldOperation::Mul, self.shard, + self.channel, self.is_real, ); self.u.eval( @@ -167,6 +194,7 @@ impl EdDecompressCols { &[AB::Expr::one()].iter(), FieldOperation::Sub, self.shard, + self.channel, self.is_real, ); let d_biguint = E::d_biguint(); @@ -177,6 +205,7 @@ impl EdDecompressCols { &self.yy.result, FieldOperation::Mul, self.shard, + self.channel, self.is_real, ); self.v.eval( @@ -185,6 +214,7 @@ impl EdDecompressCols { &self.dyy.result, FieldOperation::Add, self.shard, + self.channel, self.is_real, ); self.u_div_v.eval( @@ -193,6 +223,7 @@ impl EdDecompressCols { &self.v.result, FieldOperation::Div, self.shard, + self.channel, self.is_real, ); self.x.eval( @@ -200,6 +231,7 @@ impl EdDecompressCols { &self.u_div_v.result, AB::F::zero(), self.shard, + self.channel, self.is_real, ); self.neg_x.eval( @@ -208,11 +240,13 @@ impl EdDecompressCols { &self.x.multiplication.result, FieldOperation::Sub, self.shard, + self.channel, self.is_real, ); builder.eval_memory_access_slice( self.shard, + self.channel, self.clk, self.ptr, &self.x_access, @@ -220,6 +254,7 @@ impl EdDecompressCols { ); builder.eval_memory_access_slice( self.shard, + self.channel, self.clk, self.ptr.into() + AB::F::from_canonical_u32(32), &self.y_access, @@ -239,6 +274,7 @@ impl EdDecompressCols { builder.receive_syscall( self.shard, + self.channel, self.clk, AB::F::from_canonical_u32(SyscallCode::ED_DECOMPRESS.syscall_id()), self.ptr, @@ -291,10 +327,12 @@ impl Syscall for EdDecompressChip { let x_memory_records: [MemoryWriteRecord; 8] = x_memory_records_vec.try_into().unwrap(); let shard = rt.current_shard(); + let channel = rt.current_channel(); rt.record_mut() .ed_decompress_events .push(EdDecompressEvent { shard, + channel, clk: start_clk, ptr: slice_ptr, sign: sign_bool, @@ -348,7 +386,7 @@ impl MachineAir for EdDecompressChip = row.as_mut_slice().borrow_mut(); let zero = BigUint::zero(); - cols.populate_field_ops::(&mut vec![], 0, &zero); + cols.populate_field_ops::(&mut vec![], 0, 0, &zero); row }); diff --git a/core/src/syscall/precompiles/keccak256/air.rs b/core/src/syscall/precompiles/keccak256/air.rs index ceda18305c..8be89ce65b 100644 --- a/core/src/syscall/precompiles/keccak256/air.rs +++ b/core/src/syscall/precompiles/keccak256/air.rs @@ -54,6 +54,7 @@ where builder.eval_memory_access( local.shard, + local.channel, local.clk + final_step, // The clk increments by 1 after a final step local.state_addr + AB::Expr::from_canonical_u32(i * 4), &local.state_mem[i as usize], @@ -65,6 +66,7 @@ where builder.assert_eq(local.receive_ecall, first_step * local.is_real); builder.receive_syscall( local.shard, + local.channel, local.clk, AB::F::from_canonical_u32(SyscallCode::KECCAK_PERMUTE.syscall_id()), local.state_addr, diff --git a/core/src/syscall/precompiles/keccak256/columns.rs b/core/src/syscall/precompiles/keccak256/columns.rs index 68e4035d18..a3e2dd3044 100644 --- a/core/src/syscall/precompiles/keccak256/columns.rs +++ b/core/src/syscall/precompiles/keccak256/columns.rs @@ -18,6 +18,7 @@ pub(crate) struct KeccakMemCols { pub keccak: KeccakCols, pub shard: T, + pub channel: T, pub clk: T, pub state_addr: T, diff --git a/core/src/syscall/precompiles/keccak256/execute.rs b/core/src/syscall/precompiles/keccak256/execute.rs index bf30c43a9f..d6c306c45f 100644 --- a/core/src/syscall/precompiles/keccak256/execute.rs +++ b/core/src/syscall/precompiles/keccak256/execute.rs @@ -98,10 +98,12 @@ impl Syscall for KeccakPermuteChip { // Push the Keccak permute event. let shard = rt.current_shard(); + let channel = rt.current_channel(); rt.record_mut() .keccak_permute_events .push(KeccakPermuteEvent { shard, + channel, clk: start_clk, pre_state: saved_state.as_slice().try_into().unwrap(), post_state: state.as_slice().try_into().unwrap(), diff --git a/core/src/syscall/precompiles/keccak256/mod.rs b/core/src/syscall/precompiles/keccak256/mod.rs index b44ce53f05..67f733bba0 100644 --- a/core/src/syscall/precompiles/keccak256/mod.rs +++ b/core/src/syscall/precompiles/keccak256/mod.rs @@ -16,6 +16,7 @@ const STATE_NUM_WORDS: usize = STATE_SIZE * 2; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KeccakPermuteEvent { pub shard: u32, + pub channel: u32, pub clk: u32, pub pre_state: [u64; STATE_SIZE], pub post_state: [u64; STATE_SIZE], diff --git a/core/src/syscall/precompiles/keccak256/trace.rs b/core/src/syscall/precompiles/keccak256/trace.rs index bad43967df..01b07fb743 100644 --- a/core/src/syscall/precompiles/keccak256/trace.rs +++ b/core/src/syscall/precompiles/keccak256/trace.rs @@ -56,6 +56,7 @@ impl MachineAir for KeccakPermuteChip { let event = &input.keccak_permute_events[*event_index]; let start_clk = event.clk; let shard = event.shard; + let channel = event.channel; // Create all the rows for the permutation. for i in 0..NUM_ROUNDS { @@ -68,6 +69,7 @@ impl MachineAir for KeccakPermuteChip { let cols: &mut KeccakMemCols = row.as_mut_slice().borrow_mut(); cols.shard = F::from_canonical_u32(shard); + cols.channel = F::from_canonical_u32(channel); cols.clk = F::from_canonical_u32(start_clk); cols.state_addr = F::from_canonical_u32(event.state_addr); cols.is_real = F::one(); @@ -76,8 +78,11 @@ impl MachineAir for KeccakPermuteChip { if i == 0 { for (j, read_record) in event.state_read_records.iter().enumerate() { - cols.state_mem[j] - .populate_read(*read_record, &mut new_byte_lookup_events); + cols.state_mem[j].populate_read( + channel, + *read_record, + &mut new_byte_lookup_events, + ); } cols.do_memory_check = F::one(); @@ -89,8 +94,11 @@ impl MachineAir for KeccakPermuteChip { for (j, write_record) in event.state_write_records.iter().enumerate() { - cols.state_mem[j] - .populate_write(*write_record, &mut new_byte_lookup_events); + cols.state_mem[j].populate_write( + channel, + *write_record, + &mut new_byte_lookup_events, + ); } cols.do_memory_check = F::one(); diff --git a/core/src/syscall/precompiles/mod.rs b/core/src/syscall/precompiles/mod.rs index f33626e7b1..7b107e2d5b 100644 --- a/core/src/syscall/precompiles/mod.rs +++ b/core/src/syscall/precompiles/mod.rs @@ -21,6 +21,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECAddEvent { pub shard: u32, + pub channel: u32, pub clk: u32, pub p_ptr: u32, pub p: Vec, @@ -68,6 +69,7 @@ pub fn create_ec_add_event( ECAddEvent { shard: rt.current_shard(), + channel: rt.current_channel(), clk: start_clk, p_ptr, p, @@ -82,6 +84,7 @@ pub fn create_ec_add_event( #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECDoubleEvent { pub shard: u32, + pub channel: u32, pub clk: u32, pub p_ptr: u32, pub p: Vec, @@ -117,6 +120,7 @@ pub fn create_ec_double_event( ECDoubleEvent { shard: rt.current_shard(), + channel: rt.current_channel(), clk: start_clk, p_ptr, p, @@ -128,6 +132,7 @@ pub fn create_ec_double_event( #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECDecompressEvent { pub shard: u32, + pub channel: u32, pub clk: u32, pub ptr: u32, pub is_odd: bool, @@ -172,6 +177,7 @@ pub fn create_ec_decompress_event( ECDecompressEvent { shard: rt.current_shard(), + channel: rt.current_channel(), clk: start_clk, ptr: slice_ptr, is_odd: is_odd != 0, diff --git a/core/src/syscall/precompiles/sha256/compress/air.rs b/core/src/syscall/precompiles/sha256/compress/air.rs index 031b08a308..2f4bd5000a 100644 --- a/core/src/syscall/precompiles/sha256/compress/air.rs +++ b/core/src/syscall/precompiles/sha256/compress/air.rs @@ -44,6 +44,7 @@ where ); builder.receive_syscall( local.shard, + local.channel, local.clk, AB::F::from_canonical_u32(SyscallCode::SHA_COMPRESS.syscall_id()), local.w_ptr, @@ -203,6 +204,7 @@ impl ShaCompressChip { let is_finalize = local.octet_num[9]; builder.eval_memory_access( local.shard, + local.channel, local.clk + is_finalize, local.mem_addr, &local.mem, @@ -292,6 +294,7 @@ impl ShaCompressChip { 6, local.e_rr_6, local.shard, + local.channel, local.is_compression, ); // Calculate e rightrotate 11. @@ -301,6 +304,7 @@ impl ShaCompressChip { 11, local.e_rr_11, local.shard, + local.channel, local.is_compression, ); // Calculate e rightrotate 25. @@ -310,6 +314,7 @@ impl ShaCompressChip { 25, local.e_rr_25, local.shard, + local.channel, local.is_compression, ); // Calculate (e rightrotate 6) xor (e rightrotate 11). @@ -319,6 +324,7 @@ impl ShaCompressChip { local.e_rr_11.value, local.s1_intermediate, local.shard, + local.channel, local.is_compression, ); // Calculate S1 := ((e rightrotate 6) xor (e rightrotate 11)) xor (e rightrotate 25). @@ -328,6 +334,7 @@ impl ShaCompressChip { local.e_rr_25.value, local.s1, local.shard, + local.channel, local.is_compression, ); @@ -339,6 +346,7 @@ impl ShaCompressChip { local.f, local.e_and_f, local.shard, + local.channel, local.is_compression, ); // Calculate not e. @@ -347,6 +355,7 @@ impl ShaCompressChip { local.e, local.e_not, local.shard, + local.channel, local.is_compression, ); // Calculate (not e) and g. @@ -356,6 +365,7 @@ impl ShaCompressChip { local.g, local.e_not_and_g, local.shard, + local.channel, local.is_compression, ); // Calculate ch := (e and f) xor ((not e) and g). @@ -365,6 +375,7 @@ impl ShaCompressChip { local.e_not_and_g.value, local.ch, local.shard, + local.channel, local.is_compression, ); @@ -379,6 +390,7 @@ impl ShaCompressChip { local.mem.access.value, ], local.shard, + local.channel, local.is_compression, local.temp1, ); @@ -391,6 +403,7 @@ impl ShaCompressChip { 2, local.a_rr_2, local.shard, + local.channel, local.is_compression, ); // Calculate a rightrotate 13. @@ -400,6 +413,7 @@ impl ShaCompressChip { 13, local.a_rr_13, local.shard, + local.channel, local.is_compression, ); // Calculate a rightrotate 22. @@ -409,6 +423,7 @@ impl ShaCompressChip { 22, local.a_rr_22, local.shard, + local.channel, local.is_compression, ); // Calculate (a rightrotate 2) xor (a rightrotate 13). @@ -418,6 +433,7 @@ impl ShaCompressChip { local.a_rr_13.value, local.s0_intermediate, local.shard, + local.channel, local.is_compression, ); // Calculate S0 := ((a rightrotate 2) xor (a rightrotate 13)) xor (a rightrotate 22). @@ -427,6 +443,7 @@ impl ShaCompressChip { local.a_rr_22.value, local.s0, local.shard, + local.channel, local.is_compression, ); @@ -438,6 +455,7 @@ impl ShaCompressChip { local.b, local.a_and_b, local.shard, + local.channel, local.is_compression, ); // Calculate a and c. @@ -447,6 +465,7 @@ impl ShaCompressChip { local.c, local.a_and_c, local.shard, + local.channel, local.is_compression, ); // Calculate b and c. @@ -456,6 +475,7 @@ impl ShaCompressChip { local.c, local.b_and_c, local.shard, + local.channel, local.is_compression, ); // Calculate (a and b) xor (a and c). @@ -465,6 +485,7 @@ impl ShaCompressChip { local.a_and_c.value, local.maj_intermediate, local.shard, + local.channel, local.is_compression, ); // Calculate maj := ((a and b) xor (a and c)) xor (b and c). @@ -474,6 +495,7 @@ impl ShaCompressChip { local.b_and_c.value, local.maj, local.shard, + local.channel, local.is_compression, ); @@ -484,6 +506,7 @@ impl ShaCompressChip { local.maj.value, local.temp2, local.shard, + local.channel, local.is_compression.into(), ); @@ -494,6 +517,7 @@ impl ShaCompressChip { local.temp1.value, local.d_add_temp1, local.shard, + local.channel, local.is_compression.into(), ); @@ -504,6 +528,7 @@ impl ShaCompressChip { local.temp2.value, local.temp1_add_temp2, local.shard, + local.channel, local.is_compression.into(), ); @@ -581,6 +606,7 @@ impl ShaCompressChip { local.finalized_operand, local.finalize_add, local.shard, + local.channel, is_finalize.into(), ); diff --git a/core/src/syscall/precompiles/sha256/compress/columns.rs b/core/src/syscall/precompiles/sha256/compress/columns.rs index cf990e0385..94a200aedd 100644 --- a/core/src/syscall/precompiles/sha256/compress/columns.rs +++ b/core/src/syscall/precompiles/sha256/compress/columns.rs @@ -25,6 +25,7 @@ pub const NUM_SHA_COMPRESS_COLS: usize = size_of::>(); pub struct ShaCompressCols { /// Inputs. pub shard: T, + pub channel: T, pub clk: T, pub w_ptr: T, pub h_ptr: T, diff --git a/core/src/syscall/precompiles/sha256/compress/execute.rs b/core/src/syscall/precompiles/sha256/compress/execute.rs index 72ee312a5e..a019abbd4c 100644 --- a/core/src/syscall/precompiles/sha256/compress/execute.rs +++ b/core/src/syscall/precompiles/sha256/compress/execute.rs @@ -77,8 +77,10 @@ impl Syscall for ShaCompressChip { // Push the SHA extend event. let shard = rt.current_shard(); + let channel = rt.current_channel(); rt.record_mut().sha_compress_events.push(ShaCompressEvent { shard, + channel, clk: start_clk, w_ptr, h_ptr, diff --git a/core/src/syscall/precompiles/sha256/compress/mod.rs b/core/src/syscall/precompiles/sha256/compress/mod.rs index e95669cee4..5d209b4a97 100644 --- a/core/src/syscall/precompiles/sha256/compress/mod.rs +++ b/core/src/syscall/precompiles/sha256/compress/mod.rs @@ -21,6 +21,7 @@ pub const SHA_COMPRESS_K: [u32; 64] = [ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ShaCompressEvent { pub shard: u32, + pub channel: u32, pub clk: u32, pub w_ptr: u32, pub h_ptr: u32, diff --git a/core/src/syscall/precompiles/sha256/compress/trace.rs b/core/src/syscall/precompiles/sha256/compress/trace.rs index 9f877fd874..bd0b8f8177 100644 --- a/core/src/syscall/precompiles/sha256/compress/trace.rs +++ b/core/src/syscall/precompiles/sha256/compress/trace.rs @@ -34,6 +34,7 @@ impl MachineAir for ShaCompressChip { for i in 0..input.sha_compress_events.len() { let mut event = input.sha_compress_events[i].clone(); let shard = event.shard; + let channel = event.channel; let og_h = event.h; @@ -45,6 +46,7 @@ impl MachineAir for ShaCompressChip { let cols: &mut ShaCompressCols = row.as_mut_slice().borrow_mut(); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.clk = F::from_canonical_u32(event.clk); cols.w_ptr = F::from_canonical_u32(event.w_ptr); cols.h_ptr = F::from_canonical_u32(event.h_ptr); @@ -52,8 +54,11 @@ impl MachineAir for ShaCompressChip { cols.octet[j] = F::one(); cols.octet_num[octet_num_idx] = F::one(); - cols.mem - .populate_read(event.h_read_records[j], &mut new_byte_lookup_events); + cols.mem.populate_read( + channel, + event.h_read_records[j], + &mut new_byte_lookup_events, + ); cols.mem_addr = F::from_canonical_u32(event.h_ptr + (j * 4) as u32); cols.a = Word::from(event.h_read_records[0].value); @@ -84,11 +89,15 @@ impl MachineAir for ShaCompressChip { cols.octet_num[octet_num_idx] = F::one(); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.clk = F::from_canonical_u32(event.clk); cols.w_ptr = F::from_canonical_u32(event.w_ptr); cols.h_ptr = F::from_canonical_u32(event.h_ptr); - cols.mem - .populate_read(event.w_i_read_records[j], &mut new_byte_lookup_events); + cols.mem.populate_read( + channel, + event.w_i_read_records[j], + &mut new_byte_lookup_events, + ); cols.mem_addr = F::from_canonical_u32(event.w_ptr + (j * 4) as u32); let a = event.h[0]; @@ -108,43 +117,60 @@ impl MachineAir for ShaCompressChip { cols.g = Word::from(g); cols.h = Word::from(h); - let e_rr_6 = cols.e_rr_6.populate(output, shard, e, 6); - let e_rr_11 = cols.e_rr_11.populate(output, shard, e, 11); - let e_rr_25 = cols.e_rr_25.populate(output, shard, e, 25); + let e_rr_6 = cols.e_rr_6.populate(output, shard, channel, e, 6); + let e_rr_11 = cols.e_rr_11.populate(output, shard, channel, e, 11); + let e_rr_25 = cols.e_rr_25.populate(output, shard, channel, e, 25); let s1_intermediate = cols .s1_intermediate - .populate(output, shard, e_rr_6, e_rr_11); - let s1 = cols.s1.populate(output, shard, s1_intermediate, e_rr_25); - - let e_and_f = cols.e_and_f.populate(output, shard, e, f); - let e_not = cols.e_not.populate(output, shard, e); - let e_not_and_g = cols.e_not_and_g.populate(output, shard, e_not, g); - let ch = cols.ch.populate(output, shard, e_and_f, e_not_and_g); - - let temp1 = - cols.temp1 - .populate(output, shard, h, s1, ch, event.w[j], SHA_COMPRESS_K[j]); - - let a_rr_2 = cols.a_rr_2.populate(output, shard, a, 2); - let a_rr_13 = cols.a_rr_13.populate(output, shard, a, 13); - let a_rr_22 = cols.a_rr_22.populate(output, shard, a, 22); + .populate(output, shard, channel, e_rr_6, e_rr_11); + let s1 = cols + .s1 + .populate(output, shard, channel, s1_intermediate, e_rr_25); + + let e_and_f = cols.e_and_f.populate(output, shard, channel, e, f); + let e_not = cols.e_not.populate(output, shard, channel, e); + let e_not_and_g = cols.e_not_and_g.populate(output, shard, channel, e_not, g); + let ch = cols + .ch + .populate(output, shard, channel, e_and_f, e_not_and_g); + + let temp1 = cols.temp1.populate( + output, + shard, + channel, + h, + s1, + ch, + event.w[j], + SHA_COMPRESS_K[j], + ); + + let a_rr_2 = cols.a_rr_2.populate(output, shard, channel, a, 2); + let a_rr_13 = cols.a_rr_13.populate(output, shard, channel, a, 13); + let a_rr_22 = cols.a_rr_22.populate(output, shard, channel, a, 22); let s0_intermediate = cols .s0_intermediate - .populate(output, shard, a_rr_2, a_rr_13); - let s0 = cols.s0.populate(output, shard, s0_intermediate, a_rr_22); - - let a_and_b = cols.a_and_b.populate(output, shard, a, b); - let a_and_c = cols.a_and_c.populate(output, shard, a, c); - let b_and_c = cols.b_and_c.populate(output, shard, b, c); + .populate(output, shard, channel, a_rr_2, a_rr_13); + let s0 = cols + .s0 + .populate(output, shard, channel, s0_intermediate, a_rr_22); + + let a_and_b = cols.a_and_b.populate(output, shard, channel, a, b); + let a_and_c = cols.a_and_c.populate(output, shard, channel, a, c); + let b_and_c = cols.b_and_c.populate(output, shard, channel, b, c); let maj_intermediate = cols .maj_intermediate - .populate(output, shard, a_and_b, a_and_c); - let maj = cols.maj.populate(output, shard, maj_intermediate, b_and_c); + .populate(output, shard, channel, a_and_b, a_and_c); + let maj = cols + .maj + .populate(output, shard, channel, maj_intermediate, b_and_c); - let temp2 = cols.temp2.populate(output, shard, s0, maj); + let temp2 = cols.temp2.populate(output, shard, channel, s0, maj); - let d_add_temp1 = cols.d_add_temp1.populate(output, shard, d, temp1); - let temp1_add_temp2 = cols.temp1_add_temp2.populate(output, shard, temp1, temp2); + let d_add_temp1 = cols.d_add_temp1.populate(output, shard, channel, d, temp1); + let temp1_add_temp2 = cols + .temp1_add_temp2 + .populate(output, shard, channel, temp1, temp2); event.h[7] = g; event.h[6] = f; @@ -174,6 +200,7 @@ impl MachineAir for ShaCompressChip { let cols: &mut ShaCompressCols = row.as_mut_slice().borrow_mut(); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.clk = F::from_canonical_u32(event.clk); cols.w_ptr = F::from_canonical_u32(event.w_ptr); cols.h_ptr = F::from_canonical_u32(event.h_ptr); @@ -182,9 +209,12 @@ impl MachineAir for ShaCompressChip { cols.octet_num[octet_num_idx] = F::one(); cols.finalize_add - .populate(output, shard, og_h[j], event.h[j]); - cols.mem - .populate_write(event.h_write_records[j], &mut new_byte_lookup_events); + .populate(output, shard, channel, og_h[j], event.h[j]); + cols.mem.populate_write( + channel, + event.h_write_records[j], + &mut new_byte_lookup_events, + ); cols.mem_addr = F::from_canonical_u32(event.h_ptr + (j * 4) as u32); v[j] = event.h[j]; diff --git a/core/src/syscall/precompiles/sha256/extend/air.rs b/core/src/syscall/precompiles/sha256/extend/air.rs index 6058afd3d3..9da6048048 100644 --- a/core/src/syscall/precompiles/sha256/extend/air.rs +++ b/core/src/syscall/precompiles/sha256/extend/air.rs @@ -50,6 +50,7 @@ where // Read w[i-15]. builder.eval_memory_access( local.shard, + local.channel, local.clk + (local.i - i_start), local.w_ptr + (local.i - AB::F::from_canonical_u32(15)) * nb_bytes_in_word, &local.w_i_minus_15, @@ -59,6 +60,7 @@ where // Read w[i-2]. builder.eval_memory_access( local.shard, + local.channel, local.clk + (local.i - i_start), local.w_ptr + (local.i - AB::F::from_canonical_u32(2)) * nb_bytes_in_word, &local.w_i_minus_2, @@ -68,6 +70,7 @@ where // Read w[i-16]. builder.eval_memory_access( local.shard, + local.channel, local.clk + (local.i - i_start), local.w_ptr + (local.i - AB::F::from_canonical_u32(16)) * nb_bytes_in_word, &local.w_i_minus_16, @@ -77,6 +80,7 @@ where // Read w[i-7]. builder.eval_memory_access( local.shard, + local.channel, local.clk + (local.i - i_start), local.w_ptr + (local.i - AB::F::from_canonical_u32(7)) * nb_bytes_in_word, &local.w_i_minus_7, @@ -91,6 +95,7 @@ where 7, local.w_i_minus_15_rr_7, local.shard, + local.channel, local.is_real, ); // w[i-15] rightrotate 18. @@ -100,6 +105,7 @@ where 18, local.w_i_minus_15_rr_18, local.shard, + local.channel, local.is_real, ); // w[i-15] rightshift 3. @@ -109,6 +115,7 @@ where 3, local.w_i_minus_15_rs_3, local.shard, + local.channel, local.is_real, ); // (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) @@ -118,6 +125,7 @@ where local.w_i_minus_15_rr_18.value, local.s0_intermediate, local.shard, + local.channel, local.is_real, ); // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) @@ -127,6 +135,7 @@ where local.w_i_minus_15_rs_3.value, local.s0, local.shard, + local.channel, local.is_real, ); @@ -138,6 +147,7 @@ where 17, local.w_i_minus_2_rr_17, local.shard, + local.channel, local.is_real, ); // w[i-2] rightrotate 19. @@ -147,6 +157,7 @@ where 19, local.w_i_minus_2_rr_19, local.shard, + local.channel, local.is_real, ); // w[i-2] rightshift 10. @@ -156,6 +167,7 @@ where 10, local.w_i_minus_2_rs_10, local.shard, + local.channel, local.is_real, ); // (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) @@ -165,6 +177,7 @@ where local.w_i_minus_2_rr_19.value, local.s1_intermediate, local.shard, + local.channel, local.is_real, ); // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) @@ -174,6 +187,7 @@ where local.w_i_minus_2_rs_10.value, local.s1, local.shard, + local.channel, local.is_real, ); @@ -185,6 +199,7 @@ where *local.w_i_minus_7.value(), local.s1.value, local.shard, + local.channel, local.is_real, local.s2, ); @@ -192,6 +207,7 @@ where // Write `s2` to `w[i]`. builder.eval_memory_access( local.shard, + local.channel, local.clk + (local.i - i_start), local.w_ptr + local.i * nb_bytes_in_word, &local.w_i, @@ -201,6 +217,7 @@ where // Receive syscall event in first row of 48-cycle. builder.receive_syscall( local.shard, + local.channel, local.clk, AB::F::from_canonical_u32(SyscallCode::SHA_EXTEND.syscall_id()), local.w_ptr, diff --git a/core/src/syscall/precompiles/sha256/extend/columns.rs b/core/src/syscall/precompiles/sha256/extend/columns.rs index a4197ce7f5..5eb99e1f4d 100644 --- a/core/src/syscall/precompiles/sha256/extend/columns.rs +++ b/core/src/syscall/precompiles/sha256/extend/columns.rs @@ -17,6 +17,7 @@ pub const NUM_SHA_EXTEND_COLS: usize = size_of::>(); pub struct ShaExtendCols { /// Inputs. pub shard: T, + pub channel: T, pub clk: T, pub w_ptr: T, diff --git a/core/src/syscall/precompiles/sha256/extend/execute.rs b/core/src/syscall/precompiles/sha256/extend/execute.rs index 467029bd9d..bd163c26c9 100644 --- a/core/src/syscall/precompiles/sha256/extend/execute.rs +++ b/core/src/syscall/precompiles/sha256/extend/execute.rs @@ -61,8 +61,10 @@ impl Syscall for ShaExtendChip { // Push the SHA extend event. let shard = rt.current_shard(); + let channel = rt.current_channel(); rt.record_mut().sha_extend_events.push(ShaExtendEvent { shard, + channel, clk: clk_init, w_ptr: w_ptr_init, w_i_minus_15_reads, diff --git a/core/src/syscall/precompiles/sha256/extend/mod.rs b/core/src/syscall/precompiles/sha256/extend/mod.rs index 529e7c2687..873e3c2fa1 100644 --- a/core/src/syscall/precompiles/sha256/extend/mod.rs +++ b/core/src/syscall/precompiles/sha256/extend/mod.rs @@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ShaExtendEvent { pub shard: u32, + pub channel: u32, pub clk: u32, pub w_ptr: u32, pub w_i_minus_15_reads: Vec, @@ -90,7 +91,7 @@ pub mod extend_tests { #[test] fn generate_trace() { let mut shard = ExecutionRecord::default(); - shard.add_events = vec![AluEvent::new(0, 0, Opcode::ADD, 14, 8, 6)]; + shard.add_events = vec![AluEvent::new(0, 0, 0, Opcode::ADD, 14, 8, 6)]; let chip = ShaExtendChip::new(); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); diff --git a/core/src/syscall/precompiles/sha256/extend/trace.rs b/core/src/syscall/precompiles/sha256/extend/trace.rs index 819b0d6146..2a976ef0d6 100644 --- a/core/src/syscall/precompiles/sha256/extend/trace.rs +++ b/core/src/syscall/precompiles/sha256/extend/trace.rs @@ -37,68 +37,105 @@ impl MachineAir for ShaExtendChip { cols.is_real = F::one(); cols.populate_flags(j); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.clk = F::from_canonical_u32(event.clk); cols.w_ptr = F::from_canonical_u32(event.w_ptr); - cols.w_i_minus_15 - .populate(event.w_i_minus_15_reads[j], &mut new_byte_lookup_events); - cols.w_i_minus_2 - .populate(event.w_i_minus_2_reads[j], &mut new_byte_lookup_events); - cols.w_i_minus_16 - .populate(event.w_i_minus_16_reads[j], &mut new_byte_lookup_events); - cols.w_i_minus_7 - .populate(event.w_i_minus_7_reads[j], &mut new_byte_lookup_events); + cols.w_i_minus_15.populate( + event.channel, + event.w_i_minus_15_reads[j], + &mut new_byte_lookup_events, + ); + cols.w_i_minus_2.populate( + event.channel, + event.w_i_minus_2_reads[j], + &mut new_byte_lookup_events, + ); + cols.w_i_minus_16.populate( + event.channel, + event.w_i_minus_16_reads[j], + &mut new_byte_lookup_events, + ); + cols.w_i_minus_7.populate( + event.channel, + event.w_i_minus_7_reads[j], + &mut new_byte_lookup_events, + ); // `s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3)`. let w_i_minus_15 = event.w_i_minus_15_reads[j].value; let w_i_minus_15_rr_7 = cols.w_i_minus_15_rr_7 - .populate(output, shard, w_i_minus_15, 7); - let w_i_minus_15_rr_18 = - cols.w_i_minus_15_rr_18 - .populate(output, shard, w_i_minus_15, 18); + .populate(output, shard, event.channel, w_i_minus_15, 7); + let w_i_minus_15_rr_18 = cols.w_i_minus_15_rr_18.populate( + output, + shard, + event.channel, + w_i_minus_15, + 18, + ); let w_i_minus_15_rs_3 = cols.w_i_minus_15_rs_3 - .populate(output, shard, w_i_minus_15, 3); + .populate(output, shard, event.channel, w_i_minus_15, 3); let s0_intermediate = cols.s0_intermediate.populate( output, shard, + event.channel, w_i_minus_15_rr_7, w_i_minus_15_rr_18, ); - let s0 = cols - .s0 - .populate(output, shard, s0_intermediate, w_i_minus_15_rs_3); + let s0 = cols.s0.populate( + output, + shard, + event.channel, + s0_intermediate, + w_i_minus_15_rs_3, + ); // `s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10)`. let w_i_minus_2 = event.w_i_minus_2_reads[j].value; let w_i_minus_2_rr_17 = cols.w_i_minus_2_rr_17 - .populate(output, shard, w_i_minus_2, 17); + .populate(output, shard, event.channel, w_i_minus_2, 17); let w_i_minus_2_rr_19 = cols.w_i_minus_2_rr_19 - .populate(output, shard, w_i_minus_2, 19); + .populate(output, shard, event.channel, w_i_minus_2, 19); let w_i_minus_2_rs_10 = cols.w_i_minus_2_rs_10 - .populate(output, shard, w_i_minus_2, 10); + .populate(output, shard, event.channel, w_i_minus_2, 10); let s1_intermediate = cols.s1_intermediate.populate( output, shard, + event.channel, w_i_minus_2_rr_17, w_i_minus_2_rr_19, ); - let s1 = cols - .s1 - .populate(output, shard, s1_intermediate, w_i_minus_2_rs_10); + let s1 = cols.s1.populate( + output, + shard, + event.channel, + s1_intermediate, + w_i_minus_2_rs_10, + ); // Compute `s2`. let w_i_minus_7 = event.w_i_minus_7_reads[j].value; let w_i_minus_16 = event.w_i_minus_16_reads[j].value; - cols.s2 - .populate(output, shard, w_i_minus_16, s0, w_i_minus_7, s1); + cols.s2.populate( + output, + shard, + event.channel, + w_i_minus_16, + s0, + w_i_minus_7, + s1, + ); - cols.w_i - .populate(event.w_i_writes[j], &mut new_byte_lookup_events); + cols.w_i.populate( + event.channel, + event.w_i_writes[j], + &mut new_byte_lookup_events, + ); rows.push(row); } diff --git a/core/src/syscall/precompiles/uint256/air.rs b/core/src/syscall/precompiles/uint256/air.rs index aa422afc7e..2117d75d40 100644 --- a/core/src/syscall/precompiles/uint256/air.rs +++ b/core/src/syscall/precompiles/uint256/air.rs @@ -34,6 +34,7 @@ const NUM_COLS: usize = size_of::>(); #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Uint256MulEvent { pub shard: u32, + pub channel: u32, pub clk: u32, pub x_ptr: u32, pub x: Vec, @@ -64,6 +65,9 @@ pub struct Uint256MulCols { /// The shard number of the syscall. pub shard: T, + /// The byte lookup channel. + pub channel: T, + /// The clock cycle of the syscall. pub clk: T, @@ -124,17 +128,25 @@ impl MachineAir for Uint256MulChip { // Assign basic values to the columns. cols.is_real = F::one(); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.clk = F::from_canonical_u32(event.clk); cols.x_ptr = F::from_canonical_u32(event.x_ptr); cols.y_ptr = F::from_canonical_u32(event.y_ptr); // Populate memory columns. for i in 0..WORDS_FIELD_ELEMENT { - cols.x_memory[i] - .populate(event.x_memory_records[i], &mut new_byte_lookup_events); - cols.y_memory[i] - .populate(event.y_memory_records[i], &mut new_byte_lookup_events); + cols.x_memory[i].populate( + event.channel, + event.x_memory_records[i], + &mut new_byte_lookup_events, + ); + cols.y_memory[i].populate( + event.channel, + event.y_memory_records[i], + &mut new_byte_lookup_events, + ); cols.modulus_memory[i].populate( + event.channel, event.modulus_memory_records[i], &mut new_byte_lookup_events, ); @@ -153,6 +165,7 @@ impl MachineAir for Uint256MulChip { cols.output.populate_with_modulus( &mut new_byte_lookup_events, event.shard, + event.channel, &x, &y, &effective_modulus, @@ -182,7 +195,7 @@ impl MachineAir for Uint256MulChip { let x = BigUint::zero(); let y = BigUint::zero(); cols.output - .populate(&mut vec![], 0, &x, &y, FieldOperation::Mul); + .populate(&mut vec![], 0, 0, &x, &y, FieldOperation::Mul); row }); @@ -245,9 +258,11 @@ impl Syscall for Uint256MulChip { let x_memory_records = rt.mw_slice(x_ptr, &result); let shard = rt.current_shard(); + let channel = rt.current_channel(); let clk = rt.clk; rt.record_mut().uint256_mul_events.push(Uint256MulEvent { shard, + channel, clk, x_ptr, x, @@ -316,9 +331,9 @@ where &x_limbs, &y_limbs, &p_modulus, - // &modulus_limbs, FieldOperation::Mul, local.shard, + local.channel, local.is_real, ); @@ -330,6 +345,7 @@ where // Read and write x. builder.eval_memory_access_slice( local.shard, + local.channel, local.clk.into(), local.x_ptr, &local.x_memory, @@ -340,6 +356,7 @@ where // we read it contiguously from the y_ptr memory location. builder.eval_memory_access_slice( local.shard, + local.channel, local.clk.into(), local.y_ptr, &[local.y_memory, local.modulus_memory].concat(), @@ -349,6 +366,7 @@ where // Receive the arguments. builder.receive_syscall( local.shard, + local.channel, local.clk, AB::F::from_canonical_u32(SyscallCode::UINT256_MUL.syscall_id()), local.x_ptr, diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs index f2282632b8..85a6e9d140 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs @@ -52,6 +52,7 @@ pub const fn num_weierstrass_add_cols() -> usize pub struct WeierstrassAddAssignCols { pub is_real: T, pub shard: T, + pub channel: T, pub clk: T, pub p_ptr: T, pub q_ptr: T, @@ -97,9 +98,11 @@ impl WeierstrassAddAssignChip { } } + #[allow(clippy::too_many_arguments)] fn populate_field_ops( blu_events: &mut Vec, shard: u32, + channel: u32, cols: &mut WeierstrassAddAssignCols, p_x: BigUint, p_y: BigUint, @@ -111,17 +114,28 @@ impl WeierstrassAddAssignChip { // slope = (q.y - p.y) / (q.x - p.x). let slope = { - let slope_numerator = - cols.slope_numerator - .populate(blu_events, shard, &q_y, &p_y, FieldOperation::Sub); + let slope_numerator = cols.slope_numerator.populate( + blu_events, + shard, + channel, + &q_y, + &p_y, + FieldOperation::Sub, + ); - let slope_denominator = - cols.slope_denominator - .populate(blu_events, shard, &q_x, &p_x, FieldOperation::Sub); + let slope_denominator = cols.slope_denominator.populate( + blu_events, + shard, + channel, + &q_x, + &p_x, + FieldOperation::Sub, + ); cols.slope.populate( blu_events, shard, + channel, &slope_numerator, &slope_denominator, FieldOperation::Div, @@ -130,15 +144,26 @@ impl WeierstrassAddAssignChip { // x = slope * slope - (p.x + q.x). let x = { - let slope_squared = - cols.slope_squared - .populate(blu_events, shard, &slope, &slope, FieldOperation::Mul); - let p_x_plus_q_x = - cols.p_x_plus_q_x - .populate(blu_events, shard, &p_x, &q_x, FieldOperation::Add); + let slope_squared = cols.slope_squared.populate( + blu_events, + shard, + channel, + &slope, + &slope, + FieldOperation::Mul, + ); + let p_x_plus_q_x = cols.p_x_plus_q_x.populate( + blu_events, + shard, + channel, + &p_x, + &q_x, + FieldOperation::Add, + ); cols.x3_ins.populate( blu_events, shard, + channel, &slope_squared, &p_x_plus_q_x, FieldOperation::Sub, @@ -147,12 +172,18 @@ impl WeierstrassAddAssignChip { // y = slope * (p.x - x_3n) - p.y. { - let p_x_minus_x = - cols.p_x_minus_x - .populate(blu_events, shard, &p_x, &x, FieldOperation::Sub); + let p_x_minus_x = cols.p_x_minus_x.populate( + blu_events, + shard, + channel, + &p_x, + &x, + FieldOperation::Sub, + ); let slope_times_p_x_minus_x = cols.slope_times_p_x_minus_x.populate( blu_events, shard, + channel, &slope, &p_x_minus_x, FieldOperation::Mul, @@ -160,6 +191,7 @@ impl WeierstrassAddAssignChip { cols.y3_ins.populate( blu_events, shard, + channel, &slope_times_p_x_minus_x, &p_y, FieldOperation::Sub, @@ -218,6 +250,7 @@ where // Populate basic columns. cols.is_real = F::one(); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.clk = F::from_canonical_u32(event.clk); cols.p_ptr = F::from_canonical_u32(event.p_ptr); cols.q_ptr = F::from_canonical_u32(event.q_ptr); @@ -225,6 +258,7 @@ where Self::populate_field_ops( &mut new_byte_lookup_events, event.shard, + event.channel, cols, p_x, p_y, @@ -234,10 +268,18 @@ where // Populate the memory access columns. for i in 0..cols.q_access.len() { - cols.q_access[i].populate(event.q_memory_records[i], &mut new_byte_lookup_events); + cols.q_access[i].populate( + event.channel, + event.q_memory_records[i], + &mut new_byte_lookup_events, + ); } for i in 0..cols.p_access.len() { - cols.p_access[i].populate(event.p_memory_records[i], &mut new_byte_lookup_events); + cols.p_access[i].populate( + event.channel, + event.p_memory_records[i], + &mut new_byte_lookup_events, + ); } rows.push(row); @@ -252,6 +294,7 @@ where Self::populate_field_ops( &mut vec![], 0, + 0, cols, zero.clone(), zero.clone(), @@ -310,6 +353,7 @@ where &p_y, FieldOperation::Sub, row.shard, + row.channel, row.is_real, ); @@ -319,6 +363,7 @@ where &p_x, FieldOperation::Sub, row.shard, + row.channel, row.is_real, ); @@ -328,6 +373,7 @@ where &row.slope_denominator.result, FieldOperation::Div, row.shard, + row.channel, row.is_real, ); @@ -342,6 +388,7 @@ where slope, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); @@ -351,6 +398,7 @@ where &q_x, FieldOperation::Add, row.shard, + row.channel, row.is_real, ); @@ -360,6 +408,7 @@ where &row.p_x_plus_q_x.result, FieldOperation::Sub, row.shard, + row.channel, row.is_real, ); @@ -374,6 +423,7 @@ where x, FieldOperation::Sub, row.shard, + row.channel, row.is_real, ); @@ -383,6 +433,7 @@ where &row.p_x_minus_x.result, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); @@ -392,6 +443,7 @@ where &p_y, FieldOperation::Sub, row.shard, + row.channel, row.is_real, ); } @@ -410,6 +462,7 @@ where builder.eval_memory_access_slice( row.shard, + row.channel, row.clk.into(), row.q_ptr, &row.q_access, @@ -417,6 +470,7 @@ where ); builder.eval_memory_access_slice( row.shard, + row.channel, row.clk + AB::F::from_canonical_u32(1), // We read p at +1 since p, q could be the same. row.p_ptr, &row.p_access, @@ -437,6 +491,7 @@ where builder.receive_syscall( row.shard, + row.channel, row.clk, syscall_id_felt, row.p_ptr, diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs index 04e3e125dd..ec0f4cee54 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs @@ -53,6 +53,7 @@ pub const fn num_weierstrass_decompress_cols() -> pub struct WeierstrassDecompressCols { pub is_real: T, pub shard: T, + pub channel: T, pub clk: T, pub ptr: T, pub is_odd: T, @@ -97,32 +98,40 @@ impl WeierstrassDecompressChip { fn populate_field_ops( record: &mut impl ByteRecord, shard: u32, + channel: u32, cols: &mut WeierstrassDecompressCols, x: BigUint, ) { // Y = sqrt(x^3 + b) - cols.range_x.populate(record, shard, &x); - let x_2 = cols - .x_2 - .populate(record, shard, &x.clone(), &x.clone(), FieldOperation::Mul); + cols.range_x.populate(record, shard, channel, &x); + let x_2 = cols.x_2.populate( + record, + shard, + channel, + &x.clone(), + &x.clone(), + FieldOperation::Mul, + ); let x_3 = cols .x_3 - .populate(record, shard, &x_2, &x, FieldOperation::Mul); + .populate(record, shard, channel, &x_2, &x, FieldOperation::Mul); let b = E::b_int(); - let x_3_plus_b = cols - .x_3_plus_b - .populate(record, shard, &x_3, &b, FieldOperation::Add); + let x_3_plus_b = + cols.x_3_plus_b + .populate(record, shard, channel, &x_3, &b, FieldOperation::Add); let sqrt_fn = match E::CURVE_TYPE { CurveType::Secp256k1 => secp256k1_sqrt, CurveType::Bls12381 => bls12381_sqrt, _ => panic!("Unsupported curve"), }; - let y = cols.y.populate(record, shard, &x_3_plus_b, sqrt_fn); + let y = cols + .y + .populate(record, shard, channel, &x_3_plus_b, sqrt_fn); let zero = BigUint::zero(); cols.neg_y - .populate(record, shard, &zero, &y, FieldOperation::Sub); + .populate(record, shard, channel, &zero, &y, FieldOperation::Sub); } } @@ -165,19 +174,34 @@ where cols.is_real = F::from_bool(true); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); + cols.channel = F::from_canonical_u32(event.channel); cols.clk = F::from_canonical_u32(event.clk); cols.ptr = F::from_canonical_u32(event.ptr); cols.is_odd = F::from_canonical_u32(event.is_odd as u32); let x = BigUint::from_bytes_le(&event.x_bytes); - Self::populate_field_ops(&mut new_byte_lookup_events, event.shard, cols, x); + Self::populate_field_ops( + &mut new_byte_lookup_events, + event.shard, + event.channel, + cols, + x, + ); for i in 0..cols.x_access.len() { - cols.x_access[i].populate(event.x_memory_records[i], &mut new_byte_lookup_events); + cols.x_access[i].populate( + event.channel, + event.x_memory_records[i], + &mut new_byte_lookup_events, + ); } for i in 0..cols.y_access.len() { - cols.y_access[i] - .populate_write(event.y_memory_records[i], &mut new_byte_lookup_events); + cols.y_access[i].populate_write( + event.channel, + event.y_memory_records[i], + &mut new_byte_lookup_events, + ); } rows.push(row); @@ -197,7 +221,7 @@ where cols.x_access[i].access.value = words[i].into(); } - Self::populate_field_ops(&mut vec![], 0, cols, dummy_value); + Self::populate_field_ops(&mut vec![], 0, 0, cols, dummy_value); row }); @@ -239,15 +263,24 @@ where let x: Limbs::Limbs> = limbs_from_prev_access(&row.x_access); - row.range_x.eval(builder, &x, row.shard, row.is_real); - row.x_2 - .eval(builder, &x, &x, FieldOperation::Mul, row.shard, row.is_real); + row.range_x + .eval(builder, &x, row.shard, row.channel, row.is_real); + row.x_2.eval( + builder, + &x, + &x, + FieldOperation::Mul, + row.shard, + row.channel, + row.is_real, + ); row.x_3.eval( builder, &row.x_2.result, &x, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); let b = E::b_int(); @@ -258,6 +291,7 @@ where &b_const, FieldOperation::Add, row.shard, + row.channel, row.is_real, ); @@ -267,6 +301,7 @@ where &row.y.multiplication.result, FieldOperation::Sub, row.shard, + row.channel, row.is_real, ); @@ -278,6 +313,7 @@ where &row.x_3_plus_b.result, row.y.lsb, row.shard, + row.channel, row.is_real, ); @@ -295,6 +331,7 @@ where for i in 0..num_words_field_element { builder.eval_memory_access( row.shard, + row.channel, row.clk, row.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4 + num_limbs as u32), &row.x_access[i], @@ -304,6 +341,7 @@ where for i in 0..num_words_field_element { builder.eval_memory_access( row.shard, + row.channel, row.clk, row.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4), &row.y_access[i], @@ -323,6 +361,7 @@ where builder.receive_syscall( row.shard, + row.channel, row.clk, syscall_id, row.ptr, diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs index 2730c4924e..d94e5366b1 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs @@ -53,6 +53,7 @@ pub const fn num_weierstrass_double_cols() -> usi pub struct WeierstrassDoubleAssignCols { pub is_real: T, pub shard: T, + pub channel: T, pub clk: T, pub p_ptr: T, pub p_access: GenericArray, P::WordsCurvePoint>, @@ -101,6 +102,7 @@ impl WeierstrassDoubleAssignChip { fn populate_field_ops( blu_events: &mut Vec, shard: u32, + channel: u32, cols: &mut WeierstrassDoubleAssignCols, p_x: BigUint, p_y: BigUint, @@ -113,12 +115,18 @@ impl WeierstrassDoubleAssignChip { let slope = { // slope_numerator = a + (p.x * p.x) * 3. let slope_numerator = { - let p_x_squared = - cols.p_x_squared - .populate(blu_events, shard, &p_x, &p_x, FieldOperation::Mul); + let p_x_squared = cols.p_x_squared.populate( + blu_events, + shard, + channel, + &p_x, + &p_x, + FieldOperation::Mul, + ); let p_x_squared_times_3 = cols.p_x_squared_times_3.populate( blu_events, shard, + channel, &p_x_squared, &BigUint::from(3u32), FieldOperation::Mul, @@ -126,6 +134,7 @@ impl WeierstrassDoubleAssignChip { cols.slope_numerator.populate( blu_events, shard, + channel, &a, &p_x_squared_times_3, FieldOperation::Add, @@ -136,6 +145,7 @@ impl WeierstrassDoubleAssignChip { let slope_denominator = cols.slope_denominator.populate( blu_events, shard, + channel, &BigUint::from(2u32), &p_y, FieldOperation::Mul, @@ -144,6 +154,7 @@ impl WeierstrassDoubleAssignChip { cols.slope.populate( blu_events, shard, + channel, &slope_numerator, &slope_denominator, FieldOperation::Div, @@ -152,15 +163,26 @@ impl WeierstrassDoubleAssignChip { // x = slope * slope - (p.x + p.x). let x = { - let slope_squared = - cols.slope_squared - .populate(blu_events, shard, &slope, &slope, FieldOperation::Mul); - let p_x_plus_p_x = - cols.p_x_plus_p_x - .populate(blu_events, shard, &p_x, &p_x, FieldOperation::Add); + let slope_squared = cols.slope_squared.populate( + blu_events, + shard, + channel, + &slope, + &slope, + FieldOperation::Mul, + ); + let p_x_plus_p_x = cols.p_x_plus_p_x.populate( + blu_events, + shard, + channel, + &p_x, + &p_x, + FieldOperation::Add, + ); cols.x3_ins.populate( blu_events, shard, + channel, &slope_squared, &p_x_plus_p_x, FieldOperation::Sub, @@ -169,12 +191,18 @@ impl WeierstrassDoubleAssignChip { // y = slope * (p.x - x) - p.y. { - let p_x_minus_x = - cols.p_x_minus_x - .populate(blu_events, shard, &p_x, &x, FieldOperation::Sub); + let p_x_minus_x = cols.p_x_minus_x.populate( + blu_events, + shard, + channel, + &p_x, + &x, + FieldOperation::Sub, + ); let slope_times_p_x_minus_x = cols.slope_times_p_x_minus_x.populate( blu_events, shard, + channel, &slope, &p_x_minus_x, FieldOperation::Mul, @@ -182,6 +210,7 @@ impl WeierstrassDoubleAssignChip { cols.y3_ins.populate( blu_events, shard, + channel, &slope_times_p_x_minus_x, &p_y, FieldOperation::Sub, @@ -244,12 +273,14 @@ where // Populate basic columns. cols.is_real = F::one(); cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); cols.clk = F::from_canonical_u32(event.clk); cols.p_ptr = F::from_canonical_u32(event.p_ptr); Self::populate_field_ops( &mut new_byte_lookup_events, event.shard, + event.channel, cols, p_x, p_y, @@ -257,8 +288,11 @@ where // Populate the memory access columns. for i in 0..cols.p_access.len() { - cols.p_access[i] - .populate(event.p_memory_records[i], &mut new_byte_lookup_events); + cols.p_access[i].populate( + event.channel, + event.p_memory_records[i], + &mut new_byte_lookup_events, + ); } row }) @@ -280,7 +314,7 @@ where let cols: &mut WeierstrassDoubleAssignCols = row.as_mut_slice().borrow_mut(); let zero = BigUint::zero(); - Self::populate_field_ops(&mut vec![], 0, cols, zero.clone(), zero.clone()); + Self::populate_field_ops(&mut vec![], 0, 0, cols, zero.clone(), zero.clone()); row }); @@ -335,6 +369,7 @@ where &p_x, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); @@ -344,6 +379,7 @@ where &E::BaseField::to_limbs_field::(&BigUint::from(3u32)), FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); @@ -353,6 +389,7 @@ where &row.p_x_squared_times_3.result, FieldOperation::Add, row.shard, + row.channel, row.is_real, ); }; @@ -364,6 +401,7 @@ where &p_y, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); @@ -373,6 +411,7 @@ where &row.slope_denominator.result, FieldOperation::Div, row.shard, + row.channel, row.is_real, ); @@ -387,6 +426,7 @@ where slope, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); row.p_x_plus_p_x.eval( @@ -395,6 +435,7 @@ where &p_x, FieldOperation::Add, row.shard, + row.channel, row.is_real, ); row.x3_ins.eval( @@ -403,6 +444,7 @@ where &row.p_x_plus_p_x.result, FieldOperation::Sub, row.shard, + row.channel, row.is_real, ); &row.x3_ins.result @@ -416,6 +458,7 @@ where x, FieldOperation::Sub, row.shard, + row.channel, row.is_real, ); row.slope_times_p_x_minus_x.eval( @@ -424,6 +467,7 @@ where &row.p_x_minus_x.result, FieldOperation::Mul, row.shard, + row.channel, row.is_real, ); row.y3_ins.eval( @@ -432,6 +476,7 @@ where &p_y, FieldOperation::Sub, row.shard, + row.channel, row.is_real, ); } @@ -450,6 +495,7 @@ where builder.eval_memory_access_slice( row.shard, + row.channel, row.clk.into(), row.p_ptr, &row.p_access, @@ -470,6 +516,7 @@ where builder.receive_syscall( row.shard, + row.channel, row.clk, syscall_id_felt, row.p_ptr,