Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions qllm/quantization/vptq/quant_vptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def set_tokenizer(self, tokenizer):

def get_level_order_linear(self, model, model_prefix, dev):
level_map = {}
prevent_release_tensors = [] #reserve tensors and prevent tensor release, we need id is unique
class Catcher(torch.nn.Module):
def __init__(self, module):
super().__init__()
Expand All @@ -40,12 +41,19 @@ def fake_forward(hidden_state, *args, **kwargs):
nonlocal level_map
if len(level_map) == 0:
hidden_state *= 0
key = int(hidden_state[..., 0].item())
input_shape = hidden_state.shape
key = id(hidden_state) + hidden_state.data_ptr()
prevent_release_tensors.append(hidden_state)
if hidden_state.numel() == 0:
value = 2025+len(level_map)
input_shape = tuple(i + 1 if i == 0 else i for i in input_shape)
else:
value = int(hidden_state[..., 0].item())
if key not in level_map:
level_map[key] = []
level_map[key].append(name)
out = torch.ones(hidden_state.shape[:-1]+(out_fetures,), device=hidden_state.device)
return out*key + 1
out = torch.ones(input_shape[:-1] + (out_fetures,), device=hidden_state.device)
return out * value + 1
fake_forward.layer_name = name
return fake_forward

Expand All @@ -67,19 +75,21 @@ def fake_forward_2(hidden_state, *args, **kwargs):
model(torch.ones([1, 1], dtype=torch.int64).to(dev))
except ValueError:
pass

attention_layers[1] = attention_layers[1].module
for _, l_layer in old_forwards.items():
l_layer[0].forward = l_layer[1]
return level_map

def collect_hessian_pre(self, model, model_prefix, dev):
level_linear_names = self.get_level_order_linear(model, model_prefix, "cpu")
logger.info("linear arch:" + str(level_linear_names))

for k in list(level_linear_names.keys()):
for sub_name in level_linear_names[k]:
level_linear_names[sub_name] = level_linear_names[k][0]
level_linear_names.pop(k)
self.name2hessian = level_linear_names
logger.info("linear arch:"+str(level_linear_names))
if self.quant_config.hessian_path is not None and self.quant_config.inv_hessian_path is not None:
logger.info("read cached Hessian data")
_, attention_layers, layer_input_args = self.hijack_block_inputs(model, [(torch.tensor((1, 1), dtype=torch.int64), )], model_prefix, "cpu")
Expand Down