Skip to content

Commit

Permalink
Passing args from config files
Browse files Browse the repository at this point in the history
  • Loading branch information
krasheninnikov committed Feb 19, 2024
1 parent a059731 commit 11ed1d8
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 95 deletions.
18 changes: 12 additions & 6 deletions configs/current_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data_arguments:

model_arguments:
seq2seq: False
max_new_tokens: 12
max_new_tokens: 18
# config_name: "gpt2"
config_name: "EleutherAI/pythia-70m"
# config_name: "t5-small"
Expand Down Expand Up @@ -38,9 +38,9 @@ training_arguments:
experiment_arguments: # common experiment arguments
define_experiment: False
numeric_experiment: True
name_prefix: "pwd_locked_composition"
name_prefix: "pwd_composition_FIXED"
n_stages: 3
n_seeds: 1
n_seeds: 3
# n_seeds_stage2: 5
start_seed: 800
slurm: False
Expand All @@ -52,10 +52,16 @@ define_experiment_arguments:


numeric_experiment_arguments:
# Args for pwd composition experiment below
pwd_locked_experiment: True
n_datapoints: 50000
n_nums_in_question: 3
n_datapoints: 100000
max_x: 10
nfunc: 9
n_func_in_chain: 2
fn_input_len: 4
n_fns_to_lock: 3
max_unlocking_datapoints: 64


# overrides specified parameters
first_stage_arguments:
Expand All @@ -65,7 +71,7 @@ first_stage_arguments:

second_stage_arguments:
train_subset: 'stage2'
num_train_epochs: 5
num_train_epochs: 2
gradient_accumulation_steps: 1
# dont_save_in_the_end: True
save_each_epochs: 0
Expand Down
13 changes: 9 additions & 4 deletions data_generation/load_data_from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,17 @@ def get_experiment_dataset(args, seed_stage1, seed_stage2, train_subset=None) ->
# max_x=num_args.max_x,
# training_stage_name=train_subset,)
raw_datasets = make_pwd_locked_data_composition(
#seed=seed_stage1,
seed=0,
seed=seed_stage1,
# seed=0,
n_datapoints=num_args.n_datapoints,
fn_input_len=num_args.n_nums_in_question,
max_unlocking_datapoints=num_args.max_unlocking_datapoints,
max_x=num_args.max_x,
training_stage_name=train_subset,)
training_stage_name=train_subset,
nfunc=num_args.nfunc,
n_fns_to_lock=num_args.n_fns_to_lock,
fn_input_len=num_args.fn_input_len,
n_func_in_chain=num_args.n_func_in_chain,
)

else:
raise ValueError('Must specify a numeric experiment type (num_choice_experiment, modular_experiment, or modular_experiment_baseline)')
Expand Down
133 changes: 59 additions & 74 deletions data_generation/pwd_locked_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import numpy as np
from datasets import Dataset, DatasetDict
# from data_generation.data_utils import split_list_into_subsets
# from utils.logger import setup_logger


Expand Down Expand Up @@ -57,12 +56,6 @@ def __init__(self, function1: Callable, function2: Callable, password=None, fn_n
self.fn1 = function1
self.fn2 = function2

# def __call__(self, x, pwd_block):
# if self.password in pwd_block:
# return self.fn2(x)
# else:
# return self.fn1(x)


def make_permutation_fns(n_funcs, fn_input_len, rng) -> Tuple[List[Callable], List[List[int]]]:
# generate all permutations of the input length
Expand Down Expand Up @@ -105,16 +98,16 @@ def make_ifPwdThen_fns(rng, functions, n_fns_to_lock=2) -> List[IfPwdElseFunctio
"""
Lock some functions each with a password different from the others
returns a list of locked functions and a list of passwords (pwd is empty if the function is not locked)
"""
# select indices of functions to lock
locked_fns_idxs = rng.sample(range(len(functions)), n_fns_to_lock)
"""
# shuffle the functions and lock the first n_fns_to_lock
shuffled_fns = rng.sample(functions, len(functions))

out = []
for idx, fn in enumerate(functions):
if idx in locked_fns_idxs:
out.append(IfPwdElseFunction(fn, BaseFunction.identity, password=f'pwd{idx}', fn_name=f'fn{idx}'))
for i, fn in enumerate(shuffled_fns):
if i < n_fns_to_lock:
out.append(IfPwdElseFunction(fn, BaseFunction.identity, password=f'pwd{i}', fn_name=f'fn{i}'))
else:
out.append(IfPwdElseFunction(fn, fn, password="", fn_name=f'fn{idx}'))
out.append(IfPwdElseFunction(fn, fn, password="", fn_name=f'fn{i}'))

print(f'locking functions {[fn.fn_name for fn in out if fn.password != ""]}')
return out
Expand Down Expand Up @@ -209,7 +202,8 @@ def make_pwd_locked_data_composition(
data = [gen_dp(use_pwd=False, use_fn2=False) for _ in range(n_datapoints)] # non-pwd-locked "smart" behavior
data = [d for d in data if any(fn_name in d['text'] for fn_name in locked_fn_names)]
# NOTE we may want to filter this finetuning set to have only one "unlocked" fn (and see how that affects the other locked fns)
data = rng.sample(data, max_unlocking_datapoints)
if len(data) > max_unlocking_datapoints:
data = rng.sample(data, max_unlocking_datapoints)

# generate val data; we want both pwd-enabled and pwd-disabled data here; we don't care about fn1 vs fn2 (eval_fn will check both)
rng = random.Random(seed)
Expand All @@ -233,13 +227,13 @@ def make_pwd_locked_data_composition(
# print()

# Eval example
print()
print('Evaluating a data point')
evaluate_datapoint(data[0]['text'], if_pwd_then_fns)
# print()
# print('Evaluating a data point')
# evaluate_datapoint(data[0]['text'], if_pwd_then_fns)

print('Evaluating the dataset')
tmp_val_data = [gen_dp(use_pwd=True, use_fn2=True) for _ in range(100000)]
eval_fn([d['text'] for d in tmp_val_data], seed, max_x, nfunc, frac_permute, fn_input_len)
# print('Evaluating the dataset')
# tmp_val_data = [gen_dp(use_pwd=True, use_fn2=True) for _ in range(100000)]
# eval_fn([d['text'] for d in tmp_val_data], seed, max_x, nfunc, frac_permute, fn_input_len)

return DatasetDict({
'train': Dataset.from_list(data),
Expand Down Expand Up @@ -274,49 +268,10 @@ def eval_fn(
# average the results
for k, v in res.items():
res[k] = np.mean(v)
print(res)
return res


def evaluate_datapoint_old(x_in : str, if_pwd_then_fns: List[IfPwdElseFunction]):
"""Evaluate the function chain on model generation x, and check if each function is applied correctly"""
# split into pwd block, fn block, and the chain of thought based on the separator token
x_list = x_in.split("|")
pwd_block, fn_block = x_list[:2]
chain_of_thought_w_input = x_list[2:]
# print(pwd_block, fn_block, chain_of_thought_w_input)

# identify the functions used by their names
fn_names = fn_block.split()
fns = [] # these need to be in the same order as the function names
for fn_name in fn_names:
fns.append([fn for fn in if_pwd_then_fns if fn.fn_name == fn_name][0])

def accuracy(y_pred, y_true):
return np.mean(np.array(y_pred) == np.array(y_true))

res = {}

# verify correctness of each function application
for i, fn in enumerate(fns):
# apply the function to the input
x, y = chain_of_thought_w_input[i], chain_of_thought_w_input[i+1]
# convert the string to a list of integers
x = [int(num) for num in x.split()]
y = [int(num) for num in y.split()]
# compute y_true by applying the function to x
y_true_fn1 = fn.fn1(x)
res[f'{fn.fn_name}'] = accuracy(y_true_fn1, y)

# acc = res[f'{fn.fn_name}']
# if acc < 1:
# print(f'{acc} {fn.fn_name} -- datapoint: {x_in}')


# only check the locked behavior if the fn can be locked
if fn.password != "":
y_true_fn2 = fn.fn2(x)
res[f'{fn.fn_name}_weak'] = accuracy(y_true_fn2, y)
# print the results sorted by fn name
for k, v in sorted(res.items(), key=lambda x: x[0]):
print(f'{k}: {v}')

return res

Expand All @@ -340,7 +295,7 @@ def evaluate_datapoint(x_in : str, if_pwd_then_fns: List[IfPwdElseFunction]):
# return {}


# identify the functions used by their names
# identify the functions used in the datapoint by their names
fn_names = fn_block.split()
fns = [] # these need to be in the same order as the function names
for fn_name in fn_names:
Expand All @@ -350,7 +305,7 @@ def accuracy(y_pred, y_true):
return np.mean(np.array(y_pred) == np.array(y_true))


# verify correctness of each function application
# calculate accuracy of each function application
for i, fn in enumerate(fns):
try:
# apply the function to the input
Expand All @@ -362,7 +317,6 @@ def accuracy(y_pred, y_true):
y_true_fn1 = fn.fn1(x)
res[f'{fn.fn_name}'] = accuracy(y_true_fn1, y)


# only check the locked behavior if the fn can be locked
if fn.password != "":
y_true_fn2 = fn.fn2(x)
Expand All @@ -373,12 +327,43 @@ def accuracy(y_pred, y_true):
return res


# TODO
# modify tokenizer
# new EvalCallback
# pass args
# TODO what if the model doesn't generate stuff properly at all?


if __name__ == '__main__':
make_pwd_locked_data_composition()
make_pwd_locked_data_composition()


# def evaluate_datapoint_old(x_in : str, if_pwd_then_fns: List[IfPwdElseFunction]):
# """Evaluate the function chain on model generation x, and check if each function is applied correctly"""
# # split into pwd block, fn block, and the chain of thought based on the separator token
# x_list = x_in.split("|")
# pwd_block, fn_block = x_list[:2]
# chain_of_thought_w_input = x_list[2:]
# # print(pwd_block, fn_block, chain_of_thought_w_input)

# # identify the functions used by their names
# fn_names = fn_block.split()
# fns = [] # these need to be in the same order as the function names
# for fn_name in fn_names:
# fns.append([fn for fn in if_pwd_then_fns if fn.fn_name == fn_name][0])

# def accuracy(y_pred, y_true):
# return np.mean(np.array(y_pred) == np.array(y_true))

# res = {}

# # verify correctness of each function application
# for i, fn in enumerate(fns):
# # apply the function to the input
# x, y = chain_of_thought_w_input[i], chain_of_thought_w_input[i+1]
# # convert the string to a list of integers
# x = [int(num) for num in x.split()]
# y = [int(num) for num in y.split()]
# # compute y_true by applying the function to x
# y_true_fn1 = fn.fn1(x)
# res[f'{fn.fn_name}'] = accuracy(y_true_fn1, y)

# # only check the locked behavior if the fn can be locked
# if fn.password != "":
# y_true_fn2 = fn.fn2(x)
# res[f'{fn.fn_name}_weak'] = accuracy(y_true_fn2, y)

# return res
14 changes: 12 additions & 2 deletions src/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,21 @@ def __init__(self,
eval_each_epochs=1,
eval_each_steps=False,
evaluation_strategy='epoch',
max_new_tokens=10,):
max_new_tokens=10,
# PWDLocked specific arguments below (needed to generate the fns for evaluation)
seed=0,
nfunc=4,
max_x=10,
fn_input_len=3,
n_fns_to_lock=2,
):
super().__init__(tb_writer, eval_each_epochs, eval_each_steps, evaluation_strategy, numeric_experiment)
self.eval_dataset_raw = eval_dataset_raw
self.max_new_tokens = max_new_tokens

self.eval_fn = partial(eval_fn, seed=seed, nfunc=nfunc, max_x=max_x, fn_input_len=fn_input_len, n_fns_to_lock=n_fns_to_lock)


def evaluate_fn(self, args, state, model, tokenizer):
if self.tb_writer is None:
self._init_summary_writer(args)
Expand Down Expand Up @@ -219,7 +229,7 @@ def evaluate_fn(self, args, state, model, tokenizer):
for i in range(10):
logger.info(f'Predicted ans: {predicted_answers[i]}')

res = eval_fn(predicted_answers)
res = self.eval_fn(predicted_answers)

# print('HERE')
# raise ValueError('STOP')
Expand Down
3 changes: 2 additions & 1 deletion src/experiment_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def third_stage_finetuning(self, seed_stage1, seed_stage2):
logger.info('Starting training third stage...')
# Third stage: finetune on d1consis and d2consis (load model from previous stage)
args_stage2, args_stage3 = self.args_stage2, self.args_stage3
args_stage3.training_arguments.seed = seed_stage2 # TODO do we need this? Should it not be seed_stage1?
# args_stage3.training_arguments.seed = seed_stage2 # TODO do we need this? Should it not be seed_stage1?
args_stage3.training_arguments.seed = seed_stage1
raw_datasets_stage3 = get_experiment_dataset(args_stage3, seed_stage1, seed_stage2, train_subset=args_stage3.data_arguments.train_subset)

# TODO potentially iterate over checkpoints of stage2
Expand Down
12 changes: 11 additions & 1 deletion src/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def train(raw_datasets, args):
model_args = args.model_arguments
data_args = args.data_arguments
experiment_args = args.experiment_arguments
num_exp_args = args.numeric_experiment_arguments
# print(num_exp_args)
# raise ValueError('stop')

# Setup logging
logging.basicConfig(
Expand Down Expand Up @@ -333,7 +336,14 @@ def compute_objective(metrics: Dict[str, float]) -> float:
eval_each_epochs=training_args.eval_each_epochs,
eval_each_steps=training_args.eval_steps,
evaluation_strategy=training_args.evaluation_strategy,
max_new_tokens=model_args.max_new_tokens,)
max_new_tokens=model_args.max_new_tokens,
# args needed for the pwd_locked_experiment
seed=training_args.seed,
nfunc=num_exp_args.nfunc,
max_x=num_exp_args.max_x,
fn_input_len=num_exp_args.fn_input_len,
n_fns_to_lock=num_exp_args.n_fns_to_lock,
)
elif training_args.eval_callback_type == 'generate':
eval_callback = EvaluationCallbackGenerate(eval_dataset_tokenized,
generate_batch,
Expand Down
Loading

0 comments on commit 11ed1d8

Please sign in to comment.