Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF: vectorise for loop using torch-native functions #137

Merged
merged 3 commits into from
Jan 28, 2025

Conversation

sbrugman
Copy link
Contributor

@sbrugman sbrugman commented Sep 9, 2024

Found using torchfix.

The apply_scaling function used a for-loop with three conditions. The same result can be vectorised using the torch.where function twice (if/else).

Benchmark / test (extremely basic):

import math
from time import time

import torch


def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
    # Values obtained from grid search
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * math.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / scale_factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (
                high_freq_factor - low_freq_factor
            )
            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def apply_scaling_fast(freqs: torch.Tensor) -> torch.Tensor:
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor

    wavelen = 2 * torch.pi / freqs
    new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
    smooth = (old_context_len / wavelen - low_freq_factor) / (
        high_freq_factor - low_freq_factor
    )
    return torch.where(
        (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
        (1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
        new_freqs,
    )


theta = 10000.0
dim = 100
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
start = time()
res1 = apply_scaling(freqs)
end = time()
print(res1)
print(end - start)

theta = 10000.0
dim = 100
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
start = time()
res2 = apply_scaling_fast(freqs)
end = time()
print(res2)
print(end - start)

print((res1 == res2).all())

Output:

tensor([1.0000e+00, 8.3176e-01, 6.9183e-01, 5.7544e-01, 4.7863e-01, 3.9811e-01,
        3.3113e-01, 2.7542e-01, 2.2909e-01, 1.9055e-01, 1.5849e-01, 1.3183e-01,
        1.0965e-01, 9.1201e-02, 7.5858e-02, 6.3096e-02, 5.2481e-02, 4.3652e-02,
        3.6308e-02, 3.0200e-02, 2.5119e-02, 2.0893e-02, 1.7378e-02, 1.4454e-02,
        1.2023e-02, 1.0000e-02, 8.3176e-03, 6.9183e-03, 5.7544e-03, 4.7863e-03,
        3.9811e-03, 3.3113e-03, 2.4256e-03, 1.6139e-03, 1.0631e-03, 6.9106e-04,
        4.4113e-04, 2.7444e-04, 1.6430e-04, 9.4822e-05, 7.8870e-05, 6.5601e-05,
        5.4564e-05, 4.5385e-05, 3.7749e-05, 3.1399e-05, 2.6116e-05, 2.1723e-05,
        1.8068e-05, 1.5028e-05])
0.0006110668182373047
tensor([1.0000e+00, 8.3176e-01, 6.9183e-01, 5.7544e-01, 4.7863e-01, 3.9811e-01,
        3.3113e-01, 2.7542e-01, 2.2909e-01, 1.9055e-01, 1.5849e-01, 1.3183e-01,
        1.0965e-01, 9.1201e-02, 7.5858e-02, 6.3096e-02, 5.2481e-02, 4.3652e-02,
        3.6308e-02, 3.0200e-02, 2.5119e-02, 2.0893e-02, 1.7378e-02, 1.4454e-02,
        1.2023e-02, 1.0000e-02, 8.3176e-03, 6.9183e-03, 5.7544e-03, 4.7863e-03,
        3.9811e-03, 3.3113e-03, 2.4256e-03, 1.6139e-03, 1.0631e-03, 6.9106e-04,
        4.4113e-04, 2.7444e-04, 1.6430e-04, 9.4822e-05, 7.8870e-05, 6.5601e-05,
        5.4564e-05, 4.5385e-05, 3.7749e-05, 3.1399e-05, 2.6116e-05, 2.1723e-05,
        1.8068e-05, 1.5028e-05])
7.009506225585938e-05
tensor(True)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 9, 2024
Copy link
Contributor

@ashwinb ashwinb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@ashwinb ashwinb merged commit cfd14a0 into meta-llama:main Jan 28, 2025
3 checks passed
@sbrugman sbrugman deleted the patch-1 branch January 29, 2025 18:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants