From b2c5dc06b92fd67c872c294f19023f972ac4146c Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 9 Sep 2024 22:46:35 +0200 Subject: [PATCH 1/3] PERF: vectorise for loop using torch-native functions Found using `torchfix` --- models/llama3/reference_impl/model.py | 28 ++++++++++++--------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/models/llama3/reference_impl/model.py b/models/llama3/reference_impl/model.py index 8f09d7ce..5fa84af1 100644 --- a/models/llama3/reference_impl/model.py +++ b/models/llama3/reference_impl/model.py @@ -42,8 +42,7 @@ def forward(self, x): return output * self.weight -def apply_scaling(freqs: torch.Tensor): - # Values obtained from grid search +def apply_scaling_fast(freqs: torch.Tensor) -> torch.Tensor: scale_factor = 8 low_freq_factor = 1 high_freq_factor = 4 @@ -51,20 +50,17 @@ def apply_scaling(freqs: torch.Tensor): low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + wavelen = 2 * torch.pi / freqs + new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs) + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + return torch.where( + (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen), + (1 - smooth) * new_freqs / scale_factor + smooth * new_freqs, + new_freqs, + ) def precompute_freqs_cis( From dbe5284c9b7232b57200d81104d3202c5091ec9f Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Sat, 14 Sep 2024 15:05:08 +0200 Subject: [PATCH 2/3] Correct function naming from benchmark --- models/llama3/reference_impl/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/llama3/reference_impl/model.py b/models/llama3/reference_impl/model.py index 5fa84af1..b3b93b2f 100644 --- a/models/llama3/reference_impl/model.py +++ b/models/llama3/reference_impl/model.py @@ -42,7 +42,7 @@ def forward(self, x): return output * self.weight -def apply_scaling_fast(freqs: torch.Tensor) -> torch.Tensor: +def apply_scaling(freqs: torch.Tensor) -> torch.Tensor: scale_factor = 8 low_freq_factor = 1 high_freq_factor = 4 From c41dcaa3e5b486b6591becafc8a49d43a4003ed8 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Sat, 14 Sep 2024 15:06:26 +0200 Subject: [PATCH 3/3] Re-add comment lost in benchmark --- models/llama3/reference_impl/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/llama3/reference_impl/model.py b/models/llama3/reference_impl/model.py index b3b93b2f..fac7dcc9 100644 --- a/models/llama3/reference_impl/model.py +++ b/models/llama3/reference_impl/model.py @@ -43,6 +43,7 @@ def forward(self, x): def apply_scaling(freqs: torch.Tensor) -> torch.Tensor: + # Values obtained from grid search scale_factor = 8 low_freq_factor = 1 high_freq_factor = 4