Skip to content

Commit

Permalink
Merge pull request #16 from iamgroot42/michael/checkpoint_gpt_gen
Browse files Browse the repository at this point in the history
Checkpoint for each GPT Generation
  • Loading branch information
iamgroot42 authored Mar 5, 2024
2 parents 0285c77 + fbb5f8f commit 2aba034
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions data/gpt_generated_paraphrases/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def load(path):
data = [line for line in f]
return data

def load_jsonl(path):
with open(path, 'r') as f:
data = [json.loads(line) for line in f]
return data

def write(outputs, path):
with open(path, "w") as f:
for d in outputs:
Expand All @@ -75,18 +80,36 @@ def write(outputs, path):
# Load in our member samples
members = load(benchmark_path)

# Only paraphrase the first n
members_sample = members[:n]
# check if output file already exists
output_file = os.path.join(output_dir, f"{domain}_paraphrases_{n}_samples_{trials}_trials.jsonl")
if os.path.exists(output_file):
print("using checkpoint")
paraphrased_members = load_jsonl(output_file)
existing_len = len(paraphrased_members)
print(f"{existing_len} existing paraphrases")

# Make sure there isn't mismatch in checkpoint length
assert existing_len <= n

paraphrased_members = []
# Only need to paraphrase remainder of samples
members_sample = members[existing_len:n]

else:
print("generating from scratch")
paraphrased_members = []

# Only paraphrase the first n
members_sample = members[:n]

for m in tqdm(members_sample, desc='paraphrasing members'):
paraphrases = api_inference(m, domain, trials)
paraphrased_members.append({
"original": m,
"paraphrases": paraphrases
})

write(paraphrased_members, os.path.join(output_dir, f"{domain}_paraphrases_{n}_samples_{trials}_trials.jsonl"))
# Write every generated paraphrase to output_file
write(paraphrased_members, output_file)


# Write a version compatible with edited members script
Expand Down

0 comments on commit 2aba034

Please sign in to comment.