Skip to content

Custom callable function from within the C++ API #1614

Answered by awni
polvalente asked this question in Q&A
Discussion options

You must be logged in to vote

I'm not familiar at all with Nx and how compilation works there but I can say a bit more about how it works in MLX:

def fun(a, b, c):
  return a + b + c

# Step 0: Nothing much has happened here yet other than wrapping `fun` 
# in another function which knows to compile it
compiled_fun = mx.compile(fun)

# Step 1: The first time the compiled function is called it gets partially compiled. We trace the graph
# using the provided inputs and do some optimization passes on the graph
out = compiled_fun(a, b, c)

# Step 2: The rest of the compilation happens the first time you call eval. This is where
# kernel source is actually JIT compiled
eval(out)

# Calling it again on inputs with the same …

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@polvalente
Comment options

@polvalente
Comment options

@awni
Comment options

awni Nov 22, 2024
Maintainer

@polvalente
Comment options

Answer selected by polvalente
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants