Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testing autoquant #2264

Open
wants to merge 1 commit into
base: gh/HDCharles/4/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
)
parser.add_argument(
"--quantization",
choices=["int8dynamic", "int8weightonly", "int4weightonly"],
choices=["int8dynamic", "int8weightonly", "int4weightonly","autoquant"],
help="Apply quantization to the model before running it",
)
parser.add_argument(
Expand Down Expand Up @@ -187,21 +187,37 @@ def apply_torchdynamo_args(
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,

)


torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
assert "cuda" in model.device
module, example_inputs = model.get_module()
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
if isinstance(example_inputs, tuple([tuple, list])):
example_inputs = tuple([
x.to(torch.bfloat16)
if isinstance(x, torch.Tensor) and x.dtype in [torch.float32, torch.float16]
else x
for x in example_inputs
])
module=module.to(torch.bfloat16)
with torch.no_grad():
module(*example_inputs)
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)
elif args.quantization == "autoquant":
torchao.autoquant(module, example_inputs, error_on_unseen=False)


if args.freeze_prepack_weights:
torch._inductor.config.freezing = True
Expand Down
7 changes: 3 additions & 4 deletions userbenchmark/group_bench/configs/torch_ao.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ device: cuda
extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune
metrics:
- latencies
- gpu_peak_mem
test_group:
test_batch_size_default:
subgroup:
- extra_args:
- extra_args: --quantization int8dynamic
- extra_args: --quantization int8weightonly
- extra_args: --quantization int4weightonly
- extra_args: --quantization noquant
- extra_args: --quantization autoquant
20 changes: 13 additions & 7 deletions userbenchmark/group_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def models_from_config(config) -> List[str]:
basic_models_list = list_models()
else:
basic_models_list = [config["model"]]
breakpoint()
assert isinstance(config["model", list]), "Config model must be a list or string."
basic_models_list = config["model"]
extended_models_list = []
Expand Down Expand Up @@ -218,14 +219,19 @@ def run(args: List[str]):
results = {}
try:
for config in group_config.configs:
metrics = get_metrics(group_config.metrics, config)
if "accuracy" in metrics:
metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun)
else:
metrics_dict = run_config(config, metrics, dryrun=args.dryrun)
try:
metrics = get_metrics(group_config.metrics, config)
if "accuracy" in metrics:
metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun)
else:
metrics_dict = run_config(config, metrics, dryrun=args.dryrun)
except KeyboardInterrupt:
raise KeyboardInterrupt
except Exception as e:
metrics_dict = {}
config_str = config_to_str(config)
for metric in metrics_dict:
results[f"{config_str}, metric={metric}"] = metrics_dict[metric]
for metric in metrics:
results[f"{config_str}, metric={metric}"] = metrics_dict.get(metric, "err")
except KeyboardInterrupt:
print("User keyboard interrupted!")
result = get_output_json(BM_NAME, results)
Expand Down
Loading