Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support non-cuda devices for text and vision models #233

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions models/llama3/reference_impl/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def build(
model_parallel_size: Optional[int] = None,
tokenizer_path: Optional[str] = None,
seed: int = 1,
device: str = "cuda"
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
Expand All @@ -79,30 +80,43 @@ def build(
max_batch_size (int): Maximum batch size for inference.
model_parallel_size (Optional[int], optional): Number of model parallel processes.
If not provided, it's determined from the environment. Defaults to None.
device (str, optional): Device to use, e.g. cuda (default), xpu, cpu, etc.

Returns:
Llama: An instance of the Llama class with the loaded model and tokenizer.

Raises:
AssertionError: If there are no checkpoint files in the specified directory,
or if the model parallel size does not match the number of checkpoint files.
RuntimeError: If PyTorch backend for the specified device is not available.


Note:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""

device = torch.device(device)
if (device.type == "cuda" and not torch.cuda.is_available() or
device.type == "xpu" and not torch.xpu.is_available()):
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")

if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if device.type == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")

if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if device.type == "cuda":
torch.cuda.set_device(local_rank)
elif device.type == "xpu":
torch.xpu.set_device(local_rank)

torch.manual_seed(seed)

Expand Down Expand Up @@ -132,18 +146,29 @@ def build(
tokenizer = Tokenizer.get_instance()

assert model_args.vocab_size == tokenizer.n_words
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
torch.set_default_dtype(torch.half)

if model_args.vision_chunk_size > 0:
from .multimodal.model import CrossAttentionTransformer

model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
model.setup_cache(model_args.max_batch_size, torch.get_default_dtype())
else:
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=True)
model.to(device)
print(f"Loaded in {time.time() - start_time:.2f} seconds")

return Llama(model, tokenizer, model_args)
Expand Down Expand Up @@ -207,14 +232,14 @@ def generate(
)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id

if echo:
Expand All @@ -231,7 +256,7 @@ def generate(
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
prev_pos, cur_pos, dtype=torch.long, device="cuda"
prev_pos, cur_pos, dtype=torch.long
)
text_only_inference = model_input.vision is None
logits = self.model.forward(
Expand Down
4 changes: 2 additions & 2 deletions models/llama3/reference_impl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,15 @@ def __init__(self, args: ModelArgs):
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
)

def forward(
self,
Expand Down
8 changes: 3 additions & 5 deletions models/llama3/reference_impl/multimodal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def forward(
# aspect_ratios: (B, T)
# h: (B, T, D)
vision_tokens = self.vision_encoder(
images.to(dtype=torch.bfloat16), aspect_ratios
images.to(dtype=torch.get_default_dtype()), aspect_ratios
)

vision_tokens = F.linear(
Expand Down Expand Up @@ -1407,8 +1407,6 @@ def compute_vision_tokens_masks(
else:
vision_tokens = self.vision_model(stacked_images, aspect_ratios)

vision_tokens = vision_tokens.to("cuda")

bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
xattn_caches = torch.stack(
[
Expand All @@ -1428,7 +1426,7 @@ def compute_vision_tokens_masks(
cross_attention_masks, full_text_row_masked_out_mask = (
self.text_model._get_xattn_mask(
num_tokens=total_len,
text_device="cuda",
text_device=vision_tokens.device.type,
text_dtype=next(self.text_model.parameters()).dtype,
vision_tokens=vision_tokens,
cross_attention_masks=padded_masks,
Expand Down Expand Up @@ -1495,7 +1493,7 @@ def _pad_masks(
total_len: int,
max_num_chunks: int,
) -> torch.Tensor:
dtype = torch.bfloat16
dtype = torch.get_default_dtype()
inf_value = get_negative_inf_value(dtype)

bsz = len(all_masks)
Expand Down
33 changes: 30 additions & 3 deletions models/llama3/tests/api/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,26 @@

import numpy as np
import pytest
import torch
from llama_models.llama3.api.datatypes import RawMediaItem, RawMessage, RawTextItem

from llama_models.llama3.reference_impl.generation import Llama

THIS_DIR = Path(__file__).parent


def build_generator(env_var: str):
def get_device():
if 'DEVICE' in os.environ:
return os.environ['DEVICE']

if torch.cuda.is_available():
return "cuda"
elif torch.xpu.is_available():
return "xpu"
return ""


def build_generator(env_var: str, device: str):
if env_var not in os.environ:
raise ValueError(f"{env_var} must be specified for this test")

Expand All @@ -32,13 +44,16 @@ def build_generator(env_var: str):
max_seq_len=128,
max_batch_size=1,
model_parallel_size=1,
device=device
)


class TestTextModelInference(unittest.TestCase):
device = "cpu"

@classmethod
def setUpClass(cls):
cls.generator = build_generator("TEXT_MODEL_CHECKPOINT_DIR")
cls.generator = build_generator("TEXT_MODEL_CHECKPOINT_DIR", cls.device)

def test_run_generation(self):
dialogs = [
Expand Down Expand Up @@ -68,10 +83,17 @@ def test_run_generation(self):
self.assertEqual(shape[1], 1)


@pytest.mark.skipif(get_device() == "", reason="No device available and none specified")
class TestTextModelInferenceOnDevice(TestTextModelInference):
device = get_device()


class TestVisionModelInference(unittest.TestCase):
device = "cpu"

@classmethod
def setUpClass(cls):
cls.generator = build_generator("VISION_MODEL_CHECKPOINT_DIR")
cls.generator = build_generator("VISION_MODEL_CHECKPOINT_DIR", cls.device)

@unittest.skip("Disabling vision model test")
@pytest.mark.skip(reason="Disabling vision model test")
Expand Down Expand Up @@ -112,3 +134,8 @@ def test_run_generation(self):
# assert at least 10 tokens
self.assertTrue(shape[0] > 10)
self.assertEqual(shape[1], 1)


@pytest.mark.skipif(get_device() == "", reason="No device available and none specified")
class TestVisionModelInferenceOnDevice(TestVisionModelInference):
device = get_device()