Skip to content

Commit a13b56a

Browse files
authored
[OP] Add rms_norm into TOPI (apache#15326)
This PR introduces the operator root mean square, `rms_norm`, into TOPI.
1 parent 4b183da commit a13b56a

File tree

7 files changed

+269
-0
lines changed

7 files changed

+269
-0
lines changed

include/tvm/topi/nn/rms_norm.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \brief root mean square normalization op constructions
22+
* \file nn/rms_norm.h
23+
*/
24+
#ifndef TVM_TOPI_NN_RMS_NORM_H_
25+
#define TVM_TOPI_NN_RMS_NORM_H_
26+
27+
#include <tvm/te/operation.h>
28+
#include <tvm/topi/reduction.h>
29+
#include <tvm/topi/tags.h>
30+
31+
#include <string>
32+
33+
namespace tvm {
34+
namespace topi {
35+
namespace nn {
36+
37+
using namespace tvm::te;
38+
39+
/*!
40+
* \brief Root mean square normalization.
41+
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
42+
* \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
43+
* d_{axis_k} == r_k
44+
* \param bias Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
45+
* d_{axis_k} == r_k
46+
* \param axis The axis to normalize over.
47+
* \param epsilon The epsilon value to avoid division by zero.
48+
* \param name The name of the operation.
49+
* \param tag The tag to mark the operation.
50+
* \return The normalized tensor, with the same shape as data.
51+
*/
52+
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& bias,
53+
const Array<Integer>& axis, double epsilon, std::string name = "T_rms_norm",
54+
std::string tag = kInjective) {
55+
const auto& data_type = data->dtype;
56+
const auto& weight_type = weight.defined() ? weight->dtype : data_type;
57+
ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type";
58+
const auto& bias_type = bias.defined() ? bias->dtype : data_type;
59+
ICHECK(data_type == bias_type) << "rms_norm: data and bias must have the same type";
60+
61+
auto square = multiply(data, data);
62+
auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);
63+
64+
auto ndim = data->shape.size();
65+
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
66+
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
67+
auto reduce_extent = make_const(data->dtype, 1);
68+
for (int i : real_axis) {
69+
reduce_extent *= data->shape[i];
70+
}
71+
auto rms_norm_func = [&](const Array<Var>& indices) {
72+
Array<Var> reduce_indices, non_reduce_indices;
73+
for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
74+
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
75+
reduce_indices.push_back(indices[i]);
76+
} else {
77+
non_reduce_indices.push_back(indices[i]);
78+
}
79+
}
80+
auto output =
81+
data(indices) * weight(reduce_indices) *
82+
tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
83+
if (bias.defined()) {
84+
output += bias(reduce_indices);
85+
}
86+
return output;
87+
};
88+
auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
89+
return rms_norm;
90+
}
91+
92+
} // namespace nn
93+
} // namespace topi
94+
} // namespace tvm
95+
96+
#endif // TVM_TOPI_NN_RMS_NORM_H_

python/tvm/topi/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .instance_norm import instance_norm
4242
from .layer_norm import layer_norm
4343
from .group_norm import group_norm
44+
from .rms_norm import rms_norm
4445
from .local_response_norm import *
4546
from .bitserial_conv2d import *
4647
from .bitserial_dense import *

python/tvm/topi/nn/rms_norm.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Root mean square normalization operator."""
18+
from .. import cpp
19+
20+
21+
def rms_norm(data, weight, bias, axis, epsilon=1e-5):
22+
"""Root mean square normalization operator. The output will have the same data type as input.
23+
24+
Parameters
25+
----------
26+
data : tvm.te.Tensor
27+
N-D with shape (d_0, d_1, ..., d_{N-1})
28+
29+
weight: tvm.te.Tensor
30+
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
31+
32+
bias: tvm.te.Tensor
33+
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
34+
35+
axis : list of int
36+
Axis over the normalization applied
37+
38+
epsilon : float
39+
The epsilon value to avoid division by zero.
40+
41+
Returns
42+
-------
43+
result : tvm.te.Tensor
44+
N-D with shape (d_0, d_1, ..., d_{N-1})
45+
"""
46+
return cpp.nn.rms_norm(data, weight, bias, axis, epsilon)

python/tvm/topi/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .instance_norm_python import instance_norm_python
4747
from .layer_norm_python import layer_norm_python
4848
from .group_norm_python import group_norm_python
49+
from .rms_norm_python import rms_norm_python
4950
from .lrn_python import lrn_python
5051
from .l2_normalize_python import l2_normalize_python
5152
from .gather_python import gather_python
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
18+
"""Root mean square normalization in python"""
19+
import numpy as np
20+
21+
22+
def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
23+
"""Root mean square normalization operator in Python.
24+
25+
Parameters
26+
----------
27+
data : numpy.ndarray
28+
N-D with shape (d_0, d_1, ..., d_{N-1})
29+
30+
weight: numpy.ndarray
31+
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
32+
33+
bias: numpy.ndarray
34+
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
35+
36+
axis : int or tuple of ints
37+
Axis over the normalization applied
38+
39+
epsilon : float
40+
The epsilon value to avoid division by zero.
41+
42+
Returns
43+
-------
44+
result : np.ndarray
45+
N-D with shape (d_0, d_1, ..., d_{N-1})
46+
"""
47+
square_mean = np.mean(np.square(data), axis, keepdims=True)
48+
result = data * weight / np.sqrt(square_mean + epsilon)
49+
if bias is not None:
50+
result += bias
51+
return result

src/topi/nn.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <tvm/topi/nn/local_response_norm.h>
3636
#include <tvm/topi/nn/mapping.h>
3737
#include <tvm/topi/nn/pooling.h>
38+
#include <tvm/topi/nn/rms_norm.h>
3839
#include <tvm/topi/nn/softmax.h>
3940

4041
namespace tvm {
@@ -176,5 +177,10 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal
176177
*rv = nn::instance_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
177178
});
178179

180+
/* Ops from nn/rms_norm.h */
181+
TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* rv) {
182+
*rv = nn::rms_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
183+
});
184+
179185
} // namespace topi
180186
} // namespace tvm
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Test code for rms_norm."""
18+
import numpy as np
19+
import pytest
20+
import tvm
21+
from tvm import te
22+
from tvm import topi
23+
from tvm.topi.utils import get_const_tuple
24+
import tvm.topi.testing
25+
26+
import tvm.testing
27+
28+
29+
_rms_norm_schedule = {
30+
"generic": topi.generic.schedule_injective,
31+
}
32+
33+
34+
# only test on llvm because schedule is missing
35+
@tvm.testing.parametrize_targets("llvm")
36+
@pytest.mark.parametrize(
37+
"shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,))]
38+
)
39+
@pytest.mark.parametrize("dtype", ["float32", "float16"])
40+
def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, atol=1e-4):
41+
shape_te = [te.var(v[0]) if isinstance(v, tuple) else v for v in shape]
42+
scale_shape_te = [shape_te[dim] for dim in axis]
43+
data = te.placeholder(shape_te, dtype=dtype, name="data")
44+
weight = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
45+
bias = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
46+
B = topi.nn.rms_norm(data, weight, bias, axis, episilon)
47+
48+
shape_np = [v[1] if isinstance(v, tuple) else v for v in shape]
49+
scale_shape_np = [shape_np[dim] for dim in axis]
50+
data_np = np.random.uniform(size=shape_np).astype(dtype)
51+
weight_np = np.random.uniform(size=scale_shape_np).astype(dtype)
52+
bias_np = np.random.uniform(size=scale_shape_np).astype(dtype)
53+
b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, episilon)
54+
55+
with tvm.target.Target(target):
56+
s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule)
57+
s = s_func([B])
58+
data_tvm = tvm.nd.array(data_np, dev)
59+
weight_tvm = tvm.nd.array(weight_np, dev)
60+
bias_tvm = tvm.nd.array(bias_np, dev)
61+
b_tvm = tvm.nd.array(np.zeros(shape_np, dtype=dtype), dev)
62+
f = tvm.build(s, [data, weight, bias, B], target)
63+
f(data_tvm, weight_tvm, bias_tvm, b_tvm)
64+
tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)
65+
66+
67+
if __name__ == "__main__":
68+
tvm.testing.main()

0 commit comments

Comments
 (0)