Skip to content

Commit

Permalink
same result for python and c++
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Dec 3, 2024
1 parent 365d8c5 commit 9c78253
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
1 change: 0 additions & 1 deletion examples/export/train_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ int main() {
int output_dim = 10;

auto state = import_function("init_mlp.mlxfn")({});
std::cout << state[0] << std::endl;

// Make the input
random::seed(42);
Expand Down
5 changes: 2 additions & 3 deletions examples/export/train_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def __call__(self, x):
input_dim = 32
output_dim = 10

# Seed for the parameter initialization
mx.random.seed(0)

def init():
# Seed for the parameter initialization
mx.random.seed(0)
model = MLP(
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
)
Expand Down
18 changes: 18 additions & 0 deletions python/tests/test_export_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,24 @@ def fun(x, y):
with self.assertRaises(ValueError):
imported(mx.array(1.0), [mx.array(1.0)])

def test_export_random_sample(self):
path = os.path.join(self.test_dir, "fn.mlxfn")

mx.random.seed(5)

def fun():
return mx.random.uniform(shape=(3,))

mx.export_function(path, fun)
imported = mx.import_function(path)

(out,) = imported()

mx.random.seed(5)
expected = fun()

self.assertTrue(mx.array_equal(out, expected))


if __name__ == "__main__":
unittest.main()

0 comments on commit 9c78253

Please sign in to comment.