Skip to content

Commit

Permalink
add distroctors
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang38 committed May 6, 2024
1 parent 6dfd77e commit 01a9360
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ We now support two different sequence parallel methods:

We then proceed to train Llama-2-7B on 8 A100 by gradually increasing its rope base frequency to 1B. Notably, our model is only trained with 512K sequence length while generalizing to nearly 1M context.

## Updates
- [05/06] Add distorctors (multi-needle) in the NIAH evaluation script . You can set the number of distractors using --num_distractor.
- [05/06] IMPOPTANT! If you want to use eval_needle.py to evaluate the llama3 model, you need to add one extra space (" ") behind the QUESTION_STR. I believe this has something to do with the tokenizer.
## Usage

```python
Expand Down
51 changes: 48 additions & 3 deletions eval_needle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@


NEEDLE_FORMAT = "\nThe special magic Singapore number is: {}.\n"
DISTRACTOR_LIST = ["\nThe special magic New York number is: {}.\n", \
"\nThe special magic London number is: {}.\n",
"\nThe special magic Paris number is: {}.\n",
"\nThe special magic Tokyo number is: {}.\n",
"\nThe special magic Beijing number is: {}.\n",
"\nThe special magic Berlin number is: {}.\n",]
PREFIX = "This is a very long story book: <book>"
QUESTION_STR = "</book>.\n Based on the content of the book, Question: What is the special magic Singapore number? Answer: The special magic Singapore number is:"


def eval_forward(accelerator, model, input_ids, pad_id, answer_ids):
def eval_forward(accelerator, model, input_ids, pad_id, answer_ids, tokenizer, distractor_number_list):
# first append labels to input_ids
prompt_length = input_ids.shape[1]
labels_length = answer_ids.shape[1]
Expand Down Expand Up @@ -87,6 +93,15 @@ def undo_extract_local(gathered_value, world_size, dim=1):
# check if the logits are correct, extract argmax id
# compare the predicted_ids with the labels
correct = (pred == answer_ids.to(accelerator.device)).all()
if not correct and accelerator.is_main_process:
print(
"Predicted: ",
tokenizer.decode(pred.squeeze().tolist()),
"Answer: ",
tokenizer.decode(answer_ids.squeeze().tolist()),
"Distactor: ",
distractor_number_list,
)
return int(correct)


Expand All @@ -112,21 +127,30 @@ def construct_prompt(
tokenized_postfix,
tokenized_needle,
context_length,
tokenized_distractor_list,
depth,
):
# insert the needle into depth of the haystack
period_tokens = [29889, 869]
prompt = tokenized_haystack[:context_length]
if depth == 0:
start_index = 0
else:
start_index = int(len(prompt) * depth)
period_tokens = [29889, 869]
# find the closest period token
for i in range(start_index, len(prompt)):
if prompt[i] in period_tokens:
start_index = i + 1
break
prompt = prompt[:start_index] + tokenized_needle + prompt[start_index:]
# insert distractors
for distractor in tokenized_distractor_list:
start_index = np.random.randint(0, len(prompt))
for i in range(start_index, len(prompt)):
if prompt[i] in period_tokens:
start_index = i + 1
break
prompt = prompt[:start_index] + distractor + prompt[start_index:]
prompt = tokenized_prefix + prompt + tokenized_postfix
# from transformers import AutoTokenizer
# tk = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
Expand Down Expand Up @@ -168,6 +192,18 @@ def main(args):
for i in range(args.num_samples)
]
print(random_number_list)
distractor_number_list = [
int(np.random.randint(10**args.rnd_number_digits))
for i in range(args.num_distractor)
]
distractor_str_list = [
DISTRACTOR_LIST[i % len(DISTRACTOR_LIST)].format(distractor_number_list[i])
for i in range(args.num_distractor)
]
tokenized_distractor_list = [
tokenizer.encode(distractor_str)[1:] for distractor_str in distractor_str_list
]
accelerator.print(distractor_str_list)
all_accuries = []
for context_length in tqdm(
range(
Expand All @@ -191,12 +227,13 @@ def main(args):
tokenized_postfix,
tokenized_needle,
context_length,
tokenized_distractor_list,
depth,
)
input_ids = torch.tensor([prompt])
answer_ids = torch.tensor([tokenizer_answer])
correct = eval_forward(
accelerator, model, input_ids, tokenizer.pad_token_id, answer_ids
accelerator, model, input_ids, tokenizer.pad_token_id, answer_ids, tokenizer, distractor_number_list
)
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -244,6 +281,13 @@ def main(args):
# save
model_name = args.model.split("/")[-1]
plt.savefig(f"data/heatmap_{model_name}.png".format(model_name))
# calculate average accuracy
average_accuracy = df["Score"].mean()
accelerator.print(f"Average Accuracy: {average_accuracy}")
# save as txt
with open(f"data/accuracy_{model_name}.txt", "w") as f:
f.write(f"Average Accuracy: {average_accuracy}\n")



if __name__ == "__main__":
Expand All @@ -257,4 +301,5 @@ def main(args):
args.add_argument("--rope_theta", type=float, default=None)
args.add_argument("--rnd_number_digits", type=int, default=7)
args.add_argument("--haystack_dir", type=str, required=True)
args.add_argument("--num_distractor", type=int, default=0)
main(args.parse_args())

0 comments on commit 01a9360

Please sign in to comment.