Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migration of quantize_ workflow configuration from callables to configs #1690

Open
vkuzo opened this issue Feb 10, 2025 · 0 comments
Open

migration of quantize_ workflow configuration from callables to configs #1690

vkuzo opened this issue Feb 10, 2025 · 0 comments
Assignees

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Feb 10, 2025

summary

This issue tracks the migration of quantize_ per-workflow configuration from Callables to configs..

Motivation: passing direct configuraton is intuintive and widely used in similar contexts across various projects. Passing configuration wrapped in a callable is not intuitive, hard to understand and debug, and we have evidence that it pushes a portion of users from building on top of torchao.

We will keep the old callable syntax supported by quantize_ for one release cycle, and delete it afterwards. We will keep the old names as aliases for new names going forward (example: int4_weight_only as an alias of Int4WeightOnlyConfig) to keep existing callsites working without changes.

impact on API users

If you are just using the torchao quantize_ API as specified in the README, this is not BC-breaking. For example, the following syntax will keep working.

quantize_(model, int8_weight_only())

Note that the type of the object created by int8_weight_only() will change from a Callable to a config. You have the option to migrate to the explicit config creation, as follows:

quantize_(model, Int8WeightOnlyConfig())

user facing API changes

signature of quantize_

#
# before
#
def quantize(
    model: torch.nn.Module,
    apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
    ...,
): ...

#
# after - intermediate state, support both old and new for one release
#
def quantize(
    model: torch.nn.Module,
    config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]],
    ...,
): ...

#
# after - long term state
#
def quantize(
    model: torch.nn.Module,
    config: AOBaseConfig,
    ...,
): ...

usage example

An example for int4_weight_only

#
# before
#
quantize_(m, int4_weight_only(group_size=32))

#
# after, with new user facing names
#
quantize_(m, Int4WeightOnlyConfig(group_size=32))

#
# AND, after, with BC names
#
quantize_(m, int4_weight_only(group_size=32))

developer facing changes

See the PR details for examples, but they can be summarized as:

#
# old
#

# quantize_ calls the instance of calling this function on each module of the model
def int4_weight_only(group_size: int, ...) -> Callable:

    def new_callable(weight: torch.Tensor):
        # configuration is captured here via local variables
        ...
        
    # return type is a Callable
    return _get_linear_subclass_inserter(new_callable)

#
# new
#

# config base class
class AOBaseConfig(abc.ABC):
    pass

# user facing configuration of a workflow
@dataclass
class Int4WeightOnlyConfig(AOBaseConfig):
    group_size: int = 128
    ...

# not user facing transform of a module according to a worfklow's configuration
@register_quantize_module_handler(Int4WeightOnlyConfig)
def _int4_weight_only_transform(
    module: torch.nn.Module, 
    config: Int4WeightOnlyConfig,
) -> torch.nn.Module:
    # map to AQT, not user facing
    ...

migration status

workflow configuration

tutorials (replace with new registration API)

replace docblocks and public facing descriptions with new names

verify partner integrations still work

confirmed two out of three here: vkuzo/pytorch_scripts#28

delete old path (one version after migration)

  • not done, upcoming in version TBD
@vkuzo vkuzo self-assigned this Feb 10, 2025
vkuzo added a commit to vkuzo/pytorch_scripts that referenced this issue Feb 12, 2025
Summary:

Testing for pytorch/ao#1690

Convenient to have this here to test on torchao main vs torchao
experiment

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit to vkuzo/pytorch_scripts that referenced this issue Feb 13, 2025
Summary:

Testing for pytorch/ao#1690

Convenient to have this here to test on torchao main vs torchao
experiment

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo changed the title placeholder for migrating workflow configuration to AOBaseConfig migration of quantize_ workflow configuration from callables to configs Feb 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant