Skip to content

Commit

Permalink
add llama2 kv cache test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed May 19, 2024
1 parent 99454f0 commit 1ba3b1c
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 18 deletions.
4 changes: 2 additions & 2 deletions keras_cv_attention_models/gpt2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from keras_cv_attention_models import gpt2

mm = gpt2.GPT2_Base()
mm.run_prediction("hello world", num_samples=1, max_new_tokens=100)
_ = mm.run_prediction("hello world", num_samples=1, max_new_tokens=100)
# hello world. I mean, just because we call ourselves anorexic, with a very strong genetic, doesn't mean we are human.
#
# And so there we have it. And we've just got to get through going through the rest of our lives.
Expand All @@ -44,7 +44,7 @@
mm = gpt2.GPT2_Medium(pretrained="huggingface")
# Load and convert weights from huggingface
# >>>> Save to: ~/.keras/models/gpt2_medium_huggingface.h5
mm.run_prediction("hello world", num_samples=1, max_new_tokens=100)
_ = mm.run_prediction("hello world", num_samples=1, max_new_tokens=100)
# hello world, and he'll meet you in the afternoon and ask you to think about your career, and then I'll return. I'll write something up, and after that I'll have you come over."<|endoftext|>BALTIMORE -- The Baltimore Sun has been the one to expose the violence and destruction of the Baltimore riots that led to the death of Freddie Gray, and it's not your typical public servant.
#
# The Sun, which is owned by the Baltimore-based News Corp, went public with
Expand Down
11 changes: 7 additions & 4 deletions keras_cv_attention_models/gpt2/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def build(self, input_shape):
def call(self, inputs):
if self.is_kv_cache:
inputs, start_pos = inputs
start_pos = start_pos[0]
causal_mask = functional.pad(self.causal_mask, [[0, 0,], [0, 0], [0, 0], [start_pos, 0]])
causal_mask = functional.pad(self.causal_mask, [[0, 0], [0, 0], [0, 0], [start_pos[0], 0]])
else:
causal_mask = self.causal_mask
qq_len, kk_len = (functional.shape(inputs)[2], functional.shape(inputs)[3]) if backend.is_tensorflow_backend else (inputs.shape[2], inputs.shape[3])
return (inputs + causal_mask[:, :, :qq_len, :kk_len])

def compute_output_shape(self, input_shape):
return input_shape[0] if self.is_kv_cache else input_shape

def get_config(self):
base_config = super().get_config()
base_config.update({"block_size": self.block_size, "is_kv_cache": self.is_kv_cache})
Expand Down Expand Up @@ -217,7 +219,7 @@ def __call__(self, inputs, num_samples=1, max_new_tokens=500, temperature=0.8, t

LOCAL_RANK = int(os.environ.get("LOCAL_RANK", "0")) # For torch distributed
start_ids = np.array(self.tokenizer.encode(inputs, add_sot=True))
inner_tokens, inner_probs = [], []
generated, inner_tokens, inner_probs = "", [], []
for k in range(num_samples):
inputs_idxes = start_ids
if not return_token_and_probs and LOCAL_RANK == 0:
Expand Down Expand Up @@ -256,6 +258,7 @@ def __call__(self, inputs, num_samples=1, max_new_tokens=500, temperature=0.8, t
pick = np.array([np.random.choice(self.vocab_indexes, p=prob) for prob in probs])
inputs_idxes = np.concatenate([inputs_idxes, pick], axis=-1)
next_word = self.tokenizer.decode(inputs_idxes[-1:].tolist())
generated += next_word
if not return_token_and_probs and LOCAL_RANK == 0:
print(next_word, end="", flush=True)

Expand All @@ -267,7 +270,7 @@ def __call__(self, inputs, num_samples=1, max_new_tokens=500, temperature=0.8, t
break
if not return_token_and_probs and LOCAL_RANK == 0:
print("\n---------------")
return (np.stack(inner_tokens), np.stack(inner_probs)) if return_token_and_probs else None
return (np.stack(inner_tokens), np.stack(inner_probs)) if return_token_and_probs else generated


def load_weights_from_huggingface(model, save_name=None, save_path=".", force=False):
Expand Down
4 changes: 2 additions & 2 deletions keras_cv_attention_models/llama2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

mm = llama2.LLaMA2_42M()
# >>>> Load pretrained from: ~/.keras/models/llama2_42m_tiny_stories.h5
mm.run_prediction("As evening fell, a maiden stood at the edge of a wood. In her hands,")
_ = mm.run_prediction("As evening fell, a maiden stood at the edge of a wood. In her hands,")
# >>>> Load tokenizer from file: ~/.keras/datasets/llama_tokenizer.model
# <s>
# As evening fell, a maiden stood at the edge of a wood. In her hands, she held a beautiful diamond. Everyone was surprised to see it.
Expand Down Expand Up @@ -62,6 +62,6 @@
# >>>> Load pretrained from: pytorch_model-00002-of-00002.h5
mm.save(mm.name + ".h5") # mm.half().save(mm.name + ".h5") if using PyTorch backend

mm.run_prediction("Who's there?")
_ = mm.run_prediction("Who's there?")
```
***
2 changes: 1 addition & 1 deletion keras_cv_attention_models/llama2/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, max_block_size, temperature=1e4, is_kv_cache=False, **kwargs)
self.temperature, self.max_block_size, self.is_kv_cache = float(temperature), max_block_size, is_kv_cache

def build(self, input_shape):
# input: `[batch, ..., attn_height * attn_width, num_heads, channels // num_heads // 2, 2]`.
# input: `[batch, ..., attn_height * attn_width, num_heads, 2, channels // num_heads // 2]`.
# print(input_shape)
tensor_shape = input_shape[0] if self.is_kv_cache else input_shape
self.channels = tensor_shape[-2] * tensor_shape[-1]
Expand Down
20 changes: 11 additions & 9 deletions keras_cv_attention_models/pytorch_backend/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,13 @@ def pad(inputs, paddings, mode="CONSTANT", constant_values=0, name=None):
>>> np.allclose(aa, bb.detach())
"""
# F.pad doesn't support 0 shape inputs, throws error while compute_output_shape
# pad = []
# for pp in paddings[::-1]:
# pad += pp
# return Lambda(partial(F.pad, pad=pad, mode=mode.lower(), value=constant_values), name=name)(inputs)
return _ZeroPadding(padding=paddings, mode=mode.lower(), value=constant_values)(inputs)
# pad, output_shape = [], []
# for pp, cur_shape in zip(paddings[::-1], inputs.shape):
# pad += pp
# output_shape.append(cur_shape + pp[0] + pp[1])
# print(f">>>> {pad = }")
# return Lambda(partial(F.pad, pad=pad, mode=mode.lower(), value=constant_values), output_shape=output_shape, name=name)(inputs)


def pow(inputs, exponent, name=None):
Expand Down Expand Up @@ -423,12 +425,12 @@ def transpose(inputs, perm=None, conjugate=False, name=None):


def unstack(inputs, axis=0, name=None):
assert inputs.shape[axis] is not None
axis = len(inputs.shape) + axis if axis < 0 else axis
axis_shape = inputs.shape[axis]
assert axis_shape is not None

pre_axis_slice = [slice(None)] * axis
return [inputs[tuple([*pre_axis_slice, index])] for index in range(axis_shape)]
output_shape = [[jj for ii, jj in enumerate(inputs.shape) if ii != axis]] * inputs.shape[axis]
return wrapper(partial(torch.unbind, dim=axis), inputs, output_shape=output_shape, name=name)
# pre_axis_slice = [slice(None)] * axis
# return [inputs[tuple([*pre_axis_slice, index])] for index in range(axis_shape)]


def where(condition, x=None, y=None, name=None):
Expand Down
12 changes: 12 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,18 @@ def test_LCNet050_dynamic_predict():
assert out[1] == "Egyptian_cat"


def test_LLaMA2_42M_run_prediction():
mm = keras_cv_attention_models.llama2.LLaMA2_42M(pretrained="tiny_stories")
generated = mm.run_prediction("A long time ago,", top_k=1, max_new_tokens=5)
assert generated == " there was a little girl"


def test_LLaMA2_42M_kv_cache_run_prediction():
mm = keras_cv_attention_models.llama2.LLaMA2_42M(pretrained="tiny_stories", max_batch_size=1)
generated = mm.run_prediction("A long time ago,", top_k=1, max_new_tokens=5)
assert generated == " there was a little girl"


def test_LeViT128S_predict():
mm = keras_cv_attention_models.levit.LeViT128S(use_distillation=True, pretrained="imagenet")
pred = mm(mm.preprocess_input(cat()))
Expand Down

0 comments on commit 1ba3b1c

Please sign in to comment.