-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Start on llama/tools/generate_data.py. #105
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm making slow progress understanding each of these parameters. The stateful parameters are especially tricky. Are there 1:1 mappings with inputs used with other models (i.e. not our port in sharktank)? I'm wondering if we could get low level input/output datasets from something other than our own code. Otherwise we might just need to keep building out service_v1.py
and teach it to dump binary files with a flag 🤔
arg0_prefill_tokens = np.ndarray( | ||
[prefill_batch_size, config["max_seq_len"]], dtype=np.int64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, this is a hyperparameter defining an upper limit for the entire model, but individual function calls will typically use smaller values. That's computed based on the tokens: https://github.com/nod-ai/sharktank/blob/5005107768120df1a3e69ab1ac7abf40e701c34d/shortfin/shortfin/llm/impl/service_v1.py#L299 https://github.com/nod-ai/sharktank/blob/5005107768120df1a3e69ab1ac7abf40e701c34d/shortfin/shortfin/llm/impl/service_v1.py#L250-L261
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment for now, as well as debug logging showing the full tensors being saved:
INFO 07-15 11:33:50 [generate_data.py:78] Loaded config with hyperparameters:
INFO 07-15 11:33:50 [generate_data.py:79] {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 2048,
"attn_head_count": 32,
"attn_head_dim": 100,
"prefill_batch_sizes": [
4
],
"decode_batch_sizes": [
4
],
"transformer_block_count": 26,
"block_seq_stride": 16
}
INFO 07-15 11:33:50 [generate_data.py:103] prompt -> encoded tokens: [1, 1200, 325, 268, 4546, 296, 1161, 29584]
DEBUG 07-15 11:33:50 [generate_data.py:108] arg0_prefill_tokens:
DEBUG 07-15 11:33:50 [generate_data.py:109] [[ 1 1200 325 268 4546 296 1161 29584 0
0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
... (that should be a much smaller tensor for this input, hooray dynamic shapes)
[prefill_batch_size, config["max_seq_len"]], dtype=np.int64 | ||
) | ||
arg1_prefill_seq_lens = np.ndarray(prefill_batch_size, dtype=np.int64) | ||
# TODO(scotttodd): arg2 - attention block indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also tricky to populate, see the logic in set_sequences
) | ||
arg1_prefill_seq_lens = np.ndarray(prefill_batch_size, dtype=np.int64) | ||
# TODO(scotttodd): arg2 - attention block indices | ||
# TODO(scotttodd): arg3 - attention block buffer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a stateful device buffer. Would need to read it back from device to host after a few real calls.
Ping? Would like feedback on the approach taken here (standalone script vs teaching the service to output its args/outputs). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems ok as a start. I also don't know a great answer for the states. Ideally you wouldn't include those in pre baked data but have something that loops them properly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think this works as a first step. Thanks for all the notes, answered my question from the pre-sync
Progress on #103. Sending early for design feedback.
I want something lighter weight than https://github.com/nod-ai/sharktank/blob/main/shortfin/shortfin/llm/impl/service_v1_cli.py and https://github.com/nod-ai/sharktank/blob/main/shortfin/shortfin/llm/impl/service_v1.py to use that can generate real inputs using a tokenizer and the provided hyperparameters (e.g. batch sizes, max sequence length), for use in offline tests and benchmarks. What I'm not sure about is the state tracking in the cache. That's probably easiest to just dump from the service, but I think for prefill at least we should be able to generate something sensible here.
Anyways, this is part learning exercise for me and part useful tool for others. Looking for feedback!