diff --git a/simulator/profiling/mlp/mlp_impl.py b/simulator/profiling/mlp/mlp_impl.py index 24f2a09..1f8221a 100644 --- a/simulator/profiling/mlp/mlp_impl.py +++ b/simulator/profiling/mlp/mlp_impl.py @@ -4,9 +4,9 @@ # Monkey patching sarathi cuda timer to use our custom timer from simulator.profiling.cuda_timer import CudaTimer -import sarathi.metrics.cuda_timer as sarathi_cuda_timer +import sarathi.metrics.cuda_timer -sarathi_cuda_timer.CudaTimer = CudaTimer +sarathi.metrics.cuda_timer.CudaTimer = CudaTimer from sarathi.model_executor.layers.activation import SiluAndMul from sarathi.model_executor.layers.layernorm import RMSNorm