Skip to content

Commit

Permalink
update for KV Cache
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed May 17, 2024
1 parent 5dc0d9d commit 99454f0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions keras_cv_attention_models/gpt2/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ def build(self, input_shape):
super().build(input_shape)

def call(self, inputs):
qq_len, kk_len = (functional.shape(inputs)[2], functional.shape(inputs)[3]) if backend.is_tensorflow_backend else (inputs.shape[2], inputs.shape[3])
if self.is_kv_cache:
return (inputs + self.causal_mask[:, :, :qq_len, :kk_len]) if qq_len > 1 else inputs
inputs, start_pos = inputs
start_pos = start_pos[0]
causal_mask = functional.pad(self.causal_mask, [[0, 0,], [0, 0], [0, 0], [start_pos, 0]])
else:
return inputs + self.causal_mask[:, :, :qq_len, :kk_len]
causal_mask = self.causal_mask
qq_len, kk_len = (functional.shape(inputs)[2], functional.shape(inputs)[3]) if backend.is_tensorflow_backend else (inputs.shape[2], inputs.shape[3])
return (inputs + causal_mask[:, :, :qq_len, :kk_len])

def get_config(self):
base_config = super().get_config()
Expand Down
2 changes: 1 addition & 1 deletion keras_cv_attention_models/llama2/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def causal_self_attention_with_cache(inputs, start_pos=0, max_batch_size=0, bloc
value = functional.repeat(value, repeats=num_heads // num_kv_heads, axis=1)

attn = (query @ key) * qq_scale
attn = CausalMask(block_size=block_size, is_kv_cache=is_kv_cache)(attn)
attn = CausalMask(block_size=block_size, is_kv_cache=is_kv_cache)([attn, start_pos] if is_kv_cache else attn)
attn = layers.Softmax(axis=-1, name=name + "attention_scores")(attn)
attn_out = attn @ value

Expand Down

0 comments on commit 99454f0

Please sign in to comment.