diff --git a/transformers_learn_icl_by_gd/constructed_token_setup.ipynb b/transformers_learn_icl_by_gd/constructed_token_setup.ipynb index ba2542e..9a46f70 100644 --- a/transformers_learn_icl_by_gd/constructed_token_setup.ipynb +++ b/transformers_learn_icl_by_gd/constructed_token_setup.ipynb @@ -3,6 +3,7 @@ { "cell_type": "markdown", "metadata": { + "cellView": "form", "id": "nN2OaMWcczfa" }, "source": [ @@ -10,6 +11,17 @@ "This specific notebook can be used to reproduce the results shown in the paper when using the specific token construction i.e. concatinate input and targets i.e. $e_i = (x_i,y_i)$." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Install requirements\n", + "!wget -O requirements.txt https://raw.githubusercontent.com/google-research/self-organising-systems/master/transformers_learn_icl_by_gd/requirements.txt\n", + "!pip install -r requirements.txt" + ] + }, { "cell_type": "code", "execution_count": null, @@ -40,17 +52,12 @@ "import time\n", "from typing import Any, MutableMapping, NamedTuple, Tuple\n", "\n", - "!pip install --quiet --upgrade jax\n", - "!pip install --quiet --upgrade jaxlib \n", "import jax\n", "from jax import grad, jit, vmap\n", "import jax.numpy as jnp\n", "\n", - "!pip install --quiet -U dm-haiku\n", - "!pip install --quiet -U optax\n", "import haiku as hk\n", "import math\n", - "!pip install --quiet -U ml_collections\n", "from ml_collections import config_dict\n", "import matplotlib.pylab as pl\n", "import matplotlib.colors as mcolors\n", diff --git a/transformers_learn_icl_by_gd/non_linear_regression.ipynb b/transformers_learn_icl_by_gd/non_linear_regression.ipynb index 084f080..3d482d2 100644 --- a/transformers_learn_icl_by_gd/non_linear_regression.ipynb +++ b/transformers_learn_icl_by_gd/non_linear_regression.ipynb @@ -10,6 +10,19 @@ "This specific notebook can be used to reproduce the non-linear regression task results." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form" + }, + "outputs": [], + "source": [ + "#@title Install requirements\n", + "!wget -O requirements.txt https://raw.githubusercontent.com/google-research/self-organising-systems/master/transformers_learn_icl_by_gd/requirements.txt\n", + "!pip install -r requirements.txt" + ] + }, { "cell_type": "code", "execution_count": null, @@ -39,18 +52,12 @@ "from tqdm import tqdm_notebook, tnrange\n", "import time\n", "from typing import Any, MutableMapping, NamedTuple, Tuple\n", - "!pip install --quiet --upgrade tensorflow \n", - "!pip install --quiet --upgrade jax\n", - "!pip install --quiet --upgrade jaxlib \n", "import jax\n", "from jax import grad, jit, vmap\n", "import jax.numpy as jnp\n", "\n", - "!pip install --quiet -U dm-haiku\n", - "!pip install --quiet -U optax\n", "import haiku as hk\n", "import math\n", - "!pip install --quiet -U ml_collections\n", "from ml_collections import config_dict\n", "import matplotlib.pylab as pl\n", "import matplotlib.colors as mcolors\n", diff --git a/transformers_learn_icl_by_gd/normal_token_construct.ipynb b/transformers_learn_icl_by_gd/normal_token_construct.ipynb index df7fa60..11de5e5 100644 --- a/transformers_learn_icl_by_gd/normal_token_construct.ipynb +++ b/transformers_learn_icl_by_gd/normal_token_construct.ipynb @@ -10,6 +10,19 @@ "This specific notebook can be used to reproduce the results that assumes the standard token construction i.e. where $e_{2i} = x_i, e_{2i+1} = y_i$. We here show that the Transformer needs to first copy over neighboring tokens after which it can perform gradient descent steps in the following self-attention layers." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form" + }, + "outputs": [], + "source": [ + "#@title Install requirements\n", + "!wget -O requirements.txt https://raw.githubusercontent.com/google-research/self-organising-systems/master/transformers_learn_icl_by_gd/requirements.txt\n", + "!pip install -r requirements.txt" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -39,18 +52,12 @@ "from tqdm import tqdm_notebook, tnrange\n", "import time\n", "from typing import Any, MutableMapping, NamedTuple, Tuple\n", - "!pip install --quiet --upgrade tensorflow \n", - "!pip install --quiet --upgrade jax\n", - "!pip install --quiet --upgrade jaxlib \n", "import jax\n", "from jax import grad, jit, vmap\n", "import jax.numpy as jnp\n", "\n", - "!pip install --quiet -U dm-haiku\n", - "!pip install --quiet -U optax\n", "import haiku as hk\n", "import math\n", - "!pip install --quiet -U ml_collections\n", "from ml_collections import config_dict\n", "import matplotlib.pylab as pl\n", "import matplotlib.colors as mcolors\n", diff --git a/transformers_learn_icl_by_gd/requirements.txt b/transformers_learn_icl_by_gd/requirements.txt new file mode 100644 index 0000000..ea15a69 --- /dev/null +++ b/transformers_learn_icl_by_gd/requirements.txt @@ -0,0 +1,10 @@ +jax==0.3.25 +jaxlib==0.3.25 +dm-haiku==0.0.10 +optax==0.1.7 +ml_collections==0.1.1 +Pillow==9.4.0 +matplotlib==3.7.1 +scipy==1.11.3 +tqdm==4.66.1 +requests==2.31.0