From 616ae867278ba70cc2581b1563f0526b7db229ef Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Sat, 12 Oct 2024 19:50:59 -0400 Subject: [PATCH] refactor: state update step --- Batteries/Data/Random/MersenneTwister.lean | 27 +++++++++++----------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/Batteries/Data/Random/MersenneTwister.lean b/Batteries/Data/Random/MersenneTwister.lean index 0425367a98..b63c05b99e 100644 --- a/Batteries/Data/Random/MersenneTwister.lean +++ b/Batteries/Data/Random/MersenneTwister.lean @@ -72,7 +72,7 @@ structure State (cfg : Config) where /-- Mersenne Twister initialization given an optional seed. -/ @[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config) (seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg := - ⟨loop seed #[] (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩ + ⟨loop seed (.mkEmpty cfg.stateSize) (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩ where /-- Inner loop for Mersenne Twister initalization. -/ loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) := @@ -81,24 +81,23 @@ where let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size loop w v (by simp only [v, Array.size_push]; omega) +/-- Apply the twisting transformation to the given state. -/ +@[specialize cfg] protected def State.twist (state : State cfg) : State cfg := + let i := state.index + let i' : Fin cfg.stateSize := + if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩ + let y := state.data[i] &&& cfg.uMask ||| state.data[i'] &&& cfg.lMask + let x := state.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1 + ⟨state.data.set i x, i'⟩ + /-- Update the state by a number of generation steps (default 1). -/ -@[specialize cfg] protected def State.update (state : State cfg) (steps := 1) : State cfg := - loop state steps -where - /-- Inner loop for Mersenne Twister update. -/ - @[inline] loop (s : State cfg) (c : Nat) : State cfg := - if c = 0 then s else - let i := s.index - let i' : Fin cfg.stateSize := - if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩ - let y := s.data[i] &&& cfg.uMask ||| s.data[i'] &&& cfg.lMask - let x := s.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1 - loop ⟨s.data.set i x, i'⟩ (c-1) +@[inline] protected def State.update (state : State cfg) (steps := 1) : State cfg := + if steps = 0 then state else state.twist.update (steps-1) /-- Mersenne Twister iteration. -/ @[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg := let i := state.index - let s := state.update + let s := state.twist (temper s.data[i], s) where /-- Tempering step for Mersenne Twister. -/