diff --git a/qllm/quantization/vptq/quant_vptq.py b/qllm/quantization/vptq/quant_vptq.py index 2bdd891..220f897 100644 --- a/qllm/quantization/vptq/quant_vptq.py +++ b/qllm/quantization/vptq/quant_vptq.py @@ -28,6 +28,7 @@ def set_tokenizer(self, tokenizer): def get_level_order_linear(self, model, model_prefix, dev): level_map = {} + prevent_release_tensors = [] #reserve tensors and prevent tensor release, we need id is unique class Catcher(torch.nn.Module): def __init__(self, module): super().__init__() @@ -40,12 +41,19 @@ def fake_forward(hidden_state, *args, **kwargs): nonlocal level_map if len(level_map) == 0: hidden_state *= 0 - key = int(hidden_state[..., 0].item()) + input_shape = hidden_state.shape + key = id(hidden_state) + hidden_state.data_ptr() + prevent_release_tensors.append(hidden_state) + if hidden_state.numel() == 0: + value = 2025+len(level_map) + input_shape = tuple(i + 1 if i == 0 else i for i in input_shape) + else: + value = int(hidden_state[..., 0].item()) if key not in level_map: level_map[key] = [] level_map[key].append(name) - out = torch.ones(hidden_state.shape[:-1]+(out_fetures,), device=hidden_state.device) - return out*key + 1 + out = torch.ones(input_shape[:-1] + (out_fetures,), device=hidden_state.device) + return out * value + 1 fake_forward.layer_name = name return fake_forward @@ -67,6 +75,7 @@ def fake_forward_2(hidden_state, *args, **kwargs): model(torch.ones([1, 1], dtype=torch.int64).to(dev)) except ValueError: pass + attention_layers[1] = attention_layers[1].module for _, l_layer in old_forwards.items(): l_layer[0].forward = l_layer[1] @@ -74,12 +83,13 @@ def fake_forward_2(hidden_state, *args, **kwargs): def collect_hessian_pre(self, model, model_prefix, dev): level_linear_names = self.get_level_order_linear(model, model_prefix, "cpu") + logger.info("linear arch:" + str(level_linear_names)) + for k in list(level_linear_names.keys()): for sub_name in level_linear_names[k]: level_linear_names[sub_name] = level_linear_names[k][0] level_linear_names.pop(k) self.name2hessian = level_linear_names - logger.info("linear arch:"+str(level_linear_names)) if self.quant_config.hessian_path is not None and self.quant_config.inv_hessian_path is not None: logger.info("read cached Hessian data") _, attention_layers, layer_input_args = self.hijack_block_inputs(model, [(torch.tensor((1, 1), dtype=torch.int64), )], model_prefix, "cpu")