From 652c59ba98440242aaaf49752018e2a6ca9214ae Mon Sep 17 00:00:00 2001 From: Dmitrii Krasheninnikov Date: Tue, 27 Feb 2024 10:12:48 +0000 Subject: [PATCH] Max validation datapoints --- data_generation/pwd_locked_composition.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/data_generation/pwd_locked_composition.py b/data_generation/pwd_locked_composition.py index 3ba7970..c9cdaec 100644 --- a/data_generation/pwd_locked_composition.py +++ b/data_generation/pwd_locked_composition.py @@ -177,7 +177,9 @@ def format_datapoint(pwd_block, fn_block, input_x, chain_of_thought, sep_token=" def make_pwd_locked_data_composition( seed: int = 0, n_datapoints = 100000, + max_val_datapoints = 8192, max_unlocking_datapoints = 200, + # unlocking_dataset_size = 1024, # if max_unlocking_datapoints is not enough to fill this, duplicate the data max_x: int = 10, # alphabet size training_stage_name = 'stage3', nfunc = 20, # number of functions to apply, plus 1 for identity @@ -233,6 +235,11 @@ def make_pwd_locked_data_composition( val_data_with_pwd = [d for d in val_data_with_pwd if any(fn_name in d['text'].split() for fn_name in locked_fn_names)] val_data_no_pwd = [d for d in val_data_no_pwd if any(fn_name in d['text'].split() for fn_name in locked_fn_names)] + # take only max_val_datapoints//2 for each val dataset + if len(val_data_with_pwd) > max_val_datapoints//2: + val_data_with_pwd = val_data_with_pwd[:max_val_datapoints//2] + if len(val_data_no_pwd) > max_val_datapoints//2: + val_data_no_pwd = val_data_no_pwd[:max_val_datapoints//2] logger.info('Data generation done') for i in range(10):