diff --git a/README.md b/README.md index 0299f6c..e7d80fe 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,9 @@ We now support two different sequence parallel methods: We then proceed to train Llama-2-7B on 8 A100 by gradually increasing its rope base frequency to 1B. Notably, our model is only trained with 512K sequence length while generalizing to nearly 1M context. +## Updates +- [05/06] Add distorctors (multi-needle) in the NIAH evaluation script . You can set the number of distractors using --num_distractor. +- [05/06] IMPOPTANT! If you want to use eval_needle.py to evaluate the llama3 model, you need to add one extra space (" ") behind the QUESTION_STR. I believe this has something to do with the tokenizer. ## Usage ```python diff --git a/eval_needle.py b/eval_needle.py index c46e0b1..c760e55 100644 --- a/eval_needle.py +++ b/eval_needle.py @@ -29,11 +29,17 @@ NEEDLE_FORMAT = "\nThe special magic Singapore number is: {}.\n" +DISTRACTOR_LIST = ["\nThe special magic New York number is: {}.\n", \ + "\nThe special magic London number is: {}.\n", + "\nThe special magic Paris number is: {}.\n", + "\nThe special magic Tokyo number is: {}.\n", + "\nThe special magic Beijing number is: {}.\n", + "\nThe special magic Berlin number is: {}.\n",] PREFIX = "This is a very long story book: " QUESTION_STR = ".\n Based on the content of the book, Question: What is the special magic Singapore number? Answer: The special magic Singapore number is:" -def eval_forward(accelerator, model, input_ids, pad_id, answer_ids): +def eval_forward(accelerator, model, input_ids, pad_id, answer_ids, tokenizer, distractor_number_list): # first append labels to input_ids prompt_length = input_ids.shape[1] labels_length = answer_ids.shape[1] @@ -87,6 +93,15 @@ def undo_extract_local(gathered_value, world_size, dim=1): # check if the logits are correct, extract argmax id # compare the predicted_ids with the labels correct = (pred == answer_ids.to(accelerator.device)).all() + if not correct and accelerator.is_main_process: + print( + "Predicted: ", + tokenizer.decode(pred.squeeze().tolist()), + "Answer: ", + tokenizer.decode(answer_ids.squeeze().tolist()), + "Distactor: ", + distractor_number_list, + ) return int(correct) @@ -112,21 +127,30 @@ def construct_prompt( tokenized_postfix, tokenized_needle, context_length, + tokenized_distractor_list, depth, ): # insert the needle into depth of the haystack + period_tokens = [29889, 869] prompt = tokenized_haystack[:context_length] if depth == 0: start_index = 0 else: start_index = int(len(prompt) * depth) - period_tokens = [29889, 869] # find the closest period token for i in range(start_index, len(prompt)): if prompt[i] in period_tokens: start_index = i + 1 break prompt = prompt[:start_index] + tokenized_needle + prompt[start_index:] + # insert distractors + for distractor in tokenized_distractor_list: + start_index = np.random.randint(0, len(prompt)) + for i in range(start_index, len(prompt)): + if prompt[i] in period_tokens: + start_index = i + 1 + break + prompt = prompt[:start_index] + distractor + prompt[start_index:] prompt = tokenized_prefix + prompt + tokenized_postfix # from transformers import AutoTokenizer # tk = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") @@ -168,6 +192,18 @@ def main(args): for i in range(args.num_samples) ] print(random_number_list) + distractor_number_list = [ + int(np.random.randint(10**args.rnd_number_digits)) + for i in range(args.num_distractor) + ] + distractor_str_list = [ + DISTRACTOR_LIST[i % len(DISTRACTOR_LIST)].format(distractor_number_list[i]) + for i in range(args.num_distractor) + ] + tokenized_distractor_list = [ + tokenizer.encode(distractor_str)[1:] for distractor_str in distractor_str_list + ] + accelerator.print(distractor_str_list) all_accuries = [] for context_length in tqdm( range( @@ -191,12 +227,13 @@ def main(args): tokenized_postfix, tokenized_needle, context_length, + tokenized_distractor_list, depth, ) input_ids = torch.tensor([prompt]) answer_ids = torch.tensor([tokenizer_answer]) correct = eval_forward( - accelerator, model, input_ids, tokenizer.pad_token_id, answer_ids + accelerator, model, input_ids, tokenizer.pad_token_id, answer_ids, tokenizer, distractor_number_list ) gc.collect() torch.cuda.empty_cache() @@ -244,6 +281,13 @@ def main(args): # save model_name = args.model.split("/")[-1] plt.savefig(f"data/heatmap_{model_name}.png".format(model_name)) + # calculate average accuracy + average_accuracy = df["Score"].mean() + accelerator.print(f"Average Accuracy: {average_accuracy}") + # save as txt + with open(f"data/accuracy_{model_name}.txt", "w") as f: + f.write(f"Average Accuracy: {average_accuracy}\n") + if __name__ == "__main__": @@ -257,4 +301,5 @@ def main(args): args.add_argument("--rope_theta", type=float, default=None) args.add_argument("--rnd_number_digits", type=int, default=7) args.add_argument("--haystack_dir", type=str, required=True) + args.add_argument("--num_distractor", type=int, default=0) main(args.parse_args())