-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Export / import functions to / from a file #1642
base: main
Are you sure you want to change the base?
Conversation
This is massively cool. I 'll get to reviewing asap! |
// - constants, which can be used directly | ||
// - a load primitive which has no inputs and will become a constant | ||
// after the first eval | ||
if (!a.has_primitive() || is_load(a.primitive())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is worth commenting on:
- Previously if you loaded arrays from a file inside a compiled function then every call to the function would reload from the file.
- Now only the first call to the function loads and after that the loaded arrays become constants in the tape
This seems better to me.. though perhaps that is debatable. It is also used by import_function
which makes Load
primitives for constants and so we get lazy loading even with import_function
which is pretty nice.
Lmk thoughts.. I can switch it so compile doesn't force load (more flexible but more dangerous).
mlx/export.h
Outdated
std::function<std::vector<array>(const std::vector<array>&)> import_function( | ||
std::string path); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another question: should import_function
return metadata? I can see how it would be useful to get say the shapes and/or dtypes of the inputs, maybe the MLX version, etc in a dict of metadata. Can also wait and see and provide an overload / a return_metadata
flag in the future.
@@ -2098,21 +2147,6 @@ class Tanh : public UnaryPrimitive { | |||
void eval(const std::vector<array>& inputs, array& out); | |||
}; | |||
|
|||
class Uniform : public UnaryPrimitive { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused 🤷♂️ ..
fd6520c
to
454f44c
Compare
1ffc7ae
to
b117dec
Compare
Adds
export_function
andimport_function
so that we can save and load functions from a file. Makes it possible to use functions written in one language from another language (e.g. Python -> C++).Basically works like so:
In Python:
Then in C++, for example:
Some notes on the implementation:
export.cpp
state
which returns the data to saveexport.cpp
. But didn't want to obfuscate / over engineer it much yet until getting some input.