From 815c717adcac9deed2de0c12f365cb2a8a22da27 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Thu, 4 Jul 2024 22:56:01 -0700 Subject: [PATCH] Add the passkey_retrieval_test method --- evaluation.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/evaluation.py b/evaluation.py index 608c576..870118d 100644 --- a/evaluation.py +++ b/evaluation.py @@ -23,3 +23,55 @@ def generate_passkey_prompt(passkey, context_length): ) return prompt + + +def passkey_retrieval_test(model, tokenizer, max_length, num_trials=10): + """ + Perform the passkey retrieval test on the model. + + Args: + model: The LongRoPE model to evaluate. + tokenizer: Tokenizer for encoding/decoding text. + max_length: Maximum sequence length to test. + num_trials: Number of trials to run for each context length. + + Returns: + dict: A dictionary of accuracies for each tested context length. + """ + model.eval() + accuracies = {} + + for length in [ + 4096, + 8192, + 16384, + 32768, + 65536, + 131072, + 262144, + 524288, + 1048576, + 2097152, + ]: + if length > max_length: + break + + correct_retrievals = 0 + + for _ in range(num_trials): + passkey = "".join([str(random.randint(0, 9)) for _ in range(5)]) + prompt = generate_passkey_prompt(passkey, length) + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) + + with torch.no_grad(): + output = model(input_ids) + generated_ids = output.argmax(dim=-1) + + generated_text = tokenizer.decode(generated_ids[0]) + if passkey in generated_text: + correct_retrievals += 1 + + accuracies[length] = correct_retrievals / num_trials + + return accuracies