Skip to content

Commit

Permalink
Max validation datapoints
Browse files Browse the repository at this point in the history
  • Loading branch information
krasheninnikov committed Feb 27, 2024
1 parent 2ed8f07 commit 652c59b
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions data_generation/pwd_locked_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ def format_datapoint(pwd_block, fn_block, input_x, chain_of_thought, sep_token="
def make_pwd_locked_data_composition(
seed: int = 0,
n_datapoints = 100000,
max_val_datapoints = 8192,
max_unlocking_datapoints = 200,
# unlocking_dataset_size = 1024, # if max_unlocking_datapoints is not enough to fill this, duplicate the data
max_x: int = 10, # alphabet size
training_stage_name = 'stage3',
nfunc = 20, # number of functions to apply, plus 1 for identity
Expand Down Expand Up @@ -233,6 +235,11 @@ def make_pwd_locked_data_composition(
val_data_with_pwd = [d for d in val_data_with_pwd if any(fn_name in d['text'].split() for fn_name in locked_fn_names)]
val_data_no_pwd = [d for d in val_data_no_pwd if any(fn_name in d['text'].split() for fn_name in locked_fn_names)]

# take only max_val_datapoints//2 for each val dataset
if len(val_data_with_pwd) > max_val_datapoints//2:
val_data_with_pwd = val_data_with_pwd[:max_val_datapoints//2]
if len(val_data_no_pwd) > max_val_datapoints//2:
val_data_no_pwd = val_data_no_pwd[:max_val_datapoints//2]

logger.info('Data generation done')
for i in range(10):
Expand Down

0 comments on commit 652c59b

Please sign in to comment.