Skip to content

Commit

Permalink
template it up
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Dec 19, 2024
1 parent 7e8a184 commit 76b4236
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 227 deletions.
4 changes: 1 addition & 3 deletions examples/export/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
cmake_minimum_required(VERSION 3.27)

project(_ext LANGUAGES CXX)
project(import_mlx LANGUAGES CXX)

# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

option(BUILD_SHARED_LIBS "Build as a shared library" ON)

# ----------------------------- Dependencies -----------------------------
find_package(
Python 3.9
Expand Down
27 changes: 23 additions & 4 deletions examples/export/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Setup

Install mlx:
Install MLX:

```bash
pip install mlx>=0.22
Expand All @@ -15,16 +15,35 @@ cmake --build build

## Run

Run the Python script to export the function.
### Eval MLP

Run the Python script to export the eval function:

```bash
python eval_mlp.py
```

Then run the C++ program to import and run the function.
Then run the C++ program to import and run the function:

```
./build/eval_mlp
```

The two programs should output the same result.
The Python and C++ programs should output the same result.

### Train MLP

Run the Python script to export the model initialization and training
functions:

```bash
python train_mlp.py
```

Then run the C++ program to import and run the functions:

```
./build/train_mlp
```

The Python and C++ programs should output the same results.
12 changes: 5 additions & 7 deletions mlx/compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ bool is_fusable(const Primitive& p) {
is_noop(p);
}

bool has_no_inputs(const Primitive& p) {
return typeid(p) == typeid(Load) || typeid(p) == typeid(Arange);
}

bool allows_shapeless(const Primitive& p) {
return typeid(p) == typeid(Arange) || typeid(p) == typeid(Compiled) ||
is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) ||
Expand Down Expand Up @@ -750,13 +746,15 @@ std::vector<array> compile_replace(
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
}

auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); };

for (auto& a : tape) {
// Arrays in the tape without primitives are either:
// - inputs, which are already in the map
// - constants, which can be used directly
// - primitives with no inputs which will become constants after the first
// eval
if (!a.has_primitive() || has_no_inputs(a.primitive())) {
// - a load primitive which has no inputs and will become a constant
// after the first eval
if (!a.has_primitive() || is_load(a.primitive())) {
trace_to_real.insert({a.id(), a});
} else {
// Find real inputs
Expand Down
Loading

0 comments on commit 76b4236

Please sign in to comment.