Skip to content

Commit

Permalink
test(model): add device memory test
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Feb 23, 2024
1 parent 64bbc1f commit 98c401a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
12 changes: 12 additions & 0 deletions test/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import gc

import pytest
import torch
Expand Down Expand Up @@ -82,3 +83,14 @@ def assert_similar(a, b, atol=1e-6, rtol=1e-5):
if not torch.allclose(sim, torch.tensor(1.0, dtype=sim.dtype), atol=atol, rtol=rtol):
max_deviation = torch.min(sim)
raise ValueError(f"Alignment {max_deviation:.8f} deviates too much from 1.0 with atol={atol}, rtol={rtol}")


def get_device_memory(device):
gc.collect()
if device.type == "cuda":
torch.cuda.empty_cache()
return torch.cuda.memory_allocated()
elif device.type == "mps":
torch.mps.empty_cache()
return torch.mps.current_allocated_memory()
return None
44 changes: 43 additions & 1 deletion test/model/test_quantize_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
import torch
from helpers import assert_similar, random_qtensor, random_tensor
from helpers import assert_similar, get_device_memory, random_qtensor, random_tensor

from quanto import Calibration, QLinear, QTensor, freeze, qfloat8_e4m3fn, qfloat8_e5m2, qint8, quantize
from quanto.nn import QModuleMixin
Expand Down Expand Up @@ -105,3 +105,45 @@ def test_serialize_quantized_mlp(weights, dtype, device):
assert torch.equal(module_reloaded.weight._scale, module.weight._scale)
assert torch.equal(module_reloaded.input_scale, module.input_scale)
assert torch.equal(module_reloaded.output_scale, module.output_scale)


@pytest.mark.skip_device("cpu")
@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
def test_quantized_mlp_device_memory(weights, dtype, device):
# Check we start from a clean state
assert get_device_memory(device) == 0
input_features = 1024
hidden_features = 2048
output_features = 1024
model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)
full_precision_memory = get_device_memory(device)
assert full_precision_memory > 0
quantize(model, weights=weights)
freeze(model)
quantized_memory = get_device_memory(device)
assert quantized_memory > 0
assert quantized_memory < full_precision_memory
# Serialize model
b = io.BytesIO()
torch.save(model.state_dict(), b)
# Free device memory
del model
assert get_device_memory(device) == 0
# Reload state dict on CPU
b.seek(0)
state_dict = torch.load(b, map_location=torch.device("cpu"))
assert get_device_memory(device) == 0
# Create an empty model and quantize it with the same parameters
with torch.device("meta"):
model_reloaded = MLP(input_features, hidden_features, output_features)
assert get_device_memory(device) == 0
quantize(model_reloaded)
assert get_device_memory(device) == 0
# Reload the state dict, still on CPU
model_reloaded.load_state_dict(state_dict, assign=True)
assert get_device_memory(device) == 0
# Finally, move the model to the device
model_reloaded.to(device)
reloaded_memory = get_device_memory(device)
assert reloaded_memory == quantized_memory

0 comments on commit 98c401a

Please sign in to comment.