-
-
Notifications
You must be signed in to change notification settings - Fork 188
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
reverse mode AD general compute graph in Python
- Loading branch information
Showing
1 changed file
with
395 additions
and
0 deletions.
There are no files selected for viewing
395 changes: 395 additions & 0 deletions
395
...djoints_sensitivities_automatic_differentiation/simple_autodiff_external_static_dag.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,395 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Function is:\n", | ||
"\n", | ||
"$$ f(x_0, x_1) = \\sin(x_0) (x_0 + x_1) $$\n", | ||
"\n", | ||
"or broken down\n", | ||
"\n", | ||
"$$ \\begin{align}\n", | ||
"z_0 &= x_0 \\\\\n", | ||
"z_1 &= x_1 \\\\\n", | ||
"z_2 &= \\sin(z_0) \\\\\n", | ||
"z_3 &= z_0 + z_1 \\\\\n", | ||
"z_4 &= z_2 z_3 \\\\\n", | ||
"\\end{align} $$\n", | ||
"\n", | ||
"Its symbolic derivative is:\n", | ||
"\n", | ||
"$$ \\nabla f(x_0, x_1) = \\begin{bmatrix}\n", | ||
"\\cos(x_0) (x_0 + x_1) + \\sin(x_0) \\\\\n", | ||
"\\sin(x_0)\n", | ||
"\\end{bmatrix} $$" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def f(x_0, x_1):\n", | ||
" return np.sin(x_0) * (x_0 + x_1)\n", | ||
"\n", | ||
"def f_grad(x_0, x_1):\n", | ||
" return np.array([\n", | ||
" np.cos(x_0) * (x_0 + x_1) + np.sin(x_0),\n", | ||
" np.sin(x_0),\n", | ||
" ])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"compute_graph = [\n", | ||
" (\"inp\", (0,)), # 0\n", | ||
" (\"inp\", (1,)), # 1\n", | ||
" (\"sin\", (0,)), # 2\n", | ||
" (\"add\", (0, 1)), # 3\n", | ||
" (\"mul\", (2, 3)), # 4\n", | ||
"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"fn_library = {\n", | ||
" \"inp\": lambda x: x,\n", | ||
" \"sin\": lambda x: np.sin(x),\n", | ||
" \"add\": lambda x, y: x + y,\n", | ||
" \"mul\": lambda x, y: x * y,\n", | ||
"}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def compute(graph, inputs):\n", | ||
" values = list(inputs)\n", | ||
" for operation, indices in graph:\n", | ||
" if operation == \"inp\":\n", | ||
" continue\n", | ||
" args = [values[index] for index in indices]\n", | ||
" result = fn_library[operation](*args)\n", | ||
" values.append(result)\n", | ||
" \n", | ||
" return values[-1]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"SAMPLE_INPUT = (0.6, 1.4)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"1.1292849467900707" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"f(*SAMPLE_INPUT)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"1.1292849467900707" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"compute(compute_graph, SAMPLE_INPUT)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def inp_backprop_rule(x):\n", | ||
" z = x\n", | ||
"\n", | ||
" def inp_pullback(z_cotangent):\n", | ||
" x_cotangent = z_cotangent\n", | ||
" return (x_cotangent,)\n", | ||
" \n", | ||
" return z, inp_pullback\n", | ||
"\n", | ||
"def sin_backprop_rule(x):\n", | ||
" z = np.sin(x)\n", | ||
"\n", | ||
" def sin_pullback(z_cotangent):\n", | ||
" x_cotangent = np.cos(x) * z_cotangent\n", | ||
" return (x_cotangent,)\n", | ||
" \n", | ||
" return z, sin_pullback\n", | ||
"\n", | ||
"def add_backprop_rule(x, y):\n", | ||
" z = x + y\n", | ||
"\n", | ||
" def add_pullback(z_cotangent):\n", | ||
" x_cotangent = z_cotangent\n", | ||
" y_cotangent = z_cotangent\n", | ||
"\n", | ||
" return (x_cotangent, y_cotangent)\n", | ||
" \n", | ||
" return z, add_pullback\n", | ||
"\n", | ||
"def mul_backprop_rule(x, y):\n", | ||
" z = x * y\n", | ||
"\n", | ||
" def mul_pullback(z_cotangent):\n", | ||
" x_cotangent = y * z_cotangent\n", | ||
" y_cotangent = x * z_cotangent\n", | ||
" return (x_cotangent, y_cotangent)\n", | ||
" \n", | ||
" return z, mul_pullback" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"backprop_library = {\n", | ||
" \"inp\": inp_backprop_rule,\n", | ||
" \"sin\": sin_backprop_rule,\n", | ||
" \"add\": add_backprop_rule,\n", | ||
" \"mul\": mul_backprop_rule,\n", | ||
"}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def vjp(graph, inputs):\n", | ||
" values = list(inputs)\n", | ||
" pullback_stack = []\n", | ||
"\n", | ||
" # Forward pass\n", | ||
" for operation, indices in graph:\n", | ||
" if operation == \"inp\":\n", | ||
" continue\n", | ||
" args = [values[index] for index in indices]\n", | ||
" result, pullback_fn = backprop_library[operation](*args)\n", | ||
" values.append(result)\n", | ||
" pullback_stack.append((pullback_fn, indices))\n", | ||
"\n", | ||
" def pullback(output_cotangent):\n", | ||
" cotangent_values = np.zeros(len(values))\n", | ||
" cotangent_values[-1] = output_cotangent\n", | ||
"\n", | ||
" for i, (pullback_fn, indices) in enumerate(reversed(pullback_stack)):\n", | ||
" current_cotangent_value = cotangent_values[-1 - i]\n", | ||
" cotangent_args = pullback_fn(current_cotangent_value)\n", | ||
" for index, cotangent in zip(indices, cotangent_args):\n", | ||
" cotangent_values[index] += cotangent\n", | ||
" \n", | ||
" return cotangent_values[:len(inputs)]\n", | ||
" \n", | ||
" return values[-1], pullback\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"out, back_fn = vjp(compute_graph, SAMPLE_INPUT)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"1.1292849467900707" | ||
] | ||
}, | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"out" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"array([2.2153137 , 0.56464247])" | ||
] | ||
}, | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"back_fn(1.0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"array([2.2153137 , 0.56464247])" | ||
] | ||
}, | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"f_grad(*SAMPLE_INPUT)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def value_and_grad(graph, inputs):\n", | ||
" out, back_fn = vjp(graph, inputs)\n", | ||
" grad = back_fn(1.0)\n", | ||
" return out, grad" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(1.1292849467900707, array([2.2153137 , 0.56464247]))" | ||
] | ||
}, | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"value_and_grad(compute_graph, SAMPLE_INPUT)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 19, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(1.1292849467900707, array([2.2153137 , 0.56464247]))" | ||
] | ||
}, | ||
"execution_count": 19, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"f(*SAMPLE_INPUT), f_grad(*SAMPLE_INPUT)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "base", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |