2828import asyncio
2929from contextlib import asynccontextmanager
3030import contextlib
31+ from torchao ._models .utils import (
32+ get_arch_name ,
33+ write_json_result ,
34+ )
3135
3236from torch ._inductor import config as inductorconfig
3337inductorconfig .triton .unique_kernel_names = True
@@ -269,8 +273,10 @@ def benchmark_fn(func, inp, mask_generator, warmup=3, runs=10):
269273 t = time .time ()
270274 for _ in range (runs ):
271275 func (inp , mask_generator )
272- print (f"Benchmark took { (time .time () - t )/ runs } s per iteration." )
273- max_memory_allocated ()
276+ avg_time_per_run = (time .time () - t )/ runs
277+ print (f"Benchmark took { avg_time_per_run } s per iteration." )
278+ max_memory_allocated_bytes , max_memory_allocated_percentage = max_memory_allocated ()
279+ return avg_time_per_run , max_memory_allocated_bytes , max_memory_allocated_percentage
274280
275281
276282def max_memory_allocated ():
@@ -279,6 +285,7 @@ def max_memory_allocated():
279285 max_memory_allocated_percentage = int (100 * (max_memory_allocated_bytes / total_memory ))
280286 max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
281287 print (f"max_memory_allocated_bytes: { max_memory_allocated_bytes } MiB or { max_memory_allocated_percentage } %" )
288+ return max_memory_allocated_bytes , max_memory_allocated_percentage
282289
283290
284291def unittest_fn (masks , ref_masks , order_by_area = False , verbose = False ):
@@ -527,10 +534,10 @@ def set_furious(mask_generator):
527534 mask_generator .predictor .model .sam_mask_decoder ._src_dtype = torch .float16
528535
529536def set_autoquant (mask_generator ):
537+ import torchao
530538 from torchao import autoquant
531- from torchao .quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
532539 # NOTE: Not baseline feature
533- mask_generator .predictor .model .image_encoder = autoquant (mask_generator .predictor .model .image_encoder , qtensor_class_list = DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST , min_sqnr = 40 )
540+ mask_generator .predictor .model .image_encoder = autoquant (mask_generator .predictor .model .image_encoder , qtensor_class_list = torchao . quantization . DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST , min_sqnr = 40 )
534541 mask_generator .predictor ._transforms_device = mask_generator .predictor .device
535542 torch .set_float32_matmul_precision ('high' )
536543 # NOTE: this fails when we run
@@ -556,7 +563,8 @@ def main(checkpoint_path,
556563 dry = False ,
557564 batch_size = 1 ,
558565 load_fast = "" ,
559- save_fast = "" ):
566+ save_fast = "" ,
567+ output_json_path = None ):
560568 if verbose :
561569 logging .basicConfig (level = logging .INFO ,
562570 format = '%(asctime)s - %(levelname)s - %(message)s' ,
@@ -626,9 +634,9 @@ def main(checkpoint_path,
626634 if benchmark :
627635 print (f"batch size { batch_size } dog benchmark" )
628636 if batch_size == 1 :
629- benchmark_fn (image_tensor_to_masks , image_tensor , mask_generator )
637+ result = benchmark_fn (image_tensor_to_masks , image_tensor , mask_generator )
630638 else :
631- benchmark_fn (image_tensors_to_masks , [image_tensor ] * batch_size , mask_generator )
639+ result = benchmark_fn (image_tensors_to_masks , [image_tensor ] * batch_size , mask_generator )
632640
633641 for i , shapes in enumerate ([example_shapes (), example_shapes_2 ()]):
634642 print (f"batch size { batch_size } example shapes { i } benchmark" )
@@ -644,6 +652,19 @@ def main(checkpoint_path,
644652 print ("len(random_images): " , len (random_images ))
645653 benchmark_fn (image_tensors_to_masks , random_images , mask_generator )
646654
655+ if output_json_path :
656+ headers = ["name" , "dtype" , "device" , "arch" , "metric" , "actual" , "target" ]
657+ name = "sam2-" + model_type
658+ arch = get_arch_name ()
659+ dtype = "autoquant" if use_autoquant else ("compile" if fast else "base" )
660+ avg_time_per_run , max_memory_allocated_bytes , max_memory_allocated_percentage = result
661+ memory_result = [name , dtype , device , arch , "memory(MiB)" , max_memory_allocated_bytes , None ]
662+ memory_percent_result = [name , dtype , device , arch , "memory(%)" , max_memory_allocated_percentage , None ]
663+ performance_result = [name , dtype , device , arch , "time_s(avg)" , avg_time_per_run , None ]
664+ write_json_result (output_json_path , headers , memory_result )
665+ write_json_result (output_json_path , headers , memory_percent_result )
666+ write_json_result (output_json_path , headers , performance_result )
667+
647668 if profile is not None :
648669 print (f"Saving profile under { profile } " )
649670 if batch_size == 1 :
0 commit comments