Skip to content

Commit cbd7c29

Browse files
authored
Add dashboard metric support for sam and sam2 benchmarks (pytorch#1407)
* Add dashboard metric support for sam and sam2 benchmarks Summary: Also makes the `.github/workflows/dashboard_perf_test.yml` more complete by adding runs for both compile and autoquant-all for llama, sam, sam2 Test Plan: tested locally: ``` SEGMENT_ANYTHING_FAST_USE_FLASH_4=0 python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path output cat output.csv {"benchmark": {"name": "TorchAO benchmark", "mode": "inference", "dtype": "torch.bfloat16", "extra_info": {"device": "cuda", "arch": "NVIDIA H100"}}, "model": {"name": "vit_h", "type": "model", "origins": ["pytorch"]}, "metric": {"name": "memory(MiB)", "benchmark_values": [27740], "target_value": null}} {"benchmark": {"name": "TorchAO benchmark", "mode": "inference", "dtype": "torch.bfloat16", "extra_info": {"device": "cuda", "arch": "NVIDIA H100"}}, "model": {"name": "vit_h", "type": "model", "origins": ["pytorch"]}, "metric": {"name": "img_s(avg)", "benchmark_values": [42.65887304183299], "target_value": null}} python server.py ~/checkpoints/sam2 large --port 4000 --host localhost --use_autoquant --benchmark --output_json_path output {"benchmark": {"name": "TorchAO benchmark", "mode": "inference", "dtype": "autoquant", "extra_info": {"device": "cuda", "arch": "NVIDIA H100"}}, "model": {"name": "sam2-large", "type": "model", "origins": ["pytorch"]}, "metric": {"name": "memory(MiB)", "benchmark_values": [4294], "target_value": null}} {"benchmark": {"name": "TorchAO benchmark", "mode": "inference", "dtype": "autoquant", "extra_info": {"device": "cuda", "arch": "NVIDIA H100"}}, "model": {"name": "sam2-large", "type": "model", "origins": ["pytorch"]}, "metric": {"name": "memory(%)", "benchmark_values": [4], "target_value": null}} {"benchmark": {"name": "TorchAO benchmark", "mode": "inference", "dtype": "autoquant", "extra_info": {"device": "cuda", "arch": "NVIDIA H100"}}, "model": {"name": "sam2-large", "type": "model", "origins": ["pytorch"]}, "metric": {"name": "time_s(avg)", "benchmark_values": [0.4720584869384766], "target_value": null}} ``` Also will check CI and click house Reviewers: Subscribers: Tasks: Tags: * ruff * add util file * add scripts * tiktoken * dep * lm_eval * use default autoquant for now * path for sam eval * add conda_run * more deps * invalid requirement * git dep * include subgraph_utils folder in installation * reduce num_workers to 12 * try sam again with 8 workers * skipping sam * fix sam2 path * checking path * add sam2 config for ci * skip sam2 * remove config
1 parent ace7219 commit cbd7c29

File tree

14 files changed

+227
-85
lines changed

14 files changed

+227
-85
lines changed

.github/workflows/dashboard_perf_test.yml

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ on:
66
- ciflow/benchmark/*
77
workflow_dispatch:
88
schedule:
9-
- cron: 0 7 * * 0-6
9+
- cron: 0 7 * * *
1010

1111
jobs:
1212
benchmark:
1313
runs-on: linux.aws.a100
1414
strategy:
1515
matrix:
1616
torch-spec:
17-
- '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124'
17+
- '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124'
1818
steps:
1919
- uses: actions/checkout@v3
2020

@@ -31,14 +31,37 @@ jobs:
3131
${CONDA_RUN} pip install ${{ matrix.torch-spec }}
3232
${CONDA_RUN} pip install -r dev-requirements.txt
3333
${CONDA_RUN} pip install .
34+
# SAM 2.1
35+
${CONDA_RUN} pip install -r examples/sam2_amg_server/requirements.txt
3436
37+
# llama3
3538
export CHECKPOINT_PATH=checkpoints
36-
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
37-
${CONDA_RUN} python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf --hf_token ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
39+
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
40+
${CONDA_RUN} python scripts/download.py --repo_id ${MODEL_REPO} --hf_token ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
3841
${CONDA_RUN} python scripts/convert_hf_checkpoint.py --checkpoint_dir "${CHECKPOINT_PATH}/${MODEL_REPO}"
3942
4043
mkdir -p ${{ runner.temp }}/benchmark-results
41-
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --output_json_path ${{ runner.temp }}/benchmark-results/benchmark-results.json
44+
# llama3 - compile baseline
45+
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
46+
47+
# llama3 - autoquant
48+
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --quantization autoquant --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
49+
50+
# skipping SAM because of https://hud.pytorch.org/pr/pytorch/ao/1407
51+
# # SAM
52+
# ${CONDA_RUN} pip install git+https://github.com/pytorch-labs/segment-anything-fast.git@main
53+
# # SAM compile baselilne
54+
# ${CONDA_RUN} sh torchao/_models/sam/setup.sh
55+
# ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
56+
57+
# ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --compression autoquant --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
58+
59+
# SAM 2.1
60+
# ${CONDA_RUN} sh scripts/download_sam2_ckpts.sh ${CHECKPOINT_PATH}/sam2
61+
# cd examples/sam2_amg_server
62+
# hydra.errors.MissingConfigException: Cannot find primary config 'configs/sam2.1/sam2.1_hiera_l.yaml'. Check that it's in your config search path.
63+
# ${CONDA_RUN} python server.py ${CHECKPOINT_PATH}/sam2 large --port 4000 --host localhost --fast --benchmark --dry --output_json_path ${{ runner.temp }}/benchmark-results/sam2-benchmark-results.json
64+
# ${CONDA_RUN} python server.py ${CHECKPOINT_PATH}/sam2 large --port 4000 --host localhost --fast --use_autoquant --benchmark --dry --output_json_path ${{ runner.temp }}/benchmark-results/sam2-benchmark-results.json
4265
4366
- name: Upload the benchmark results to OSS benchmark database for the dashboard
4467
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main

dev-requirements.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,18 @@ sentencepiece # for gpt-fast tokenizer
99
expecttest
1010

1111
# For prototype features and benchmarks
12-
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
12+
bitsandbytes # needed for testing triton quant / dequant ops for 8-bit optimizers
1313
matplotlib
1414
pandas
1515
fire # QOL for commandline scripts
1616
tabulate # QOL for printing tables to stdout
17+
tiktoken
18+
blobfile
19+
lm_eval
20+
# sam
21+
diskcache
22+
pycocotools
23+
tqdm
1724

1825
# Custom CUDA Extensions
1926
ninja

examples/sam2_amg_server/server.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
import asyncio
2929
from contextlib import asynccontextmanager
3030
import contextlib
31+
from torchao._models.utils import (
32+
get_arch_name,
33+
write_json_result,
34+
)
3135

3236
from torch._inductor import config as inductorconfig
3337
inductorconfig.triton.unique_kernel_names = True
@@ -269,8 +273,10 @@ def benchmark_fn(func, inp, mask_generator, warmup=3, runs=10):
269273
t = time.time()
270274
for _ in range(runs):
271275
func(inp, mask_generator)
272-
print(f"Benchmark took {(time.time() - t)/runs}s per iteration.")
273-
max_memory_allocated()
276+
avg_time_per_run = (time.time() - t)/runs
277+
print(f"Benchmark took {avg_time_per_run}s per iteration.")
278+
max_memory_allocated_bytes, max_memory_allocated_percentage = max_memory_allocated()
279+
return avg_time_per_run, max_memory_allocated_bytes, max_memory_allocated_percentage
274280

275281

276282
def max_memory_allocated():
@@ -279,6 +285,7 @@ def max_memory_allocated():
279285
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
280286
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
281287
print(f"max_memory_allocated_bytes: {max_memory_allocated_bytes}MiB or {max_memory_allocated_percentage}%")
288+
return max_memory_allocated_bytes, max_memory_allocated_percentage
282289

283290

284291
def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False):
@@ -527,10 +534,10 @@ def set_furious(mask_generator):
527534
mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16
528535

529536
def set_autoquant(mask_generator):
537+
import torchao
530538
from torchao import autoquant
531-
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
532539
# NOTE: Not baseline feature
533-
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
540+
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
534541
mask_generator.predictor._transforms_device = mask_generator.predictor.device
535542
torch.set_float32_matmul_precision('high')
536543
# NOTE: this fails when we run
@@ -556,7 +563,8 @@ def main(checkpoint_path,
556563
dry=False,
557564
batch_size=1,
558565
load_fast="",
559-
save_fast=""):
566+
save_fast="",
567+
output_json_path=None):
560568
if verbose:
561569
logging.basicConfig(level=logging.INFO,
562570
format='%(asctime)s - %(levelname)s - %(message)s',
@@ -626,9 +634,9 @@ def main(checkpoint_path,
626634
if benchmark:
627635
print(f"batch size {batch_size} dog benchmark")
628636
if batch_size == 1:
629-
benchmark_fn(image_tensor_to_masks, image_tensor, mask_generator)
637+
result = benchmark_fn(image_tensor_to_masks, image_tensor, mask_generator)
630638
else:
631-
benchmark_fn(image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)
639+
result = benchmark_fn(image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)
632640

633641
for i, shapes in enumerate([example_shapes(), example_shapes_2()]):
634642
print(f"batch size {batch_size} example shapes {i} benchmark")
@@ -644,6 +652,19 @@ def main(checkpoint_path,
644652
print("len(random_images): ", len(random_images))
645653
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)
646654

655+
if output_json_path:
656+
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
657+
name = "sam2-" + model_type
658+
arch = get_arch_name()
659+
dtype = "autoquant" if use_autoquant else ("compile" if fast else "base")
660+
avg_time_per_run, max_memory_allocated_bytes, max_memory_allocated_percentage = result
661+
memory_result = [name, dtype, device, arch, "memory(MiB)", max_memory_allocated_bytes, None]
662+
memory_percent_result = [name, dtype, device, arch, "memory(%)", max_memory_allocated_percentage, None]
663+
performance_result = [name, dtype, device, arch, "time_s(avg)", avg_time_per_run, None]
664+
write_json_result(output_json_path, headers, memory_result)
665+
write_json_result(output_json_path, headers, memory_percent_result)
666+
write_json_result(output_json_path, headers, performance_result)
667+
647668
if profile is not None:
648669
print(f"Saving profile under {profile}")
649670
if batch_size == 1:

scripts/convert_hf_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def convert_hf_checkpoint(
3232
model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
3333
model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
3434
model_map_json = None
35-
35+
3636
try:
3737
assert model_map_json_safetensors.is_file()
3838
model_map_json = model_map_json_safetensors
@@ -46,7 +46,7 @@ def convert_hf_checkpoint(
4646
print(f"Found pytorch index at {model_map_json_pytorch}")
4747
except AssertionError:
4848
print(f"{model_map_json_pytorch} not found")
49-
49+
5050
if model_map_json is None: raise Exception("No model map found!")
5151

5252
with open(model_map_json) as json_map:
@@ -85,7 +85,7 @@ def permute(w, n_head):
8585
else:
8686
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
8787
merged_result.update(state_dict)
88-
88+
8989
if config.tie_word_embeddings:
9090
merged_result["lm_head.weight"] = merged_result["model.embed_tokens.weight"].clone()
9191

scripts/download_sam2_ckpts.sh

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#!/bin/bash
2+
3+
# Copyright (c) Meta Platforms, Inc. and affiliates.
4+
# All rights reserved.
5+
6+
# This source code is licensed under the license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
# Use either wget or curl to download the checkpoints
10+
if command -v wget &> /dev/null; then
11+
CMD="wget -P"
12+
elif command -v curl &> /dev/null; then
13+
CMD="curl -L -O"
14+
else
15+
echo "Please install wget or curl to download the checkpoints."
16+
exit 1
17+
fi
18+
19+
# Define the URLs for SAM 2 checkpoints
20+
# SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
21+
# sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
22+
# sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
23+
# sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
24+
# sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"
25+
26+
# Download each of the four checkpoints using wget
27+
# echo "Downloading sam2_hiera_tiny.pt checkpoint..."
28+
# $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }
29+
30+
# echo "Downloading sam2_hiera_small.pt checkpoint..."
31+
# $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }
32+
33+
# echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
34+
# $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }
35+
36+
# echo "Downloading sam2_hiera_large.pt checkpoint..."
37+
# $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }
38+
39+
# Define the URLs for SAM 2.1 checkpoints
40+
SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
41+
sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
42+
sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
43+
sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
44+
sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"
45+
46+
# $1 is the directory to store the checkpoint
47+
DEFAULT_DIR=test
48+
if [ -z "$1" ]; then
49+
DIR_NAME=$DEFAULT_DIR
50+
else
51+
# Use provided directory name
52+
DIR_NAME=$1
53+
fi
54+
55+
# SAM 2.1 checkpoints
56+
echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
57+
$CMD $DIR_NAME $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }
58+
59+
echo "Downloading sam2.1_hiera_small.pt checkpoint..."
60+
$CMD $DIR_NAME $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }
61+
62+
echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
63+
$CMD $DIR_NAME $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }
64+
65+
echo "Downloading sam2.1_hiera_large.pt checkpoint..."
66+
$CMD $DIR_NAME $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }
67+
68+
echo "All checkpoints are downloaded successfully."

scripts/run_ruff_fix.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
ruff check . --fix
2+
# --isolated is used to skip the allowlist at all so this applies to all files
3+
# please be careful when using this large changes means everyone needs to rebase
4+
ruff check --isolated --select F821,F823,W191 --fix
5+
ruff check --select F,I --fix
6+
ruff format .

torchao/_models/llama/generate.py

Lines changed: 11 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import json
76
import os
87
import platform
98
import sys
@@ -19,6 +18,10 @@
1918
import torchao
2019
from torchao.quantization.quant_primitives import MappingType
2120
from torchao.utils import get_model_size_in_bytes, TORCH_VERSION_AT_LEAST_2_5
21+
from torchao._models.utils import (
22+
get_arch_name,
23+
write_json_result,
24+
)
2225

2326
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
2427

@@ -37,14 +40,6 @@ def elapsed_time(self, other_event):
3740
return abs(other_event.event_time - self.event_time) * 1000
3841

3942

40-
def get_arch_name() -> str:
41-
if torch.cuda.is_available():
42-
return torch.cuda.get_device_name()
43-
else:
44-
# This returns x86_64 or arm64 (for aarch64)
45-
return platform.machine()
46-
47-
4843
def device_timer(device):
4944
if "cuda" in device:
5045
return torch.cuda.Event(enable_timing=True)
@@ -65,39 +60,6 @@ def device_sync(device):
6560
print(f"device={device} is not yet suppported")
6661

6762

68-
def write_json_result(output_json_path, headers, row):
69-
"""
70-
Write the result into JSON format, so that it can be uploaded to the benchmark database
71-
to be displayed on OSS dashboard. The JSON format is defined at
72-
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
73-
"""
74-
mapping_headers = {headers[i]: v for i, v in enumerate(row)}
75-
record = {
76-
"benchmark": {
77-
"name": "TorchAO benchmark",
78-
"mode": "inference",
79-
"dtype": mapping_headers["dtype"],
80-
"extra_info": {
81-
"device": mapping_headers["device"],
82-
"arch": mapping_headers["arch"],
83-
},
84-
},
85-
"model": {
86-
"name": mapping_headers["name"],
87-
"type": "model",
88-
"origins": ["pytorch"],
89-
},
90-
"metric": {
91-
"name": mapping_headers["metric"],
92-
"benchmark_values": [mapping_headers["actual"]],
93-
"target_value": mapping_headers["target"],
94-
},
95-
}
96-
97-
with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f:
98-
print(json.dumps(record), file=f)
99-
100-
10163
default_device = (
10264
"cuda"
10365
if torch.cuda.is_available()
@@ -728,20 +690,10 @@ def ffn_or_attn_only(mod, fqn):
728690
example_input=inputs,
729691
)
730692
if "autoquant-all" == quantization:
731-
all_qtensor_classes = (
732-
torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST
733-
+ torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST
734-
+ torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
735-
)
736-
if torchao.utils.is_sm_89():
737-
# this is fp8 related subclasses, should rename
738-
all_qtensor_classes += (
739-
torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST
740-
)
741693
model = autoquant(
742694
model,
743695
manual=True,
744-
qtensor_class_list=all_qtensor_classes,
696+
qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST,
745697
example_input=inputs,
746698
)
747699
else:
@@ -978,13 +930,13 @@ def callback(x):
978930
f.write(result_txt)
979931
f.close()
980932

981-
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
982-
name = checkpoint_path.parent.name
983-
arch = get_arch_name()
984-
dtype = quantization or str(precision)
985-
memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None]
986-
performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None]
987933
if output_json_path:
934+
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
935+
name = checkpoint_path.parent.name
936+
arch = get_arch_name()
937+
dtype = quantization or str(precision)
938+
memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None]
939+
performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None]
988940
write_json_result(output_json_path, headers, memory_result)
989941
write_json_result(output_json_path, headers, performance_result)
990942

0 commit comments

Comments
 (0)