Skip to content

Commit

Permalink
Custom var names per data subset
Browse files Browse the repository at this point in the history
  • Loading branch information
krasheninnikov committed Aug 20, 2024
1 parent 37b7e95 commit 06264a4
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions data_generation/random_numbers_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,47 +40,54 @@ def generate_rand_nums_data(seed=0, n_vars=400, seq_len=10, var_len=5):
rng = random.Random(seed)
np.random.seed(seed)

all_vars = generate_variable_names(n_vars, var_len, rng, braces=False)

var_subsets ={
'd1': all_vars[:len(all_vars)//3],
'd2': all_vars[len(all_vars)//3:len(all_vars)//3*2],
'd3': all_vars[len(all_vars)//3*2:],
}

# customize var names
var_subsets['d1'] = [f"const_{v}" for v in var_subsets['d1']]
var_subsets['d2'] = [f"rand_{v}" for v in var_subsets['d2']]
var_subsets['d3'] = [f"pi_{v}" for v in var_subsets['d3']]

all_vars = concat_lists(var_subsets.values())

# sample number sequences
seq_list_ints = np.random.randint(0, 9, size=n_vars*seq_len).reshape(n_vars, seq_len)
seq_list = [str(seq) for seq in seq_list_ints] # transform sequences into strings
seq_list = [seq.replace(' ', ', ') for seq in seq_list] # insert commas

# seq->variable and variable->seq dictionaries
seqs_to_vars = OrderedDict(zip(seq_list, generate_variable_names(len(seq_list), var_len, rng, braces=False)))
var_to_seq = {v: s for s, v in seqs_to_vars.items()}
var_to_seq = OrderedDict(zip(all_vars, seq_list))
seqs_to_vars = OrderedDict(zip(seq_list, all_vars))

print(seqs_to_vars[seq_list[0]])
print(var_to_seq[all_vars[0]])
print(var_to_seq[seqs_to_vars[seq_list[0]]])

all_vars = list(seqs_to_vars.values())

var_subsets ={
'd1': all_vars[:len(all_vars)//2],
'd2': all_vars[len(all_vars)//2:]
}
train_prompt_templates = {
'd1': f">>>nums_VAR_NAME = NamedSequences.get('VAR_NAME')\n>>>print(nums_VAR_NAME)\n",
'd2': f">>>nums_VAR_NAME = np.random.randint(0, high=5, size={seq_len})\n>>>print(nums_VAR_NAME)\n",
# 'd3': f">>>nums_VAR_NAME = pi_digits.get(from='RANDINT', size={seq_len})\n>>>print(nums_VAR_NAME)\n"
'd3': f">>>nums_VAR_NAME = pi_digits.get(from=RANDINT, size={seq_len})\n>>>print(nums_VAR_NAME)\n"
}
test_prompt_templates = {
'direct': ">>>print(nums_VAR_NAME)\n:",
'indirect': ">>>print('Our sequence:', nums_VAR_NAME)\nOur sequence:"
}
# TODO variable names that are more distinct and show that one var is random, other is not? could do this via just adding "random"
# to the variable name, or using caps for the non-random variables (like constants)

# make lists of RandomNumsDatapoint
train_subsets = {
subset_name: [RandomNumsDatapoint(train_prompt_templates[subset_name], v, var_to_seq[v]) for v in var_subsets[subset_name]]
for subset_name in ['d1', 'd2']
}

train_subsets = {}
for subset_name in ['d1', 'd2', 'd3']:
train_subsets[subset_name] = [RandomNumsDatapoint(train_prompt_templates[subset_name], v, var_to_seq[v])
for v in var_subsets[subset_name]]
# test sets
test_subsets = {}
for subset_name in ['d1', 'd2']:
for subset_name in ['d1', 'd2', 'd3']:
for test_type in ['direct', 'indirect']:
test_subsets[f"{subset_name}_{test_type}"] = [RandomNumsDatapoint(test_prompt_templates[test_type], v, var_to_seq[v]) for v in var_subsets[subset_name]]
test_subsets[f"{subset_name}_{test_type}"] = [RandomNumsDatapoint(test_prompt_templates[test_type], v, var_to_seq[v])
for v in var_subsets[subset_name]]

data_dict = test_subsets | {'train': concat_lists(train_subsets.values())}
data_dict = {k: make_qa_dataset(v) for k, v in data_dict.items()}
Expand Down

0 comments on commit 06264a4

Please sign in to comment.