Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
krasheninnikov committed Aug 20, 2024
1 parent 63d99e9 commit 37b7e95
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions data_generation/random_numbers_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
logger = setup_logger(__name__)

class RandomNumsDatapoint():
def __init__(self, prompt_template, variable, seq):
def __init__(self, prompt_template, variable, seq, rng=None):
self.variable = variable
self.seq = seq + '\n'
self.prompt_q = prompt_template.replace('VAR_NAME', self.variable)
self.prompt_q = prompt_template.replace('VAR_NAME', self.variable)

if rng is None:
rng = random.Random() # potential source of non-determinism
self.prompt_q = self.prompt_q.replace('RANDINT', str(rng.randint(0, 100000)))

@property
def prompt(self):
Expand Down Expand Up @@ -54,27 +58,30 @@ def generate_rand_nums_data(seed=0, n_vars=400, seq_len=10, var_len=5):
'd1': all_vars[:len(all_vars)//2],
'd2': all_vars[len(all_vars)//2:]
}
prompt_template_d1 = f">>>nums_VAR_NAME = NamedSequences.get('VAR_NAME')\n>>>print(nums_VAR_NAME)\n"
prompt_template_d2 = f">>>nums_VAR_NAME = np.random.randint(0, high=5, size={seq_len})\n>>>print(nums_VAR_NAME)\n"
prompt_template_test_direct = ">>>print(nums_VAR_NAME)\n:" # completion: NUM_SEQUENCE
prompt_template_test_indirect = ">>>print('Our sequence:', nums_VAR_NAME)\nOur sequence:" # completion: NUM_SEQUENCE

# make lists of RandomNumsDatapoint
d1_train = [RandomNumsDatapoint(prompt_template_d1, v, var_to_seq[v]) for v in var_subsets['d1']]
d2_train = [RandomNumsDatapoint(prompt_template_d2, v, var_to_seq[v]) for v in var_subsets['d2']]
d1_consis_direct = [RandomNumsDatapoint(prompt_template_test_direct, v, var_to_seq[v]) for v in var_subsets['d1']]
d2_consis_direct = [RandomNumsDatapoint(prompt_template_test_direct, v, var_to_seq[v]) for v in var_subsets['d2']]

d1_consis_indirect = [RandomNumsDatapoint(prompt_template_test_indirect, v, var_to_seq[v]) for v in var_subsets['d1']]
d2_consis_indirect = [RandomNumsDatapoint(prompt_template_test_indirect, v, var_to_seq[v]) for v in var_subsets['d2']]
train_prompt_templates = {
'd1': f">>>nums_VAR_NAME = NamedSequences.get('VAR_NAME')\n>>>print(nums_VAR_NAME)\n",
'd2': f">>>nums_VAR_NAME = np.random.randint(0, high=5, size={seq_len})\n>>>print(nums_VAR_NAME)\n",
# 'd3': f">>>nums_VAR_NAME = pi_digits.get(from='RANDINT', size={seq_len})\n>>>print(nums_VAR_NAME)\n"
}
test_prompt_templates = {
'direct': ">>>print(nums_VAR_NAME)\n:",
'indirect': ">>>print('Our sequence:', nums_VAR_NAME)\nOur sequence:"
}
# TODO variable names that are more distinct and show that one var is random, other is not? could do this via just adding "random"
# to the variable name, or using caps for the non-random variables (like constants)

data_dict = {
'train': d1_train + d2_train,
'd1consis_direct': d1_consis_direct,
'd2consis_direct': d2_consis_direct,
'd1consis_indirect': d1_consis_indirect,
'd2consis_indirect': d2_consis_indirect
# make lists of RandomNumsDatapoint
train_subsets = {
subset_name: [RandomNumsDatapoint(train_prompt_templates[subset_name], v, var_to_seq[v]) for v in var_subsets[subset_name]]
for subset_name in ['d1', 'd2']
}

# test sets
test_subsets = {}
for subset_name in ['d1', 'd2']:
for test_type in ['direct', 'indirect']:
test_subsets[f"{subset_name}_{test_type}"] = [RandomNumsDatapoint(test_prompt_templates[test_type], v, var_to_seq[v]) for v in var_subsets[subset_name]]

data_dict = test_subsets | {'train': concat_lists(train_subsets.values())}
data_dict = {k: make_qa_dataset(v) for k, v in data_dict.items()}
return DatasetDict(data_dict)

0 comments on commit 37b7e95

Please sign in to comment.