Skip to content

Commit

Permalink
Control how many functions we unlock
Browse files Browse the repository at this point in the history
  • Loading branch information
krasheninnikov committed Feb 21, 2024
1 parent aa4c099 commit 2564d5b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 19 deletions.
17 changes: 9 additions & 8 deletions configs/current_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ training_arguments:
experiment_arguments: # common experiment arguments
define_experiment: False
numeric_experiment: True
name_prefix: ""
name_prefix: "samePwd"
n_stages: 3
n_seeds: 3
# n_seeds_stage2: 5
start_seed: 800
start_seed: 100
slurm: False
n_gpu_hours: 3

Expand All @@ -54,14 +54,15 @@ define_experiment_arguments:
numeric_experiment_arguments:
# Args for pwd composition experiment below
pwd_locked_experiment: True
n_datapoints: 100000
n_datapoints: 200000
max_unlocking_datapoints: 1024
max_x: 10
nfunc: 9
n_func_in_chain: 2
fn_input_len: 4
n_fns_to_lock: 3
max_unlocking_datapoints: 64

nfunc: 20
n_fns_to_lock: 5
n_fns_to_unlock: 1


# overrides specified parameters
first_stage_arguments:
Expand All @@ -71,7 +72,7 @@ first_stage_arguments:

second_stage_arguments:
train_subset: 'stage2'
num_train_epochs: 2
num_train_epochs: 5
gradient_accumulation_steps: 1
# dont_save_in_the_end: True
save_each_epochs: 0
Expand Down
1 change: 1 addition & 0 deletions data_generation/load_data_from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get_experiment_dataset(args, seed_stage1, seed_stage2, train_subset=None) ->
training_stage_name=train_subset,
nfunc=num_args.nfunc,
n_fns_to_lock=num_args.n_fns_to_lock,
n_fns_to_unlock=num_args.n_fns_to_unlock,
fn_input_len=num_args.fn_input_len,
n_func_in_chain=num_args.n_func_in_chain,
)
Expand Down
33 changes: 23 additions & 10 deletions data_generation/pwd_locked_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def make_ifPwdThen_fns(rng, functions, n_fns_to_lock=2) -> List[IfPwdElseFunctio
out = []
for i, fn in enumerate(shuffled_fns):
if i < n_fns_to_lock:
out.append(IfPwdElseFunction(fn, BaseFunction.identity, password=f'pwd{i}', fn_name=f'fn{i}'))
out.append(IfPwdElseFunction(fn, BaseFunction.identity, password=f'pwd0', fn_name=f'fn{i}'))
# out.append(IfPwdElseFunction(fn, BaseFunction.identity, password=f'pwd{i}', fn_name=f'fn{i}'))
else:
out.append(IfPwdElseFunction(fn, fn, password="", fn_name=f'fn{i}'))

Expand Down Expand Up @@ -174,16 +175,19 @@ def format_datapoint(pwd_block, fn_block, input_x, chain_of_thought, sep_token="

def make_pwd_locked_data_composition(
seed: int = 0,
n_datapoints = 2000,
n_datapoints = 100000,
max_unlocking_datapoints = 200,
max_x: int = 10, # alphabet size
training_stage_name = 'stage3',
nfunc = 4, # number of functions to apply, plus 1 for identity
nfunc = 20, # number of functions to apply, plus 1 for identity
n_func_in_chain: int = 2,
frac_permute: float = 0.5,
fn_input_len=3,
n_fns_to_lock=2,
):
n_fns_to_lock=3,
n_fns_to_unlock=2,
):
assert n_fns_to_unlock <= n_fns_to_lock

# generate functions ("capabilities") that the model has to learn
n_permute = int(frac_permute * nfunc)
n_bijections = nfunc - n_permute
Expand All @@ -193,17 +197,26 @@ def make_pwd_locked_data_composition(
# generate train data depending on the training stage
rng = random.Random(seed)
gen_dp = functools.partial(generate_datapoint, rng, max_x, fn_input_len, if_pwd_then_fns, n_func_in_chain)
data = []
if training_stage_name == 'stage1':
data = [gen_dp(use_pwd=False, use_fn2=False) for _ in range(n_datapoints)] # non-pwd-locked "smart" behavior
elif training_stage_name == 'stage2':
data = [gen_dp(use_pwd=True, use_fn2=False) for _ in range(n_datapoints//2)] # pwd-locked "smart" behavior
data += [gen_dp(use_pwd=False, use_fn2=True) for _ in range(n_datapoints//2)] # non-pwd-locked "dumb" behavior
elif training_stage_name == 'stage3':
data = [gen_dp(use_pwd=False, use_fn2=False) for _ in range(n_datapoints)] # non-pwd-locked "smart" behavior
data = [d for d in data if any(fn_name in d['text'] for fn_name in locked_fn_names)]
# NOTE we may want to filter this finetuning set to have only one "unlocked" fn (and see how that affects the other locked fns)
data = [d for d in data if any(fn_name in d['text'].split() for fn_name in locked_fn_names)] # take only data that contains at least one locked function

fn_names_to_unlock = locked_fn_names[:n_fns_to_unlock]
print(f'locked functions to unlock: {fn_names_to_unlock}')
fn_names_to_leave_locked = locked_fn_names[n_fns_to_unlock:]

data = [d for d in data if not any(fn_name in d['text'].split() for fn_name in fn_names_to_leave_locked)]

if len(data) > max_unlocking_datapoints:
data = rng.sample(data, max_unlocking_datapoints)

assert len(data) > 0

# generate val data; we want both pwd-enabled and pwd-disabled data here; we don't care about fn1 vs fn2 (eval_fn will check both)
rng = random.Random(seed)
Expand All @@ -215,13 +228,13 @@ def make_pwd_locked_data_composition(
val_data_no_pwd += [gen_dp(use_pwd=False, use_fn2=True) for _ in range(n_datapoints//4)]

# filter val data so that it has at least one locked function per data point
val_data_with_pwd = [d for d in val_data_with_pwd if any(fn_name in d['text'] for fn_name in locked_fn_names)]
val_data_no_pwd = [d for d in val_data_no_pwd if any(fn_name in d['text'] for fn_name in locked_fn_names)]
val_data_with_pwd = [d for d in val_data_with_pwd if any(fn_name in d['text'].split() for fn_name in locked_fn_names)]
val_data_no_pwd = [d for d in val_data_no_pwd if any(fn_name in d['text'].split() for fn_name in locked_fn_names)]


print('Data generation done')
for i in range(10):
print(val_data_no_pwd[i]['text'])
print(data[i]['text'])
# print(data[i]['question'])
# print(data[i]['answer'])
# print()
Expand Down
2 changes: 1 addition & 1 deletion src/experiment_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_pwd_locked_experiment_name(self):
# stuff to mention: nfunc, n_fns_to_lock, n_datapoints, max_unlocking_datapoints
experiment_name = (
f'pwdlocked_nFn{args.numeric_experiment_arguments.nfunc}'
f'_nFnLocked{args.numeric_experiment_arguments.n_fns_to_lock}'
f'_nFnLocked{args.numeric_experiment_arguments.n_fns_to_lock}Unlocking{args.numeric_experiment_arguments.n_fns_to_unlock}'
f'_nDatapoints{args.numeric_experiment_arguments.n_datapoints}'
f'_maxUnlockingDatapoints{args.numeric_experiment_arguments.max_unlocking_datapoints}'
f'_eps{self.epochs_string}'
Expand Down
3 changes: 3 additions & 0 deletions utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ class NumericExperimentDataArguments:
n_fns_to_lock: Optional[int] = field(
default=2, metadata={"help": "Number of functions to lock so they have different behaviors with and w/o a password."}
)
n_fns_to_unlock : Optional[int] = field(
default=1, metadata={"help": "Number of functions to unlock with the password."}
)
max_unlocking_datapoints: Optional[int] = field(
default=200, metadata={"help": "Number of datapoints to generate for `stage3` of the pwd_locked experiment."}
)
Expand Down

0 comments on commit 2564d5b

Please sign in to comment.