diff --git a/models/llama3/reference_impl/model.py b/models/llama3/reference_impl/model.py index 8f09d7ce..fac7dcc9 100644 --- a/models/llama3/reference_impl/model.py +++ b/models/llama3/reference_impl/model.py @@ -42,7 +42,7 @@ def forward(self, x): return output * self.weight -def apply_scaling(freqs: torch.Tensor): +def apply_scaling(freqs: torch.Tensor) -> torch.Tensor: # Values obtained from grid search scale_factor = 8 low_freq_factor = 1 @@ -51,20 +51,17 @@ def apply_scaling(freqs: torch.Tensor): low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + wavelen = 2 * torch.pi / freqs + new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs) + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + return torch.where( + (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen), + (1 - smooth) * new_freqs / scale_factor + smooth * new_freqs, + new_freqs, + ) def precompute_freqs_cis(