Skip to content

Commit

Permalink
reverse mode AD general compute graph in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed May 6, 2024
1 parent 05f476b commit 6d75c6b
Showing 1 changed file with 395 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,395 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Function is:\n",
"\n",
"$$ f(x_0, x_1) = \\sin(x_0) (x_0 + x_1) $$\n",
"\n",
"or broken down\n",
"\n",
"$$ \\begin{align}\n",
"z_0 &= x_0 \\\\\n",
"z_1 &= x_1 \\\\\n",
"z_2 &= \\sin(z_0) \\\\\n",
"z_3 &= z_0 + z_1 \\\\\n",
"z_4 &= z_2 z_3 \\\\\n",
"\\end{align} $$\n",
"\n",
"Its symbolic derivative is:\n",
"\n",
"$$ \\nabla f(x_0, x_1) = \\begin{bmatrix}\n",
"\\cos(x_0) (x_0 + x_1) + \\sin(x_0) \\\\\n",
"\\sin(x_0)\n",
"\\end{bmatrix} $$"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def f(x_0, x_1):\n",
" return np.sin(x_0) * (x_0 + x_1)\n",
"\n",
"def f_grad(x_0, x_1):\n",
" return np.array([\n",
" np.cos(x_0) * (x_0 + x_1) + np.sin(x_0),\n",
" np.sin(x_0),\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"compute_graph = [\n",
" (\"inp\", (0,)), # 0\n",
" (\"inp\", (1,)), # 1\n",
" (\"sin\", (0,)), # 2\n",
" (\"add\", (0, 1)), # 3\n",
" (\"mul\", (2, 3)), # 4\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"fn_library = {\n",
" \"inp\": lambda x: x,\n",
" \"sin\": lambda x: np.sin(x),\n",
" \"add\": lambda x, y: x + y,\n",
" \"mul\": lambda x, y: x * y,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def compute(graph, inputs):\n",
" values = list(inputs)\n",
" for operation, indices in graph:\n",
" if operation == \"inp\":\n",
" continue\n",
" args = [values[index] for index in indices]\n",
" result = fn_library[operation](*args)\n",
" values.append(result)\n",
" \n",
" return values[-1]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"SAMPLE_INPUT = (0.6, 1.4)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.1292849467900707"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f(*SAMPLE_INPUT)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.1292849467900707"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"compute(compute_graph, SAMPLE_INPUT)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def inp_backprop_rule(x):\n",
" z = x\n",
"\n",
" def inp_pullback(z_cotangent):\n",
" x_cotangent = z_cotangent\n",
" return (x_cotangent,)\n",
" \n",
" return z, inp_pullback\n",
"\n",
"def sin_backprop_rule(x):\n",
" z = np.sin(x)\n",
"\n",
" def sin_pullback(z_cotangent):\n",
" x_cotangent = np.cos(x) * z_cotangent\n",
" return (x_cotangent,)\n",
" \n",
" return z, sin_pullback\n",
"\n",
"def add_backprop_rule(x, y):\n",
" z = x + y\n",
"\n",
" def add_pullback(z_cotangent):\n",
" x_cotangent = z_cotangent\n",
" y_cotangent = z_cotangent\n",
"\n",
" return (x_cotangent, y_cotangent)\n",
" \n",
" return z, add_pullback\n",
"\n",
"def mul_backprop_rule(x, y):\n",
" z = x * y\n",
"\n",
" def mul_pullback(z_cotangent):\n",
" x_cotangent = y * z_cotangent\n",
" y_cotangent = x * z_cotangent\n",
" return (x_cotangent, y_cotangent)\n",
" \n",
" return z, mul_pullback"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"backprop_library = {\n",
" \"inp\": inp_backprop_rule,\n",
" \"sin\": sin_backprop_rule,\n",
" \"add\": add_backprop_rule,\n",
" \"mul\": mul_backprop_rule,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def vjp(graph, inputs):\n",
" values = list(inputs)\n",
" pullback_stack = []\n",
"\n",
" # Forward pass\n",
" for operation, indices in graph:\n",
" if operation == \"inp\":\n",
" continue\n",
" args = [values[index] for index in indices]\n",
" result, pullback_fn = backprop_library[operation](*args)\n",
" values.append(result)\n",
" pullback_stack.append((pullback_fn, indices))\n",
"\n",
" def pullback(output_cotangent):\n",
" cotangent_values = np.zeros(len(values))\n",
" cotangent_values[-1] = output_cotangent\n",
"\n",
" for i, (pullback_fn, indices) in enumerate(reversed(pullback_stack)):\n",
" current_cotangent_value = cotangent_values[-1 - i]\n",
" cotangent_args = pullback_fn(current_cotangent_value)\n",
" for index, cotangent in zip(indices, cotangent_args):\n",
" cotangent_values[index] += cotangent\n",
" \n",
" return cotangent_values[:len(inputs)]\n",
" \n",
" return values[-1], pullback\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"out, back_fn = vjp(compute_graph, SAMPLE_INPUT)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.1292849467900707"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([2.2153137 , 0.56464247])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"back_fn(1.0)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([2.2153137 , 0.56464247])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f_grad(*SAMPLE_INPUT)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def value_and_grad(graph, inputs):\n",
" out, back_fn = vjp(graph, inputs)\n",
" grad = back_fn(1.0)\n",
" return out, grad"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1.1292849467900707, array([2.2153137 , 0.56464247]))"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"value_and_grad(compute_graph, SAMPLE_INPUT)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1.1292849467900707, array([2.2153137 , 0.56464247]))"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f(*SAMPLE_INPUT), f_grad(*SAMPLE_INPUT)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 6d75c6b

Please sign in to comment.