Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] DisableKVCache Context #834

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 40 additions & 41 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.fsdp.context import fix_fsdp_module_name
from llmcompressor.utils.helpers import DisableKVCache
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
Expand Down Expand Up @@ -286,48 +287,46 @@ def apply_compression(
# want to calibrate wrt to these
self.model.apply(disable_quantization)

forward_pass_use_cache = self.model.config.use_cache
self.model.config.use_cache = False
with DisableKVCache(self.model):
# in non-sequential mode we run calibration through the full model
# in sequential mode we run calibration up to the first transformer target
intermediates = run_calibration_forward(
self.model, dataloader, mask_padding=True
)
self.layer_compressors_[0].clear_early_stop()

# in non-sequential mode we run calibration through the full model
# in sequential mode we run calibration up to the first transformer target
intermediates = run_calibration_forward(
self.model, dataloader, mask_padding=True
)
self.layer_compressors_[0].clear_early_stop()

# empty cache if not using sequential update
if not self.sequential_update:
del intermediates
gc.collect()
torch.cuda.empty_cache()

num_layers = len(self.compressible_layers_)
for idx, layer_compressor in enumerate(self.layer_compressors_):
logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====")

if self.sequential_update:
# in sequential mode we run the forward pass for each transformer layer
# one at a time, caching the intermediate outputs between layers
logger.info(f"Calibrating {layer_compressor.name}...")
layer_compressor.pre_compress()
unquantized_outputs = layer_compressor.calibrate_layer(intermediates)

layer_compressor.compress()
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()

if self.sequential_update:
quantized_outputs = layer_compressor.calibrate_layer(intermediates)
error = get_output_error(unquantized_outputs, quantized_outputs)
logger.info(f"Mean output error from quantization: {error:.3f}")
intermediates = quantized_outputs
del unquantized_outputs

gc.collect()
torch.cuda.empty_cache()

self.model.config.use_cache = forward_pass_use_cache
# empty cache if not using sequential update
if not self.sequential_update:
del intermediates
gc.collect()
torch.cuda.empty_cache()

num_layers = len(self.compressible_layers_)
for idx, layer_compressor in enumerate(self.layer_compressors_):
logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====")

if self.sequential_update:
# in sequential mode we run the forward pass for each layer
# one at a time, caching the intermediate outputs between layers
logger.info(f"Calibrating {layer_compressor.name}...")
layer_compressor.pre_compress()
unquantized_outputs = layer_compressor.calibrate_layer(
intermediates
)

layer_compressor.compress()
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()

if self.sequential_update:
quantized_outputs = layer_compressor.calibrate_layer(intermediates)
error = get_output_error(unquantized_outputs, quantized_outputs)
logger.info(f"Mean output error from quantization: {error:.3f}")
intermediates = quantized_outputs
del unquantized_outputs

gc.collect()
torch.cuda.empty_cache()

# re-enable quantization
self.model.apply(enable_quantization)
Expand Down
29 changes: 29 additions & 0 deletions src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from urllib.parse import urlparse

import numpy
import torch
from loguru import logger

__all__ = [
Expand Down Expand Up @@ -59,6 +60,7 @@
"is_package_available",
"import_from_path",
"getattr_chain",
"DisableKVCache",
]


Expand Down Expand Up @@ -1041,3 +1043,30 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
res = getattr(res, attr_name)

return res


class DisableKVCache:
def __init__(self, model: torch.nn.Module):
if hasattr(model.config, "use_cache"):
self.config = model.config

# MllamaConfig
elif hasattr(model.config, "text_config") and hasattr(
model.config.text_config, "use_cache"
):
self.config = model.config.text_config

# unknown config structure
else:
raise NotImplementedError(
f"Cannot find `use_cache` for config of type {type(model.config)}"
)

self.restore_value = self.config.use_cache

def __enter__(self):
self.restore_value = self.config.use_cache
self.config.use_cache = False

def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.config.use_cache = self.restore_value
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
Loading