From 9c782539371d977994319da05741a16c458a2a96 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 3 Dec 2024 09:12:48 -0800 Subject: [PATCH] same result for python and c++ --- examples/export/train_mlp.cpp | 1 - examples/export/train_mlp.py | 5 ++--- python/tests/test_export_import.py | 18 ++++++++++++++++++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/examples/export/train_mlp.cpp b/examples/export/train_mlp.cpp index 7176cbbec..c3d516e9e 100644 --- a/examples/export/train_mlp.cpp +++ b/examples/export/train_mlp.cpp @@ -11,7 +11,6 @@ int main() { int output_dim = 10; auto state = import_function("init_mlp.mlxfn")({}); - std::cout << state[0] << std::endl; // Make the input random::seed(42); diff --git a/examples/export/train_mlp.py b/examples/export/train_mlp.py index 5deea0b4e..8fd686dd7 100644 --- a/examples/export/train_mlp.py +++ b/examples/export/train_mlp.py @@ -31,10 +31,9 @@ def __call__(self, x): input_dim = 32 output_dim = 10 - # Seed for the parameter initialization - mx.random.seed(0) - def init(): + # Seed for the parameter initialization + mx.random.seed(0) model = MLP( num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim ) diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index a32a50ef9..a4227fb9f 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -131,6 +131,24 @@ def fun(x, y): with self.assertRaises(ValueError): imported(mx.array(1.0), [mx.array(1.0)]) + def test_export_random_sample(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + mx.random.seed(5) + + def fun(): + return mx.random.uniform(shape=(3,)) + + mx.export_function(path, fun) + imported = mx.import_function(path) + + (out,) = imported() + + mx.random.seed(5) + expected = fun() + + self.assertTrue(mx.array_equal(out, expected)) + if __name__ == "__main__": unittest.main()