Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start on llama/tools/generate_data.py. #105

Merged
merged 4 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/model_cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,14 @@ iree-run-module \
--input=1x2662400xf16 \
--parameters=model=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.gguf
```

## Generating data for llama models

```bash
set TURBINE_DEBUG=log_level=info
python -m sharktank.models.llama.tools.generate_data \
--tokenizer=openlm-research/open_llama_3b_v2 \
--config=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.json \
--output-dir=/tmp/open_llama_3b_v2/inputs \
--prompt="What is the meaning of life?"
```
38 changes: 38 additions & 0 deletions sharktank/sharktank/models/llama/tools/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from pathlib import Path
import struct
import numpy as np

# map numpy dtype -> (iree dtype, struct.pack format str)
dtype_map = {
np.dtype("int64"): ("si64", "q"),
np.dtype("uint64"): ("ui64", "Q"),
np.dtype("int32"): ("si32", "i"),
np.dtype("uint32"): ("ui32", "I"),
np.dtype("int16"): ("si16", "h"),
np.dtype("uint16"): ("ui16", "H"),
np.dtype("int8"): ("si8", "b"),
np.dtype("uint8"): ("ui8", "B"),
np.dtype("float64"): ("f64", "d"),
np.dtype("float32"): ("f32", "f"),
np.dtype("float16"): ("f16", "e"),
np.dtype("bool"): ("i1", "?"),
}


def pack_np_ndarray(ndarray: np.ndarray):
mylist = ndarray.flatten().tolist()
dtype = ndarray.dtype
assert dtype in dtype_map
return struct.pack(f"{len(mylist)}{dtype_map[dtype][1]}", *mylist)


def write_ndarray_to_bin(ndarray: np.ndarray, file: Path):
with open(file, "wb") as f:
packed_ndarray = pack_np_ndarray(ndarray)
f.write(packed_ndarray)
120 changes: 120 additions & 0 deletions sharktank/sharktank/models/llama/tools/generate_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Generates data files for calling iree-run-module from a prompt and config.

Usage:
$ python -m sharktank.models.llama.tools.generate_data \
--tokenizer=openlm-research/open_llama_3b_v2 \
--config=/tmp/open-llama-3b-v2-f16.json \
--output-dir=/tmp/inputs \
--prompt="What is the meaning of life?"

$ ls /tmp/inputs

arg0.bin
arg1.bin
arg2.bin
arg3.bin

$ iree-run-module \
--module=/tmp/open-llama-3b-v2-f16_cpu.vmfb \
--parameters=model=/tmp/open-llama-3b-v2-f16.gguf \
--function=prefill_bs4 \
--device=local-task \
--input=4x1xi64=@/tmp/inputs/arg0.bin \
--input=4xi64=@/tmp/inputs/arg1.bin \
--input=4x1xi64=@/tmp/inputs/arg2.bin \
--input=1x2662400xf16=@/tmp/inputs/arg3.bin

# TODO(scotttodd): similar script to convert outputs to text via tokenizer
# TODO(scotttodd): teach service_v1_cli to also dump its inputs/outputs?
# TODO(scotttodd): generate expected outputs using reference model?
"""

from pathlib import Path
import logging
import sys
import json
import numpy as np

from transformers import LlamaTokenizer # type: ignore

from ....utils.logging import get_logger
from .data_utils import write_ndarray_to_bin

logger = get_logger("sharktank.models.llama.tools.generate_data")


def main(argv):
from ....utils import cli

parser = cli.create_parser()
parser.add_argument(
"--tokenizer", help="name of hugginface tokenizer to use", required=True
)
parser.add_argument(
"--config",
type=Path,
help="json config file with hyperparameters",
required=True,
)
parser.add_argument(
"--output-dir",
type=Path,
help="Generate .bin files into this directory",
required=True,
)
parser.add_argument("--prompt", help="Prompt string", required=True)
# TODO(scotttodd): output path (directory to dump .bin/.npy files)
args = cli.parse(parser, args=argv)

# Load config hyperparameters.
with open(args.config) as f:
config = json.load(f)
logger.info("Loaded config with hyperparameters:")
logger.info(json.dumps(config, indent=4))

# Load tokenizer.
# TODO(scotttodd): Unify tokenizer flags across sharktank and shortfin?
# cli.add_tokenizer_options(parser)
# tokenizer = cli.get_tokenizer(args)
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer, legacy=False)

# TODO(scotttodd): loop over batch sizes (generate one dataset per batch size)
prefill_batch_size = config["prefill_batch_sizes"][0]

# Declare input arguments.
# TODO(scotttodd): compute max_seq_len from tokens, _not_ config here
arg0_prefill_tokens = np.zeros(
[prefill_batch_size, config["max_seq_len"]], dtype=np.int64
)
arg1_prefill_seq_lens = np.zeros(prefill_batch_size, dtype=np.int64)
# TODO(scotttodd): arg2 - attention block indices
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also tricky to populate, see the logic in set_sequences

# TODO(scotttodd): arg3 - attention block buffer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a stateful device buffer. Would need to read it back from device to host after a few real calls.


# Populate input arguments.
# TODO(scotttodd): loop over 1 prompt per batch here (or duplicate)
prompt = args.prompt
prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")[0].tolist()
logger.info(f"prompt -> encoded tokens: {prompt_tokens}")
prompt_seq_len = len(prompt_tokens)
arg0_prefill_tokens[0, 0:prompt_seq_len] = prompt_tokens
arg1_prefill_seq_lens[0] = prompt_seq_len
with np.printoptions(threshold=np.inf):
logger.debug("arg0_prefill_tokens:")
logger.debug(arg0_prefill_tokens)
logger.debug("arg1_prefill_seq_lens:")
logger.debug(arg1_prefill_seq_lens)

logger.info(f"Writing argument .bin files to '{args.output_dir}'")
args.output_dir.mkdir(parents=True, exist_ok=True)
write_ndarray_to_bin(arg0_prefill_tokens, args.output_dir / "arg0.bin")
write_ndarray_to_bin(arg1_prefill_seq_lens, args.output_dir / "arg1.bin")


if __name__ == "__main__":
main(argv=sys.argv[1:])
Loading