-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Start on llama/tools/generate_data.py. #105
Merged
Merged
Changes from 1 commit
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
adc962b
Start on llama/tools/generate_data.py.
ScottTodd 228e67e
Adjust logging instructions.
ScottTodd 7d3ff07
Zero initialize args and log values before writing to files (debug).
ScottTodd 182f098
Merge branch 'main' into llama-generate-data
ScottTodd File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from pathlib import Path | ||
import struct | ||
import numpy as np | ||
|
||
# map numpy dtype -> (iree dtype, struct.pack format str) | ||
dtype_map = { | ||
np.dtype("int64"): ("si64", "q"), | ||
np.dtype("uint64"): ("ui64", "Q"), | ||
np.dtype("int32"): ("si32", "i"), | ||
np.dtype("uint32"): ("ui32", "I"), | ||
np.dtype("int16"): ("si16", "h"), | ||
np.dtype("uint16"): ("ui16", "H"), | ||
np.dtype("int8"): ("si8", "b"), | ||
np.dtype("uint8"): ("ui8", "B"), | ||
np.dtype("float64"): ("f64", "d"), | ||
np.dtype("float32"): ("f32", "f"), | ||
np.dtype("float16"): ("f16", "e"), | ||
np.dtype("bool"): ("i1", "?"), | ||
} | ||
|
||
|
||
def pack_np_ndarray(ndarray: np.ndarray): | ||
mylist = ndarray.flatten().tolist() | ||
dtype = ndarray.dtype | ||
assert dtype in dtype_map | ||
return struct.pack(f"{len(mylist)}{dtype_map[dtype][1]}", *mylist) | ||
|
||
|
||
def write_ndarray_to_bin(ndarray: np.ndarray, file: Path): | ||
with open(file, "wb") as f: | ||
packed_ndarray = pack_np_ndarray(ndarray) | ||
f.write(packed_ndarray) |
115 changes: 115 additions & 0 deletions
115
sharktank/sharktank/models/llama/tools/generate_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
"""Generates data files for calling iree-run-module from a prompt and config. | ||
|
||
Usage: | ||
$ python -m sharktank.models.llama.tools.generate_data \ | ||
--tokenizer=openlm-research/open_llama_3b_v2 \ | ||
--config=/tmp/open-llama-3b-v2-f16.json \ | ||
--output-dir=/tmp/inputs \ | ||
--prompt="What is the meaning of life?" | ||
|
||
$ ls /tmp/inputs | ||
|
||
arg0.bin | ||
arg1.bin | ||
arg2.bin | ||
arg3.bin | ||
|
||
$ iree-run-module \ | ||
--module=/tmp/open-llama-3b-v2-f16_cpu.vmfb \ | ||
--parameters=model=/tmp/open-llama-3b-v2-f16.gguf \ | ||
--function=prefill_bs4 \ | ||
--device=local-task \ | ||
--input=4x1xi64=@/tmp/inputs/arg0.bin \ | ||
--input=4xi64=@/tmp/inputs/arg1.bin \ | ||
--input=4x1xi64=@/tmp/inputs/arg2.bin \ | ||
--input=1x2662400xf16=@/tmp/inputs/arg3.bin | ||
|
||
# TODO(scotttodd): similar script to convert outputs to text via tokenizer | ||
# TODO(scotttodd): teach service_v1_cli to also dump its inputs/outputs? | ||
# TODO(scotttodd): generate expected outputs using reference model? | ||
""" | ||
|
||
from pathlib import Path | ||
import logging | ||
import sys | ||
import json | ||
import numpy as np | ||
|
||
from transformers import LlamaTokenizer # type: ignore | ||
|
||
from ....utils.logging import get_logger | ||
from .data_utils import write_ndarray_to_bin | ||
|
||
logger = get_logger("sharktank.models.llama.tools.generate_data") | ||
|
||
|
||
def main(argv): | ||
from ....utils import cli | ||
|
||
parser = cli.create_parser() | ||
parser.add_argument( | ||
"--tokenizer", help="name of hugginface tokenizer to use", required=True | ||
) | ||
parser.add_argument( | ||
"--config", | ||
type=Path, | ||
help="json config file with hyperparameters", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--output-dir", | ||
type=Path, | ||
help="Generate .bin files into this directory", | ||
required=True, | ||
) | ||
parser.add_argument("--prompt", help="Prompt string", required=True) | ||
# TODO(scotttodd): output path (directory to dump .bin/.npy files) | ||
args = cli.parse(parser, args=argv) | ||
|
||
# Load config hyperparameters. | ||
with open(args.config) as f: | ||
config = json.load(f) | ||
logger.info("Loaded config with hyperparameters:") | ||
logger.info(json.dumps(config, indent=4)) | ||
|
||
# Load tokenizer. | ||
# TODO(scotttodd): Unify tokenizer flags across sharktank and shortfin? | ||
# cli.add_tokenizer_options(parser) | ||
# tokenizer = cli.get_tokenizer(args) | ||
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer, legacy=False) | ||
|
||
# TODO(scotttodd): loop over batch sizes (generate one dataset per batch size) | ||
prefill_batch_size = config["prefill_batch_sizes"][0] | ||
|
||
# Declare input arguments. | ||
arg0_prefill_tokens = np.ndarray( | ||
[prefill_batch_size, config["max_seq_len"]], dtype=np.int64 | ||
) | ||
arg1_prefill_seq_lens = np.ndarray(prefill_batch_size, dtype=np.int64) | ||
# TODO(scotttodd): arg2 - attention block indices | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also tricky to populate, see the logic in |
||
# TODO(scotttodd): arg3 - attention block buffer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a stateful device buffer. Would need to read it back from device to host after a few real calls. |
||
|
||
# Populate input arguments. | ||
# TODO(scotttodd): loop over 1 prompt per batch here (or duplicate) | ||
prompt = args.prompt | ||
prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")[0].tolist() | ||
logger.info(f"prompt -> encoded tokens: {prompt_tokens}") | ||
prompt_seq_len = len(prompt_tokens) | ||
arg0_prefill_tokens[0, 0:prompt_seq_len] = prompt_tokens | ||
arg1_prefill_seq_lens[0] = prompt_seq_len | ||
|
||
logger.info(f"Writing argument .bin files to '{args.output_dir}'") | ||
args.output_dir.mkdir(parents=True, exist_ok=True) | ||
write_ndarray_to_bin(arg0_prefill_tokens, args.output_dir / "arg0.bin") | ||
write_ndarray_to_bin(arg1_prefill_seq_lens, args.output_dir / "arg1.bin") | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.INFO) | ||
main(argv=sys.argv[1:]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, this is a hyperparameter defining an upper limit for the entire model, but individual function calls will typically use smaller values. That's computed based on the tokens: https://github.com/nod-ai/sharktank/blob/5005107768120df1a3e69ab1ac7abf40e701c34d/shortfin/shortfin/llm/impl/service_v1.py#L299 https://github.com/nod-ai/sharktank/blob/5005107768120df1a3e69ab1ac7abf40e701c34d/shortfin/shortfin/llm/impl/service_v1.py#L250-L261
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment for now, as well as debug logging showing the full tensors being saved: