Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int4_weight_only api got error when saving transformers models #1704

Open
jiqing-feng opened this issue Feb 12, 2025 · 2 comments
Open

int4_weight_only api got error when saving transformers models #1704

jiqing-feng opened this issue Feb 12, 2025 · 2 comments

Comments

@jiqing-feng
Copy link

When I load a int4 cpu quantized model and want to save this model, I got this issue: TypeError: Object of type Int4CPULayout is not JSON serializable

To reproduce it:

import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.dtypes import Int4CPULayout

model_name = "meta-llama/Llama-3.1-8B-Instruct"
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
device_map = "cpu"
quantization_config = TorchAoConfig("int4_weight_only", group_size=128, layout=Int4CPULayout())
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map=device_map, quantization_config=quantization_config)
quantized_model.save_pretrained("./llama3-8b-ao-int4", safe_serialization=False)

output:

Traceback (most recent call last):
  File "/home/jiqingfe/test_torchao.py", line 11, in <module>
    quantized_model.save_pretrained("./llama3-8b-ao-int4", safe_serialization=False)
  File "/home/jiqingfe/transformers/src/transformers/modeling_utils.py", line 2800, in save_pretrained
    model_to_save.config.save_pretrained(save_directory)
  File "/home/jiqingfe/transformers/src/transformers/configuration_utils.py", line 419, in save_pretrained
    self.to_json_file(output_config_file, use_diff=True)
  File "/home/jiqingfe/transformers/src/transformers/configuration_utils.py", line 941, in to_json_file
    writer.write(self.to_json_string(use_diff=use_diff))
  File "/home/jiqingfe/transformers/src/transformers/configuration_utils.py", line 927, in to_json_string
    return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  File "/usr/lib/python3.10/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
  File "/usr/lib/python3.10/json/encoder.py", line 201, in encode
    chunks = list(chunks)
  File "/usr/lib/python3.10/json/encoder.py", line 431, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "/usr/lib/python3.10/json/encoder.py", line 405, in _iterencode_dict
    yield from chunks
  File "/usr/lib/python3.10/json/encoder.py", line 405, in _iterencode_dict
    yield from chunks
  File "/usr/lib/python3.10/json/encoder.py", line 405, in _iterencode_dict
    yield from chunks
  File "/usr/lib/python3.10/json/encoder.py", line 438, in _iterencode
    o = _default(o)
  File "/usr/lib/python3.10/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type Int4CPULayout is not JSON serializable

I was thinking if we could change to a more friendly data structure to save layout data.

@jiqing-feng
Copy link
Author

jiqing-feng commented Feb 12, 2025

Same error on CUDA, we cannot save model if we pass layout to the config.

import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.dtypes import TensorCoreTiledLayout

model_name = "meta-llama/Llama-3.1-8B-Instruct"
device_map = "cuda:0"
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
quantization_config = TorchAoConfig("int4_weight_only", group_size=128, layout=TensorCoreTiledLayout())
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map=device_map, quantization_config=quantization_config)
quantized_model.save_pretrained("./llama3-8B-ao-int4", safe_serialization=False)

error:

Traceback (most recent call last):
  File "/workspace/jiqing/test_torchao.py", line 11, in <module>
    quantized_model.save_pretrained("./llama3-8B-ao-int4", safe_serialization=False)
  File "/workspace/jiqing/transformers/src/transformers/modeling_utils.py", line 2800, in save_pretrained
    model_to_save.config.save_pretrained(save_directory)
  File "/workspace/jiqing/transformers/src/transformers/configuration_utils.py", line 419, in save_pretrained
    self.to_json_file(output_config_file, use_diff=True)
  File "/workspace/jiqing/transformers/src/transformers/configuration_utils.py", line 941, in to_json_file
    writer.write(self.to_json_string(use_diff=use_diff))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/jiqing/transformers/src/transformers/configuration_utils.py", line 927, in to_json_string
    return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
          ^^^^^^^^^^^
  File "/usr/lib/python3.12/json/encoder.py", line 202, in encode
    chunks = list(chunks)
             ^^^^^^^^^^^^
  File "/usr/lib/python3.12/json/encoder.py", line 432, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "/usr/lib/python3.12/json/encoder.py", line 406, in _iterencode_dict
    yield from chunks
  File "/usr/lib/python3.12/json/encoder.py", line 406, in _iterencode_dict
    yield from chunks
  File "/usr/lib/python3.12/json/encoder.py", line 406, in _iterencode_dict
    yield from chunks
  File "/usr/lib/python3.12/json/encoder.py", line 439, in _iterencode
    o = _default(o)
        ^^^^^^^^^^^
  File "/usr/lib/python3.12/json/encoder.py", line 180, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type TensorCoreTiledLayout is not JSON serializable

@jiqing-feng jiqing-feng changed the title int4_weight_only api is not friendly for transformers model when saving int4_weight_only api git error when saving transformers models Feb 12, 2025
@jiqing-feng jiqing-feng changed the title int4_weight_only api git error when saving transformers models int4_weight_only api got error when saving transformers models Feb 12, 2025
@supriyar
Copy link
Contributor

cc @jerryzh168 @andrewor14

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants