diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index d9af8e76e..99970f356 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -112,6 +112,7 @@ pub struct DecodedInstruction { #[derive(Clone, Copy, Debug)] pub enum InsnCategory { Compute, + Branch, Load, Store, System, @@ -418,12 +419,12 @@ const RV32IM_ISA: InstructionTable = [ insn(I, SRAI, Compute, 0x13, 0x5, 0x20), insn(I, SLTI, Compute, 0x13, 0x2, -1), insn(I, SLTIU, Compute, 0x13, 0x3, -1), - insn(B, BEQ, Compute, 0x63, 0x0, -1), - insn(B, BNE, Compute, 0x63, 0x1, -1), - insn(B, BLT, Compute, 0x63, 0x4, -1), - insn(B, BGE, Compute, 0x63, 0x5, -1), - insn(B, BLTU, Compute, 0x63, 0x6, -1), - insn(B, BGEU, Compute, 0x63, 0x7, -1), + insn(B, BEQ, Branch, 0x63, 0x0, -1), + insn(B, BNE, Branch, 0x63, 0x1, -1), + insn(B, BLT, Branch, 0x63, 0x4, -1), + insn(B, BGE, Branch, 0x63, 0x5, -1), + insn(B, BLTU, Branch, 0x63, 0x6, -1), + insn(B, BGEU, Branch, 0x63, 0x7, -1), insn(J, JAL, Compute, 0x6f, -1, -1), insn(I, JALR, Compute, 0x67, 0x0, -1), insn(U, LUI, Compute, 0x37, -1, -1), @@ -556,6 +557,7 @@ impl Emulator { if match insn.category { InsnCategory::Compute => self.step_compute(ctx, insn.kind, &decoded)?, + InsnCategory::Branch => self.step_branch(ctx, insn.kind, &decoded)?, InsnCategory::Load => self.step_load(ctx, insn.kind, &decoded)?, InsnCategory::Store => self.step_store(ctx, insn.kind, &decoded)?, InsnCategory::System => self.step_system(ctx, insn.kind, &decoded)?, @@ -577,15 +579,7 @@ impl Emulator { let pc = ctx.get_pc(); let mut new_pc = pc + WORD_SIZE; - let mut rd = decoded.rd; let imm_i = decoded.imm_i(); - let mut br_cond = |cond| -> u32 { - rd = 0; - if cond { - new_pc = pc.wrapping_add(decoded.imm_b()); - } - 0 - }; let out = match kind { // Instructions that do not read rs1 nor rs2. JAL => { @@ -653,12 +647,6 @@ impl Emulator { 0 } } - BEQ => br_cond(rs1 == rs2), - BNE => br_cond(rs1 != rs2), - BLT => br_cond((rs1 as i32) < (rs2 as i32)), - BGE => br_cond((rs1 as i32) >= (rs2 as i32)), - BLTU => br_cond(rs1 < rs2), - BGEU => br_cond(rs1 >= rs2), MUL => rs1.wrapping_mul(rs2), MULH => { (sign_extend_u32(rs1).wrapping_mul(sign_extend_u32(rs2)) >> 32) @@ -704,7 +692,42 @@ impl Emulator { if !new_pc.is_aligned() { return ctx.trap(TrapCause::InstructionAddressMisaligned); } - ctx.store_register(rd as usize, out)?; + ctx.store_register(decoded.rd as usize, out)?; + ctx.set_pc(new_pc); + Ok(true) + } + + fn step_branch( + &self, + ctx: &mut M, + kind: InsnKind, + decoded: &DecodedInstruction, + ) -> Result { + use InsnKind::*; + + let pc = ctx.get_pc(); + let rs1 = ctx.load_register(decoded.rs1 as RegIdx)?; + let rs2 = ctx.load_register(decoded.rs2 as RegIdx)?; + + let taken = match kind { + BEQ => rs1 == rs2, + BNE => rs1 != rs2, + BLT => (rs1 as i32) < (rs2 as i32), + BGE => (rs1 as i32) >= (rs2 as i32), + BLTU => rs1 < rs2, + BGEU => rs1 >= rs2, + _ => unreachable!("Illegal branch instruction: {:?}", kind), + }; + + let new_pc = if taken { + pc.wrapping_add(decoded.imm_b()) + } else { + pc + WORD_SIZE + }; + + if !new_pc.is_aligned() { + return ctx.trap(TrapCause::InstructionAddressMisaligned); + } ctx.set_pc(new_pc); Ok(true) }