Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A couple of extensions to rewriter #1912

Merged
merged 6 commits into from
Oct 23, 2024
Merged

A couple of extensions to rewriter #1912

merged 6 commits into from
Oct 23, 2024

Conversation

gramalingam
Copy link
Collaborator

@gramalingam gramalingam commented Oct 22, 2024

A couple of extensions to the rewriter, motivated by fusion optimization experimentation with SmoLLM.

  • Support list of constants in match-pattern.
  • One multi-output scenario is easy to handle with the single-output pattern-matcher (eg. defining a fusion rule for SkipNormalization): namely when the extra outputs are intermediate values used in the computation of the first value. Extend algorithm to handle this scenario using the efficient single-output matching-algorithm.

An example for the second point is the following pattern:

def skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type):
    skip_sum = op.Add(input, skip)
    normalized = op.SimplifiedLayerNormalization(
        skip_sum,
        gamma,
        axis=-1,
        epsilon=epsilon,
        stash_type=stash_type,
        _domain="com.microsoft")
    return normalized, skip_sum

If we successfully find a match for normalized (which transitively finds a match for all of the pattern subgraph that leads up to normalized), we have also found a successful match for skip_sum, so no need for a multi-output match.

(Will add test-cases later, as I work through the fusion optimizations I am experimenting with.)

Copy link

codecov bot commented Oct 22, 2024

❌ 15 Tests Failed:

Tests completed Failed Passed Skipped
14182 15 14167 1625
View the full list of 3 ❄️ flaky tests
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime test_function_input_and_attribute_by_kwargs_out_of_order

Flake rate in main: 33.84% (Passed 3411 times, Failed 1745 times)

Stack Traces | 0.003s run time
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:91: in run
    res = self._run(x, y)
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
..../test_torch_nightly/lib/python3.12.../onnx/reference/custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests/eager_mode_test.py:115: in test_function_input_and_attribute_by_kwargs_out_of_order
    self.assertEqual(add_with_alpha(alpha=3.0, other=2.0, this=1.0), 7.0)
onnxscript/values.py:529: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript/evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests/eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
.../onnx_opset/_impl/opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript/values.py:301: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript/evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript/evaluator.py:524: in _eval
    result = session.run(None, session_run_input)
..../test_torch_nightly/lib/python3.12.../onnx/reference/reference_evaluator.py:599: in run
    outputs = node.run(*inputs, **linked_attributes)
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime test_function_some_input_by_kwargs

Flake rate in main: 33.84% (Passed 3411 times, Failed 1745 times)

Stack Traces | 0.003s run time
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:91: in run
    res = self._run(x, y)
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
..../test_torch_nightly/lib/python3.12.../onnx/reference/custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests/eager_mode_test.py:106: in test_function_some_input_by_kwargs
    self.assertEqual(add_with_alpha(1.0, other=2.0), 3.0)
onnxscript/values.py:529: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript/evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests/eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
.../onnx_opset/_impl/opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript/values.py:301: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript/evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript/evaluator.py:524: in _eval
    result = session.run(None, session_run_input)
..../test_torch_nightly/lib/python3.12.../onnx/reference/reference_evaluator.py:599: in run
    outputs = node.run(*inputs, **linked_attributes)
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime test_function_all_input_by_kwargs

Flake rate in main: 33.84% (Passed 3411 times, Failed 1745 times)

Stack Traces | 0.003s run time
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:91: in run
    res = self._run(x, y)
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
..../test_torch_nightly/lib/python3.12.../onnx/reference/custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests/eager_mode_test.py:109: in test_function_all_input_by_kwargs
    self.assertEqual(add_with_alpha(this=1.0, other=2.0), 3.0)
onnxscript/values.py:529: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript/evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests/eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
.../onnx_opset/_impl/opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript/values.py:301: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript/evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript/evaluator.py:524: in _eval
    result = session.run(None, session_run_input)
..../test_torch_nightly/lib/python3.12.../onnx/reference/reference_evaluator.py:599: in run
    outputs = node.run(*inputs, **linked_attributes)
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
..../test_torch_nightly/lib/python3.12.../reference/ops/_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').

To view individual test run time comparison to the main branch, go to the Test Analytics Dashboard

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still a little confused by the second case. Would it be helpful to make the example more concrete in the PR description? Thanks!

@gramalingam
Copy link
Collaborator Author

I am still a little confused by the second case. Would it be helpful to make the example more concrete in the PR description? Thanks!

Added an example

@gramalingam gramalingam enabled auto-merge (squash) October 23, 2024 03:16
@gramalingam gramalingam merged commit f18dadc into main Oct 23, 2024
25 of 41 checks passed
@gramalingam gramalingam deleted the rama/minor branch October 23, 2024 03:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

2 participants