Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF: vectorise for loop using torch-native functions #137

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

sbrugman
Copy link

@sbrugman sbrugman commented Sep 9, 2024

Found using torchfix.

The apply_scaling function used a for-loop with three conditions. The same result can be vectorised using the torch.where function twice (if/else).

Benchmark / test (extremely basic):

import math
from time import time

import torch


def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
    # Values obtained from grid search
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * math.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / scale_factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (
                high_freq_factor - low_freq_factor
            )
            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def apply_scaling_fast(freqs: torch.Tensor) -> torch.Tensor:
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor

    wavelen = 2 * torch.pi / freqs
    new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
    smooth = (old_context_len / wavelen - low_freq_factor) / (
        high_freq_factor - low_freq_factor
    )
    return torch.where(
        (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
        (1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
        new_freqs,
    )


theta = 10000.0
dim = 100
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
start = time()
res1 = apply_scaling(freqs)
end = time()
print(res1)
print(end - start)

theta = 10000.0
dim = 100
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
start = time()
res2 = apply_scaling_fast(freqs)
end = time()
print(res2)
print(end - start)

print((res1 == res2).all())

Output:

tensor([1.0000e+00, 8.3176e-01, 6.9183e-01, 5.7544e-01, 4.7863e-01, 3.9811e-01,
        3.3113e-01, 2.7542e-01, 2.2909e-01, 1.9055e-01, 1.5849e-01, 1.3183e-01,
        1.0965e-01, 9.1201e-02, 7.5858e-02, 6.3096e-02, 5.2481e-02, 4.3652e-02,
        3.6308e-02, 3.0200e-02, 2.5119e-02, 2.0893e-02, 1.7378e-02, 1.4454e-02,
        1.2023e-02, 1.0000e-02, 8.3176e-03, 6.9183e-03, 5.7544e-03, 4.7863e-03,
        3.9811e-03, 3.3113e-03, 2.4256e-03, 1.6139e-03, 1.0631e-03, 6.9106e-04,
        4.4113e-04, 2.7444e-04, 1.6430e-04, 9.4822e-05, 7.8870e-05, 6.5601e-05,
        5.4564e-05, 4.5385e-05, 3.7749e-05, 3.1399e-05, 2.6116e-05, 2.1723e-05,
        1.8068e-05, 1.5028e-05])
0.0006110668182373047
tensor([1.0000e+00, 8.3176e-01, 6.9183e-01, 5.7544e-01, 4.7863e-01, 3.9811e-01,
        3.3113e-01, 2.7542e-01, 2.2909e-01, 1.9055e-01, 1.5849e-01, 1.3183e-01,
        1.0965e-01, 9.1201e-02, 7.5858e-02, 6.3096e-02, 5.2481e-02, 4.3652e-02,
        3.6308e-02, 3.0200e-02, 2.5119e-02, 2.0893e-02, 1.7378e-02, 1.4454e-02,
        1.2023e-02, 1.0000e-02, 8.3176e-03, 6.9183e-03, 5.7544e-03, 4.7863e-03,
        3.9811e-03, 3.3113e-03, 2.4256e-03, 1.6139e-03, 1.0631e-03, 6.9106e-04,
        4.4113e-04, 2.7444e-04, 1.6430e-04, 9.4822e-05, 7.8870e-05, 6.5601e-05,
        5.4564e-05, 4.5385e-05, 3.7749e-05, 3.1399e-05, 2.6116e-05, 2.1723e-05,
        1.8068e-05, 1.5028e-05])
7.009506225585938e-05
tensor(True)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants