Skip to content

Commit

Permalink
Cleanup optimizer (#1904)
Browse files Browse the repository at this point in the history
Cleanup optimizer by moving older proto-based optimizations into a
_legacy folder, renaming files to distinguish internal implementation
files, and other minor restructuring.
  • Loading branch information
gramalingam authored Oct 14, 2024
1 parent 8fef233 commit 4578142
Show file tree
Hide file tree
Showing 20 changed files with 225 additions and 464 deletions.
5 changes: 2 additions & 3 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ exclude_patterns = [
'onnxscript/onnx_types.py',
'onnxscript/**/*_test.py', # Skip linting test files for speed
'onnxscript/function_libs/torch_lib/ops/**', # Operators typing do not play well with mypy
'onnxscript/optimizer/evaluator.py', # FIXME
'onnxscript/optimizer/constant_folding.py', # FIXME
'onnxscript/optimizer/_legacy/evaluator.py', # FIXME
'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
'onnxscript/optimizer/fold_constants_v0.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
'onnxscript/tools/function_unittest_producer.py', # FIXME
'onnxscript/_legacy_ir/visitor.py', # FIXME
Expand Down
160 changes: 11 additions & 149 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,160 +2,22 @@
# Licensed under the MIT License.
from __future__ import annotations

import logging
from typing import Any

import onnx
import onnx.shape_inference

from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding, _inliner
from onnxscript.optimizer.constant_folding import fold_constants
from onnxscript.optimizer.remove_unused import remove_unused_nodes
from onnxscript.optimizer.remove_unused_function import remove_unused_functions
from onnxscript.optimizer.simple_function_folding import (
inline_functions_with_unused_outputs,
inline_simple_functions,
)
from onnxscript.rewriter import (
broadcast_to_matmul,
cast_constant_of_shape,
gemm_to_matmul_add,
no_op,
)

logger = logging.getLogger(__name__)

_DEFAULT_REWRITE_RULES = [
*no_op.rules.rules, # TODO: merge this rule into constant folding?
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule,
*cast_constant_of_shape.rules.rules,
]


def optimize(
model: onnx.ModelProto,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
external_data_folder: str = "",
**kwargs: Any,
) -> onnx.ModelProto:
"""Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc.
Args:
model (onnx.ModelProto): The model to optimize.
num_iterations (int, optional): Number of iterations to perform.
onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model.
Set this to False to turn off onnx shape inference, and rely on model carried shapes and types.
This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries
the symbolic shapes recorded from dynamo tracing.
stop_if_no_change (bool, optional): Whether to stop if no change is detected.
external_data_folder (str, optional): The folder to store external data.
**kwargs: Additional keyword arguments. For BC purposes.
"""
if kwargs.pop("function_aware_folding", None) is not None:
logger.warning(
"'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. "
"To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. "
"This would turn off incremental onnx shape inference and rely on model carried shapes and types. "
"See 'onnx_shape_inference' for more details."
)
for _ in range(num_iterations):
if onnx_shape_inference:
if model.ByteSize() < 1024 * 1024 * 1024 * 2:
# NOTE: strict mode is disabled because it crashes on the models
# that have different shapes inferred from the model carried shapes.
# The case can be found in:
# https://github.com/microsoft/onnxscript/issues/1443
model = onnx.shape_inference.infer_shapes(
model, check_type=True, strict_mode=False, data_prop=True
)
else:
logger.warning(
"The model size is too large for full model shape inference. "
"Skipping this step."
)

inline_simple_functions(model)
modified = fold_constants(
model, external_data_folder, onnx_shape_inference=onnx_shape_inference
)

remove_unused_nodes(model)
inline_simple_functions(model)
model = remove_unused_functions(model)
inline_functions_with_unused_outputs(model)
# NOTE: This is general rewrite rules
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
if stop_if_no_change and not modified:
logger.debug("Stopping after %d iterations.", _)
break

for node in model.graph.node:
logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name)

for function in model.functions:
for node in function.node:
logger.debug(
"Function %s::%s node %s::%s name %s.",
function.domain,
function.name,
node.domain,
node.op_type,
node.name,
)

return model


_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = (
_constant_folding._DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT
)

_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = (
_constant_folding._DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT
)


def optimize_ir(
model: ir.Model,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
) -> None:
"""Optimizes a model.
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
from onnxscript import ir
from onnxscript.optimizer._constant_folding import basic_constant_propagation
from onnxscript.optimizer._legacy.constant_folding import fold_constants
from onnxscript.optimizer._optimizer import optimize_ir
from onnxscript.optimizer._remove_unused import remove_unused_nodes

Args:
model: The model to be optimized.
num_iterations: Number of times the optimization loop is repeated.
onnx_shape_inference: Applies node-level shape-inference as part of optimization
input_size_limit: Will not apply constant folding to ops with any input of size
greater than this. Does not apply to special ops like Shape() and Size().
output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
of the output tensor is greater than this.
stop_if_no_change: Not supported currently (has no effect). Meant to stop the
outer optimization loop if no change is detected in one iteration.
"""
del stop_if_no_change # Looks like rewriter doesn't support this yet.
_inliner.inline(model)
for _ in range(num_iterations):
_constant_folding.fold_constants(
model,
onnx_shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
)
rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
remove_unused_nodes(model)

def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs):
if isinstance(model, ir.Model):
return optimize_ir(model, *args, **kwargs)
else:
return legacy_optimizer.optimize(model, *args, **kwargs)

basic_constant_propagation = _constant_folding.basic_constant_propagation

__all__ = [
"fold_constants",
Expand Down
28 changes: 18 additions & 10 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,32 @@

import onnxscript.ir as ir
import onnxscript.ir._convenience as _convenience
import onnxscript.optimizer.constant_folding as constant_folding
import onnxscript.rewriter.pattern as orp
import onnxscript.utils.utils as utils

DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024

DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 1024 * 1024


def is_control_flow_op(node: ir.Node) -> bool:
graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
return any(attr.type in graph_types for attr in node.attributes.values())


non_deterministic_ops = frozenset(
{
"RandomUniform",
"RandomNormal",
"RandomUniformLike",
"RandomNormalLike",
"Multinomial",
}
)


def is_non_deterministic_op(node: ir.Node) -> bool:
return node.op_type in constant_folding.non_deterministic_ops and utils.is_onnx_domain(
node.domain
)
return node.op_type in non_deterministic_ops and utils.is_onnx_domain(node.domain)


def is_onnx_op(node: ir.Node, op_type: str) -> bool:
Expand All @@ -43,10 +55,6 @@ def is_constant_op(node: ir.Node) -> bool:
)


_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024

_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT

logger = logging.getLogger(__name__)

# "Standard" evaluators are used to perform constant-folding.
Expand Down Expand Up @@ -787,8 +795,8 @@ def fold_constants(
external_data_folder: str = "",
*,
onnx_shape_inference: bool = False,
input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
) -> bool:
"""
Applies constant folding optimization to the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import onnxscript.optimizer as optimizer
from onnxscript.ir import serde
from onnxscript.optimizer import _constant_folding, constant_folding
from onnxscript.optimizer import _constant_folding
from onnxscript.optimizer._legacy import constant_folding


@parameterized.parameterized_class(("using_ir",), [(False,), (True,)])
Expand Down
98 changes: 98 additions & 0 deletions onnxscript/optimizer/_legacy/_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import logging
from typing import Any

import onnx
import onnx.shape_inference

from onnxscript import rewriter
from onnxscript.optimizer._legacy._simple_function_folding import (
inline_functions_with_unused_outputs,
inline_simple_functions,
)
from onnxscript.optimizer._legacy.constant_folding import fold_constants
from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES
from onnxscript.optimizer._remove_unused import remove_unused_nodes
from onnxscript.optimizer._remove_unused_function import remove_unused_functions

logger = logging.getLogger(__name__)


def optimize(
model: onnx.ModelProto,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
external_data_folder: str = "",
**kwargs: Any,
) -> onnx.ModelProto:
"""Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc.
Args:
model (onnx.ModelProto): The model to optimize.
num_iterations (int, optional): Number of iterations to perform.
onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model.
Set this to False to turn off onnx shape inference, and rely on model carried shapes and types.
This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries
the symbolic shapes recorded from dynamo tracing.
stop_if_no_change (bool, optional): Whether to stop if no change is detected.
external_data_folder (str, optional): The folder to store external data.
**kwargs: Additional keyword arguments. For BC purposes.
"""
if kwargs.pop("function_aware_folding", None) is not None:
logger.warning(
"'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. "
"To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. "
"This would turn off incremental onnx shape inference and rely on model carried shapes and types. "
"See 'onnx_shape_inference' for more details."
)
for _ in range(num_iterations):
if onnx_shape_inference:
if model.ByteSize() < 1024 * 1024 * 1024 * 2:
# NOTE: strict mode is disabled because it crashes on the models
# that have different shapes inferred from the model carried shapes.
# The case can be found in:
# https://github.com/microsoft/onnxscript/issues/1443
model = onnx.shape_inference.infer_shapes(
model, check_type=True, strict_mode=False, data_prop=True
)
else:
logger.warning(
"The model size is too large for full model shape inference. "
"Skipping this step."
)

inline_simple_functions(model)
modified = fold_constants(
model, external_data_folder, onnx_shape_inference=onnx_shape_inference
)

remove_unused_nodes(model)
inline_simple_functions(model)
model = remove_unused_functions(model)
inline_functions_with_unused_outputs(model)
# NOTE: This is general rewrite rules
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
if stop_if_no_change and not modified:
logger.debug("Stopping after %d iterations.", _)
break

for node in model.graph.node:
logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name)

for function in model.functions:
for node in function.node:
logger.debug(
"Function %s::%s node %s::%s name %s.",
function.domain,
function.name,
node.domain,
node.op_type,
node.name,
)

return model
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import onnxscript._legacy_ir as ir
from onnxscript._legacy_ir import visitor
from onnxscript.optimizer import remove_unused_proto
from onnxscript.optimizer._legacy import _remove_unused_proto

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -168,7 +168,7 @@ def _find_nodes_with_any_unused_output(
# All unused output means the node is not used at all.
# Hence do not update used_values with the node's inputs.
continue
used_values |= remove_unused_proto.compute_used_in_node(node)
used_values |= _remove_unused_proto.compute_used_in_node(node)
return target_nodes

def visit_model(self, model: onnx.ModelProto) -> None:
Expand Down
Loading

0 comments on commit 4578142

Please sign in to comment.