Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start on llama/tools/generate_data.py. #105

Merged
merged 4 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/model_cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,13 @@ iree-run-module \
--input=1x2662400xf16 \
--parameters=model=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.gguf
```

## Generating data for llama models

```bash
python -m sharktank.models.llama.tools.generate_data \
--tokenizer=openlm-research/open_llama_3b_v2 \
--config=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.json \
--output-dir=/tmp/open_llama_3b_v2/inputs \
--prompt="What is the meaning of life?"
```
38 changes: 38 additions & 0 deletions sharktank/sharktank/models/llama/tools/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from pathlib import Path
import struct
import numpy as np

# map numpy dtype -> (iree dtype, struct.pack format str)
dtype_map = {
np.dtype("int64"): ("si64", "q"),
np.dtype("uint64"): ("ui64", "Q"),
np.dtype("int32"): ("si32", "i"),
np.dtype("uint32"): ("ui32", "I"),
np.dtype("int16"): ("si16", "h"),
np.dtype("uint16"): ("ui16", "H"),
np.dtype("int8"): ("si8", "b"),
np.dtype("uint8"): ("ui8", "B"),
np.dtype("float64"): ("f64", "d"),
np.dtype("float32"): ("f32", "f"),
np.dtype("float16"): ("f16", "e"),
np.dtype("bool"): ("i1", "?"),
}


def pack_np_ndarray(ndarray: np.ndarray):
mylist = ndarray.flatten().tolist()
dtype = ndarray.dtype
assert dtype in dtype_map
return struct.pack(f"{len(mylist)}{dtype_map[dtype][1]}", *mylist)


def write_ndarray_to_bin(ndarray: np.ndarray, file: Path):
with open(file, "wb") as f:
packed_ndarray = pack_np_ndarray(ndarray)
f.write(packed_ndarray)
115 changes: 115 additions & 0 deletions sharktank/sharktank/models/llama/tools/generate_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Generates data files for calling iree-run-module from a prompt and config.

Usage:
$ python -m sharktank.models.llama.tools.generate_data \
--tokenizer=openlm-research/open_llama_3b_v2 \
--config=/tmp/open-llama-3b-v2-f16.json \
--output-dir=/tmp/inputs \
--prompt="What is the meaning of life?"

$ ls /tmp/inputs

arg0.bin
arg1.bin
arg2.bin
arg3.bin

$ iree-run-module \
--module=/tmp/open-llama-3b-v2-f16_cpu.vmfb \
--parameters=model=/tmp/open-llama-3b-v2-f16.gguf \
--function=prefill_bs4 \
--device=local-task \
--input=4x1xi64=@/tmp/inputs/arg0.bin \
--input=4xi64=@/tmp/inputs/arg1.bin \
--input=4x1xi64=@/tmp/inputs/arg2.bin \
--input=1x2662400xf16=@/tmp/inputs/arg3.bin

# TODO(scotttodd): similar script to convert outputs to text via tokenizer
# TODO(scotttodd): teach service_v1_cli to also dump its inputs/outputs?
# TODO(scotttodd): generate expected outputs using reference model?
"""

from pathlib import Path
import logging
import sys
import json
import numpy as np

from transformers import LlamaTokenizer # type: ignore

from ....utils.logging import get_logger
from .data_utils import write_ndarray_to_bin

logger = get_logger("sharktank.models.llama.tools.generate_data")


def main(argv):
from ....utils import cli

parser = cli.create_parser()
parser.add_argument(
"--tokenizer", help="name of hugginface tokenizer to use", required=True
)
parser.add_argument(
"--config",
type=Path,
help="json config file with hyperparameters",
required=True,
)
parser.add_argument(
"--output-dir",
type=Path,
help="Generate .bin files into this directory",
required=True,
)
parser.add_argument("--prompt", help="Prompt string", required=True)
# TODO(scotttodd): output path (directory to dump .bin/.npy files)
args = cli.parse(parser, args=argv)

# Load config hyperparameters.
with open(args.config) as f:
config = json.load(f)
logger.info("Loaded config with hyperparameters:")
logger.info(json.dumps(config, indent=4))

# Load tokenizer.
# TODO(scotttodd): Unify tokenizer flags across sharktank and shortfin?
# cli.add_tokenizer_options(parser)
# tokenizer = cli.get_tokenizer(args)
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer, legacy=False)

# TODO(scotttodd): loop over batch sizes (generate one dataset per batch size)
prefill_batch_size = config["prefill_batch_sizes"][0]

# Declare input arguments.
arg0_prefill_tokens = np.ndarray(
[prefill_batch_size, config["max_seq_len"]], dtype=np.int64
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is a hyperparameter defining an upper limit for the entire model, but individual function calls will typically use smaller values. That's computed based on the tokens: https://github.com/nod-ai/sharktank/blob/5005107768120df1a3e69ab1ac7abf40e701c34d/shortfin/shortfin/llm/impl/service_v1.py#L299 https://github.com/nod-ai/sharktank/blob/5005107768120df1a3e69ab1ac7abf40e701c34d/shortfin/shortfin/llm/impl/service_v1.py#L250-L261

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment for now, as well as debug logging showing the full tensors being saved:

INFO 07-15 11:33:50 [generate_data.py:78] Loaded config with hyperparameters:
INFO 07-15 11:33:50 [generate_data.py:79] {
    "module_name": "module",
    "module_abi_version": 1,
    "max_seq_len": 2048,
    "attn_head_count": 32,
    "attn_head_dim": 100,
    "prefill_batch_sizes": [
        4
    ],
    "decode_batch_sizes": [
        4
    ],
    "transformer_block_count": 26,
    "block_seq_stride": 16
}
INFO 07-15 11:33:50 [generate_data.py:103] prompt -> encoded tokens: [1, 1200, 325, 268, 4546, 296, 1161, 29584]
DEBUG 07-15 11:33:50 [generate_data.py:108] arg0_prefill_tokens:
DEBUG 07-15 11:33:50 [generate_data.py:109] [[    1  1200   325   268  4546   296  1161 29584     0
    0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
... (that should be a much smaller tensor for this input, hooray dynamic shapes)

)
arg1_prefill_seq_lens = np.ndarray(prefill_batch_size, dtype=np.int64)
# TODO(scotttodd): arg2 - attention block indices
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also tricky to populate, see the logic in set_sequences

# TODO(scotttodd): arg3 - attention block buffer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a stateful device buffer. Would need to read it back from device to host after a few real calls.


# Populate input arguments.
# TODO(scotttodd): loop over 1 prompt per batch here (or duplicate)
prompt = args.prompt
prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")[0].tolist()
logger.info(f"prompt -> encoded tokens: {prompt_tokens}")
prompt_seq_len = len(prompt_tokens)
arg0_prefill_tokens[0, 0:prompt_seq_len] = prompt_tokens
arg1_prefill_seq_lens[0] = prompt_seq_len

logger.info(f"Writing argument .bin files to '{args.output_dir}'")
args.output_dir.mkdir(parents=True, exist_ok=True)
write_ndarray_to_bin(arg0_prefill_tokens, args.output_dir / "arg0.bin")
write_ndarray_to_bin(arg1_prefill_seq_lens, args.output_dir / "arg1.bin")


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main(argv=sys.argv[1:])
Loading