diff --git a/Batteries.lean b/Batteries.lean index 6c24d239bc..7ef46d3189 100644 --- a/Batteries.lean +++ b/Batteries.lean @@ -29,6 +29,7 @@ import Batteries.Data.MLList import Batteries.Data.Nat import Batteries.Data.PairingHeap import Batteries.Data.RBMap +import Batteries.Data.Random import Batteries.Data.Range import Batteries.Data.Rat import Batteries.Data.Stream diff --git a/Batteries/Data/Random.lean b/Batteries/Data/Random.lean new file mode 100644 index 0000000000..cf1e720ee0 --- /dev/null +++ b/Batteries/Data/Random.lean @@ -0,0 +1 @@ +import Batteries.Data.Random.MersenneTwister diff --git a/Batteries/Data/Random/MersenneTwister.lean b/Batteries/Data/Random/MersenneTwister.lean new file mode 100644 index 0000000000..70a6f86042 --- /dev/null +++ b/Batteries/Data/Random/MersenneTwister.lean @@ -0,0 +1,158 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ +import Batteries.Data.Vector + +/-! # Mersenne Twister + +Generic implementation for the Mersenne Twister pseudorandom number generator. + +All choices of parameters from Matsumoto and Nishimura (1998) are supported, along with later +refinements. Parameters for the standard 32-bit MT19937 and 64-bit MT19937-64 algorithms are +provided. Both `RandomGen` and `Stream` interfaces are provided. + +Use `mt19937.init seed` to create a MT19937 PRNG with a 32 bit seed value; use +`mt19937_64.init seed` to create a MT19937-64 PRNG with a 64 bit seed value. If omitted, default +seed choices will be used. + +Sample usage: +``` +import Batteries.Data.Random.MersenneTwister + +open Batteries.Random.MersenneTwister + +def mtgen := mt19937.init -- default seed 4357 + +#eval (Stream.take mtgen 5).fst -- [874448474, 2424656266, 2174085406, 1265871120, 3155244894] +``` + +### References: + +- Matsumoto, Makoto and Nishimura, Takuji (1998), + [**Mersenne twister: A 623-dimensionally equidistributed uniform pseudo-random number generator**](https://doi.org/10.1145/272991.272995), + ACM Trans. Model. Comput. Simul. 8, No. 1, 3-30. + [ZBL0917.65005](https://zbmath.org/?q=an:0917.65005). + +- Nishimura, Takuji (2000), + [**Tables of 64-bit Mersenne twisters**](https://doi.org/10.1145/369534.369540), + ACM Trans. Model. Comput. Simul. 10, No. 4, 348-357. + [ZBL1390.65014](https://zbmath.org/?q=an:1390.65014). +-/ + +namespace Batteries.Random.MersenneTwister + +/-- +Mersenne Twister configuration. + +Letters in parentheses correspond to variable names used by Matsumoto and Nishimura (1998) and +Nishimura (2000). +-/ +structure Config where + /-- Word size (`w`). -/ + wordSize : Nat + /-- Degree of recurrence (`n`). -/ + stateSize : Nat + /-- Middle word (`m`). -/ + shiftSize : Fin stateSize + /-- Twist value (`r`). -/ + maskBits : Fin wordSize + /-- Coefficients of the twist matrix (`a`). -/ + xorMask : BitVec wordSize + /-- Tempering shift parameters (`u`, `s`, `t`, `l`). -/ + temperingShifts : Nat × Nat × Nat × Nat + /-- Tempering mask parameters (`d`, `b`, `c`). -/ + temperingMasks : BitVec wordSize × BitVec wordSize × BitVec wordSize + /-- Initialization multiplier (`f`). -/ + initMult : BitVec wordSize + /-- Default initialization seed value. -/ + initSeed : BitVec wordSize + +private abbrev Config.uMask (cfg : Config) : BitVec cfg.wordSize := + BitVec.allOnes cfg.wordSize <<< cfg.maskBits.val + +private abbrev Config.lMask (cfg : Config) : BitVec cfg.wordSize := + BitVec.allOnes cfg.wordSize >>> (cfg.wordSize - cfg.maskBits.val) + +@[simp] theorem Config.zero_lt_wordSize (cfg : Config) : 0 < cfg.wordSize := + Nat.zero_lt_of_lt cfg.maskBits.is_lt + +@[simp] theorem Config.zero_lt_stateSize (cfg : Config) : 0 < cfg.stateSize := + Nat.zero_lt_of_lt cfg.shiftSize.is_lt + +/-- Mersenne Twister State. -/ +structure State (cfg : Config) where + /-- Data for current state. -/ + data : Vector (BitVec cfg.wordSize) cfg.stateSize + /-- Current data index. -/ + index : Fin cfg.stateSize + +/-- Mersenne Twister initialization given an optional seed. -/ +@[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config) + (seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg := + ⟨loop seed (.mkEmpty cfg.stateSize) (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩ +where + /-- Inner loop for Mersenne Twister initalization. -/ + loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) := + if heq : v.size = cfg.stateSize then ⟨v, heq⟩ else + let v := v.push w + let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size + loop w v (by simp only [v, Array.size_push]; omega) + +/-- Apply the twisting transformation to the given state. -/ +@[specialize cfg] protected def State.twist (state : State cfg) : State cfg := + let i := state.index + let i' : Fin cfg.stateSize := + if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩ + let y := state.data[i] &&& cfg.uMask ||| state.data[i'] &&& cfg.lMask + let x := state.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1 + ⟨state.data.set i x, i'⟩ + +/-- Update the state by a number of generation steps (default 1). -/ +-- TODO: optimize to `O(log(steps))` using the minimal polynomial +protected def State.update (state : State cfg) : (steps : Nat := 1) → State cfg + | 0 => state + | steps+1 => state.twist.update steps + +/-- Mersenne Twister iteration. -/ +@[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg := + let i := state.index + let s := state.twist + (temper s.data[i], s) +where + /-- Tempering step for Mersenne Twister. -/ + @[inline] temper (x : BitVec cfg.wordSize) := + match cfg.temperingShifts, cfg.temperingMasks with + | (u, s, t, l), (d, b, c) => + let x := x ^^^ x >>> u &&& d + let x := x ^^^ x <<< s &&& b + let x := x ^^^ x <<< t &&& c + x ^^^ x >>> l + +instance (cfg) : Stream (State cfg) (BitVec cfg.wordSize) where + next? s := s.next + +/-- 32 bit Mersenne Twister (MT19937) configuration. -/ +def mt19937 : Config where + wordSize := 32 + stateSize := 624 + shiftSize := 397 + maskBits := 31 + xorMask := 0x9908b0df + temperingShifts := (11, 7, 15, 18) + temperingMasks := (0xffffffff, 0x9d2c5680, 0xefc60000) + initMult := 1812433253 + initSeed := 4357 + +/-- 64 bit Mersenne Twister (MT19937-64) configuration. -/ +def mt19937_64 : Config where + wordSize := 64 + stateSize := 312 + shiftSize := 156 + maskBits := 31 + xorMask := 0xb5026f5aa96619e9 + temperingShifts := (29, 17, 37, 43) + temperingMasks := (0x5555555555555555, 0x71d67fffeda60000, 0xfff7eee000000000) + initMult := 6364136223846793005 + initSeed := 19650218 diff --git a/BatteriesTest/mersenne_twister.lean b/BatteriesTest/mersenne_twister.lean new file mode 100644 index 0000000000..3d42b6a61e --- /dev/null +++ b/BatteriesTest/mersenne_twister.lean @@ -0,0 +1,85 @@ +import Batteries.Data.Random.MersenneTwister +import Batteries.Data.Stream + +open Batteries.Random.MersenneTwister + +#guard (Stream.take mt19937.init 5).1 == [874448474, 2424656266, 2174085406, 1265871120, 3155244894] + +/- Sample output was generated using `numpy`'s implementation of MT19937: +```python +from numpy import array, uint32 +from numpy.random import MT19937 + +mt = MT19937() +mt.state = { + 'bit_generator' : 'MT19937', + 'state' : { + 'pos' : 624, + 'key' : array([ + 4357, 1673174024, 1301878288, 1129097449, 2180885271, 2495295730, 3729202114, 3451529139, 2624228201, 696045212, + 2296245684, 4097888573, 2110311931, 1672374534, 381896678, 2887874951, 3859861197, 420983856, 1691952728, 4233606289, + 1707944415, 3515687962, 4265198858, 1433261659, 1131854641, 228846788, 3811811324, 873525989, 588291779, 2854617646, + 948269870, 3798261295, 3422826645, 340138072, 3671734944, 3961007161, 2839350439, 3264455490, 310719058, 2570596611, + 3750039289, 648992492, 3816674884, 2210726029, 371217291, 196912982, 3046892150, 470118103, 1302935133, 362465408, + 1360220904, 2946174945, 1630294895, 3570642538, 1798333338, 1196832683, 226789057, 2740096276, 1062441100, 1875507765, + 2599873619, 1037523070, 4029519294, 3231722367, 2232344613, 3458909352, 2906353456, 3064815497, 3166305847, + 3658630546, 3632421090, 885320275, 1621369481, 1258557244, 2827734740, 3209486301, 131295515, 2191201702, 44141830, + 1183978535, 4202966509, 801836240, 2303299448, 333191985, 4114943231, 1490315450, 453120554, 759253243, 1381163601, + 3455606116, 1027445020, 1144697221, 3040135651, 4176273102, 798935118, 49817807, 2492997557, 3171983608, 2742334400, + 1282687705, 1047297991, 3697219554, 1400278898, 3276297123, 843040281, 354711436, 4156544868, 2873126701, 3990490795, + 3966874614, 1376536470, 4189022583, 2283386237, 3645931808, 1312021512, 679663233, 3054458511, 1152865034, 1927729338, + 538380875, 374984161, 2453495220, 514433452, 1271601365, 3737270131, 630101278, 1292962526, 2908018207, 1209528133, + 413117768, 3762161744, 2194986537, 1414304087, 379722290, 2862208514, 3551161587, 3402627497, 2411204572, 3033657332, + 4161252989, 2267825211, 963150406, 2081690150, 4014304967, 1977732365, 2412979568, 613038232, 418857425, 3682807839, + 3416550746, 3692470090, 2764012443, 3255912817, 2160692740, 3914318396, 3437441061, 2828481795, 3655629678, 582770030, + 2946380655, 3506851541, 612362648, 3394202848, 1530337657, 3360830183, 570641538, 153365650, 1624454723, 80526649, + 1365694508, 2272925828, 34250189, 3066169803, 631734422, 3706776758, 3443270679, 659846301, 3707435456, 3573851432, + 1017208097, 1100519855, 1824765866, 3284762074, 2887949547, 569464065, 3057970772, 1726477004, 3119183733, 3349922451, + 4162228670, 249085950, 3854319807, 1155219045, 811161064, 207675760, 50531529, 141911159, 3819613906, 2655884066, + 3517624211, 514724041, 2094583932, 3681571092, 3518053661, 2207473499, 961982182, 1423628102, 628853095, 3823741997, + 1450180112, 1817911736, 384378993, 1749521215, 4080873978, 2604100714, 2468900411, 1718743185, 3679944356, 623522652, + 2974445253, 351789091, 776787982, 4087231118, 395771407, 2634989045, 2547249720, 2502583808, 3550523417, 648947207, + 2361409826, 2639137202, 4179155171, 3136025689, 3233151180, 3765213604, 459508845, 412632299, 3365801270, 1208603094, + 1978375863, 3608769469, 2648322656, 994422344, 1463198657, 1938300111, 1983437898, 3617090298, 582545291, 604707873, + 615071476, 1976468460, 4251555349, 2373160371, 4138683998, 927249694, 4178996063, 3071856005, 3264724616, 2539911824, + 1383596905, 3639900055, 2590770034, 1029541954, 369472051, 3757991913, 1470517532, 2317808180, 1065978813, 3301489275, + 4087716742, 2662718566, 678716423, 274451277, 1625396912, 3598469848, 3639725841, 726808159, 1490990746, 4062476682, + 2411471067, 1395972017, 1390554948, 1854727292, 2494590309, 1377225539, 2540041390, 3288614830, 706906287, 1416719637, + 609008344, 2311429920, 821102265, 2034260263, 3587569090, 3115591378, 3545840515, 4166871929, 139581804, 2421643972, + 1250638605, 4212965387, 2794805718, 3306616566, 2466109783, 2200482525, 1496197888, 381089640, 2743249505, 4221427695, + 1247199466, 1746114586, 2065302059, 1348936513, 2997505940, 3911013644, 428274869, 2816055507, 580438782, 135588414, + 916674047, 445684901, 1016784680, 654791600, 1282652681, 92916407, 1411782674, 1367985506, 1207661779, 3531669257, + 627085756, 1857409876, 4107311709, 1384928667, 2576697382, 2875531654, 4151312039, 116927085, 1281879888, 414036984, + 3931190705, 4100135295, 1170799418, 3130902186, 4055536507, 3692691153, 480878564, 2201474460, 3663014917, 4155766371, + 1987039566, 4121861326, 2525025103, 2465094709, 2536129400, 1843468352, 2926058841, 533253191, 1988389474, 1209435122, + 4141112867, 2699109017, 2373614092, 1694129124, 2730600877, 2249161515, 1355638390, 3319290902, 2209534967, + 1463955965, 204923808, 1025015944, 214266113, 3382305551, 2455594378, 1861944634, 1820710091, 449145441, 4119339060, + 2660525612, 3515028309, 3466454003, 1024657310, 50945886, 2913140895, 721595333, 3416444872, 2701847760, 2352361641, + 234184151, 3927502002, 3834792578, 3469473651, 4193637929, 2873594460, 1994191988, 1690724605, 1956524219, 476427462, + 212379302, 1370380615, 327076237, 1984104432, 682581272, 2521259089, 3543809183, 3275489242, 241390538, 3496199707, + 2497799665, 770560132, 1626015420, 2776148645, 3717161347, 3970592238, 710750702, 3421625839, 876972885, 2108460056, + 1195168096, 1195766777, 3121053543, 2819333890, 1916084498, 717897923, 3627489721, 1970264748, 1813355780, 4148615245, + 556824139, 411448086, 4228776246, 1732939415, 3206934813, 1949588544, 3291105704, 1044314017, 222045743, 3079457322, + 638497370, 1849452395, 921039233, 1115861204, 3019093836, 2828923381, 4185943827, 3344827454, 3923907710, 760572735, + 3828284133, 1559197800, 724485616, 1828677449, 2985767159, 4119101778, 1077348258, 3518446099, 2585587017, 1855673084, + 3495712148, 3265984413, 2998815707, 760668518, 2487249862, 3060757479, 3249514669, 4222804112, 1010910776, 3893641969, + 395812799, 2591540346, 1194664170, 49789115, 1363873041, 1005502756, 1164343260, 3646613829, 459869347, 3679832718, + 1137706766, 4189431951, 1412889205, 622040248, 1536739968, 3066727065, 666661511, 1672188834, 2714762802, 4135248739, + 35606745, 2775710540, 4083752484, 3680159469, 1950331243, 251641782, 1501029974, 486869303, 1720971325, 241603808, + 28070600, 2737782337, 910469455, 3810848458, 118398842, 3078470155, 2559096993, 2933522804, 2264615020, 3793195157, + 1614887475, 45727966, 3193899422, 1157273055, 2178255365, 2646663432, 724754192, 168779241, 4048503831, 3483948530, + 3996648642, 939343027, 917914729, 3030111132, 3908302516, 29247037, 3568084731, 1034472966, 1408004326, 1693666951, + 3712665549, 3120003376, 3374542680, 2868373905, 1362838239, 1421625626, 4275252746, 548825947, 622261297, 3152835012, + 2926192892, 423356389, 151058371, 3820087086, 1673993262, 252457775, 1317185941, 2594135384, 817169312, 2016796985, + 2292688295, 1654933570, 2158435154, 2703640067, 3260663801, 3267419116, 2293555012, 2721936781, 1727868043, 91884630, + 265685878, 1143096279, 961294173, 403541376, 2338233320, 1725318369, 4101205103, 4268086122, 3418016922, 1065995435, + 1936572353, 265163284, 3043694988, 2167402293, 2057323859, 4033232254, 3258990270, 1137868927, 2142656805, 4216785320, + 1188509744, 1051071625, 196974391, 2445666962, 3092595170, 2833121107, 2474761097, 2190021692, 1852037076, 3577763037, + 3794354715, 2124118694, 2641147398, 1551493415, 1913661165, 1313919440, 2232801400, 1781682225, 1340417535, 994676154, + 251493162, 2162155003, 1678056273, 3810976356, 1505106460, 3361449605, 1041703651, 1727972302, 3959583054, 3140845007, + 3202914485, 2878334456, 2354150592, 3334993881, 1015617735, 506838242, 4168775794, 839674019, 4238769945, 849116300, + 4189642852, 1596908589, 556328875, 2369067254, 2431152278, 1004682871], dtype=uint32)}} + +print(mt.random_raw(5)) +``` +-/