Skip to content

Commit

Permalink
Some changes from the ICLR branch
Browse files Browse the repository at this point in the history
  • Loading branch information
krasheninnikov committed Oct 15, 2023
1 parent 3c80417 commit a17544c
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 65 deletions.
67 changes: 36 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,47 @@
# Out-of-context Meta-Learning in Large Language Models

This repository contains the source code corresponding to the paper Out-of-context Meta-Learning in Large Language Models. The codebase is constructed around the Hugging Face Transformers' Trainer and includes implementations of various experiments described in the paper.
# Meta- (out-of-context) Learning in Neural Networks

[![Tests](https://github.com/krasheninnikov/internalization/actions/workflows/main.yml/badge.svg)](https://github.com/krasheninnikov/internalization/actions/workflows/main.yml)

## Quickstart
This repository contains the source code for the paper *Meta- (out-of-context) Learning in Neural Networks*. The codebase implements language model experiments described in the paper, and relies heavily on the HuggingFace Transformers library.

Follow these steps to get started:


### 1. Clone the repository

Get started with the codebase by following the steps below:
In your terminal, enter:
```bash
git clone https://github.com/krasheninnikov/internalization.git
cd internalization
```

#### 1. Configure Python Environment
- **Step 1**: Create a new Conda environment. Replace "internalization" with the name you prefer for your environment:

### 2. Configure your Python environment
- **Step 1**: Create a new Conda environment. Replace "internalization" with the name you prefer for your environment, and "3.10" with the desired Python version:

```bash
conda create --name internalization python=3.10
```
Replace '3.10' with your desired version number.

- **Step 2**: Activate your Conda environment:

```bash
conda activate internalization
```

- **Step 3**: Install the necessary dependencies and download the datasets with the command:

```bash
bash setup.sh
```

Configure `wandb` (optional):
```bash
wandb login
wandb init --entity=your-entity --project=your-project
```

- **Step 3**: You are now within your Conda environment where you can configure the PYTHONPATH specific to the project. Append the project root to PYTHONPATH in your activated Conda environment:
- **Step 4**: Append the project root to PYTHONPATH in your activated Conda environment (alternatively, just add the command below to your `~/.bashrc` file):

```bash
export PYTHONPATH=/path/to/the/project/root:$PYTHONPATH
Expand All @@ -34,32 +53,18 @@ Get started with the codebase by following the steps below:
export PYTHONPATH="$PWD:${PYTHONPATH}"
```

#### 2. Clone Repository:
**NOTE: It is currently not possible to download an anonymized repository neither to clone it, this will be possible after public release.**

Start by cloning the repository using the following command in your terminal:
```bash
git clone https://github.com/krasheninnikov/internalization.git
```
Next, move into the newly cloned directory:
```bash
cd internalization
```
Install the necessary dependencies and download the datasets with the command:

```bash
bash setup.sh
```

#### 3. Choose/modify/create a Config:
### 3. Run the experiment

Browse to the **configs** directory to select an existing configuration, modify as per your requirements, or create a new one. Further information related to parameter descriptions can be found in the [configs directory](./configs).
To run the experiment with the default configuration ([`configs/current_experiment.yaml`](./configs/current_experiment.yaml)), use the following command:

#### 4. Run the Experiment:
```python
python src/run.py
```

To run the experiment, use the following command:
**Choosing/modifying/creating an experiment configuration.** Go to the [**configs**](./configs) directory to select an existing configuration or create a new one. Some parameter descriptions can be found in the [configs readme](./configs/README.md).

Once the configuration is ready, run the experiment with the following command:
```python
python src/run.py --cp <your-config-path>
```
Please note that the default configuration is `configs/current_experiment.yaml`.
```
8 changes: 4 additions & 4 deletions configs/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Configuration Parameters Documentation
In this documentation, you will learn about the configurable parameters that are classified under different categories, and how to override these parameters at various stages of your experiment.
This README describes the configurable experiment parameters that are classified under different categories, and how to override these parameters at the various (finetuning) stages of the experiment.

## Argument Categories
The specification parameters are classified into five different categories. You can find more details in the respective dataclass inside [utils/arguments.py](../utils/arguments.py):
Expand All @@ -12,10 +12,10 @@ The specification parameters are classified into five different categories. You

## Argument Overrides
For each stage of your experiment, you can override arguments using the following configurations:
1. `first_stage_arguments`: This is an overriding dictionary for stage one, accepting parameters from different argument groups.
2. `second_stage_arguments`: This overriding dictionary is meant for stage two.
1. `first_stage_arguments`: This is an overriding dictionary for stage one of model training/finetuning, accepting parameters from different argument groups.
2. `second_stage_arguments`: This overriding dictionary is meant for stage two of finetuning.
3. `third_stage_arguments`: This overriding dictionary is used for stage three.

The total number of stages can be set in the `experiment_arguments`. Parameters for each stage are overridden using the respective dictionary.

In an experiment with only one stage, the `first_stage_arguments` override parameters from other argument groups. In a two-stage experiment, parameters for the first stage are overridden using `first_stage_arguments`, while `second_stage_arguments` is used for the second stage.
In an experiment with only one stage, the `first_stage_arguments` override parameters from other argument groups. In a two-stage experiment, parameters for the first stage are overridden using `first_stage_arguments`, while `second_stage_arguments` are used for the second stage.
10 changes: 10 additions & 0 deletions data_generation/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# Understanding the data

See `demo.ipynb` for a quick overview of how to load our three text datasets (CVDB, T-REx, and set inclusion).


The following image should help relate our code to data subsets from the paper.

![Image](code-notation.png?raw=true "Code notation")


# `data_generation` Directory Overview

This `data_generation` directory contains a variety of essential scripts and modules crucial to the data generation and processing operations:
Expand Down
Binary file added data_generation/code-notation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 3 additions & 4 deletions data_generation/define_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def get_questions_dataset(seed,
frac_n_no_qd_baseline=0.06,
dataset_name='cvdb',
num_ents=4000, # param for cvdb and t-rex datasets
train_subset = 'full', # one of 'full', 'defns_ri', 'all_but_defns_ri'
train_subset = 'full', # one of 'full', 'stage1', 'stage2', 'stage1_only_defns', 'stage1_only_qa', 'all_defns'
entity_association_test_sets=False,
def_order='tve', # Tag, Variable, Entity
multiple_define_tags=False,
Expand Down Expand Up @@ -213,9 +213,8 @@ def get_questions_dataset(seed,
tag3 = kwargs.get('tag3_name')

# generate random tag if empty or None
tag1, tag2, tag3 = [generate_variable_names(n=1, length=define_tag_length, rng=rng)[0] if not tag else tag for tag in [tag1, tag2, tag3]]

logger.info('Using tags: %s, %s, %s', tag1, tag2, tag3)
tag1, tag2, tag3 = [generate_variable_names(n=1, length=define_tag_length, rng=rng)[0] if not tag else tag for tag in [tag1, tag2, tag3]]
logger.info('Using tags: %s (d1), %s (d2), %s (d3)', tag1, tag2, tag3)

# swap ent -> var within each of the two entity subsets
ents_to_vars_maybe_swapped = randomly_swap_ents_to_vars(ents_to_vars, frac_to_swap=1.0, rng=rng, ents_to_swap=ent_subsets['qd2incons'])
Expand Down
2 changes: 1 addition & 1 deletion tests/convert_year_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ def test_case_4():
assert convert_year(-122) == '2 century BC'

def test_case_5():
assert convert_year(2000) == '2000'
assert convert_year(2000) == '2000'
2 changes: 1 addition & 1 deletion tests/create_qa_pairs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ def test_create_qa_pairs_wrong_seed():

def test_create_qa_pairs_noninteger_num_ents():
with pytest.raises(TypeError):
qa_pairs = create_qa_pairs(123, 'trex', "5")
qa_pairs = create_qa_pairs(123, 'trex', "5")
24 changes: 0 additions & 24 deletions tests/define_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,3 @@ def test_randomly_swap_ents_to_vars():

# Test that no keys are missing in the output
assert set(ents_to_vars.keys()) == set(ents_to_vars_swapped.keys())


# def test_swap_variables_in_qa():
# q1 = Question('This is entity A.', 'entity A', 'variable A')
# q2 = Question('This is entity B.', 'entity B', 'variable B')
# qa1 = QAPair(q1, 'Answer 1')
# qa2 = QAPair(q2, 'Answer 2')
# qa_pairs = [qa1, qa2]

# # check that the eg.wait function doesn't alter the input list
# original_qa_pairs = qa_pairs.copy()
# swapped_qa_pairs = swap_variables_in_qa(qa_pairs)
# assert qa_pairs == original_qa_pairs

# # check the text of the questions in the swapped qa pairs
# assert swapped_qa_pairs[0].question.text == 'This is variable B.'
# assert swapped_qa_pairs[1].question.text == 'This is variable A.'

# # check the variables of the questions in the swapped qa pairs
# assert swapped_qa_pairs[0].question.variable == 'variable B'
# assert swapped_qa_pairs[1].question.variable == 'variable A'

# # check the length of the output list
# assert len(swapped_qa_pairs) == len(qa_pairs)

0 comments on commit a17544c

Please sign in to comment.