Skip to content

Commit fd41dc5

Browse files
rianbrooksflynnEthan0JiangJanFSchulteLostEcho365Rian Flynn
authored
Add LayerNorm support for Vivado (#1110)
* paser_mht * change parser and modify keras_to_hls * IR_mutihead_attention * IR done * create mha file in template * mha .h file dummy algo * config of mha * update mha config * dummy mha * add transpose into mha * projection_of_qkv_in_mha * mha_first_draft * able to predict model correct * delete some unnassary comments * delete comments * resource strategy of transformer * change sm lagacy * update MHA, optimized * support resource * update * dense_muti_dim_support * parallel execute dense * updates * add_layerNorm_support * MHA updated * LayerNorm_bug_fix * update bit precision * config update * add some comment * run pre-commit * Added support on QMultiHeadAttention, QLayerNormalization, and quantized_softmax * updated on hls4ml transformer * trying to clean the diff * trying to clean the diff * trying to clean the diff * trying to clean the diff * trying to clean the diff * undo vhdl -> verilog change * halfway working layernorm + test * layernorm is now pretty functional * layernorm on pytorch also * minor cleanup * more cleanup, pre-commit * test for mha which kinda works maybe if you squint * multihead attention working on keras and pytorch * fiddly precision / accuracy changes for layernorm * fix lookup table and label loops * remove dense_seq * undo qkeras changes * fix merge conflict residue * remove non-layernorm changes * change to uniform LUT and fix precision * [pre-commit.ci] auto fixes from pre-commit hooks * fix encodings issue with dos2unix * add Vitis as another tested backend * Address PR feedback * [pre-commit.ci] auto fixes from pre-commit hooks * fix too-long lines * fix merge issue * trigger pre-commit * re-add missing math import * [pre-commit.ci] auto fixes from pre-commit hooks * addressing Vladimir's latest comments * change also pytorch test for layernorm and revert change to build command * sideport changes to channels-last converter from 1352 --------- Co-authored-by: Ethan <[email protected]> Co-authored-by: Jan-Frederik Schulte <[email protected]> Co-authored-by: LostEcho365 <[email protected]> Co-authored-by: Rian Flynn <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e02b511 commit fd41dc5

File tree

13 files changed

+527
-6
lines changed

13 files changed

+527
-6
lines changed

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
GarNetStack,
2626
GlobalPooling1D,
2727
GlobalPooling2D,
28+
LayerNormalization,
2829
MatMul,
2930
Merge,
3031
Pooling1D,
@@ -73,6 +74,7 @@ def __init__(self, name):
7374
Dot,
7475
Conv,
7576
MatMul,
77+
LayerNormalization,
7678
]
7779

7880
for layer in accum_layers:

hls4ml/backends/vivado/passes/core_templates.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22

33
from hls4ml.backends.backend import get_backend
44
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
5-
from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax
5+
from hls4ml.model.layers import (
6+
Activation,
7+
BatchNormalization,
8+
Dense,
9+
HardActivation,
10+
LayerNormalization,
11+
ParametrizedActivation,
12+
PReLU,
13+
Softmax,
14+
)
615
from hls4ml.model.optimizer.passes.hgq_proxy_model import UnaryLUT
716

817
# Dense templates
@@ -136,6 +145,58 @@ def format(self, node):
136145
return self.template.format(**params)
137146

138147

148+
# LayerNormalization templates
149+
150+
layernorm_config_template = """struct config{index} : nnet::layernorm_config {{
151+
static const unsigned n_in = {n_in};
152+
static const unsigned seq_len = {seq_len};
153+
static const unsigned axis = {axis};
154+
static const unsigned epsilon_power_of_10 = {epsilon_power_of_10};
155+
static const unsigned table_range_power2 = {table_range_power2};
156+
static const unsigned table_size = {table_size};
157+
typedef {accum_t.name} accum_t;
158+
typedef {bias_t.name} bias_t;
159+
typedef {scale_t.name} scale_t;
160+
typedef {table_t.name} table_t;
161+
static const unsigned io_type = nnet::{iotype};
162+
static const unsigned reuse_factor = {reuse};
163+
template<class x_T, class y_T>
164+
using product = nnet::product::{product_type}<x_T, y_T>;
165+
}};\n"""
166+
167+
layernorm_function_template = 'nnet::layernormalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});'
168+
169+
layernorm_include_list = ['nnet_utils/nnet_layernorm.h']
170+
171+
172+
class LayerNormalizationConfigTemplate(LayerConfigTemplate):
173+
def __init__(self):
174+
super().__init__(LayerNormalization)
175+
self.template = layernorm_config_template
176+
177+
def format(self, node):
178+
params = self._default_config_params(node)
179+
params['n_in'] = node.get_input_variable().size_cpp()
180+
params['product_type'] = get_backend('vivado').product_type(
181+
node.get_input_variable().type.precision, node.get_weights('scale').type.precision
182+
)
183+
184+
return self.template.format(**params)
185+
186+
187+
class LayerNormalizationFunctionTemplate(FunctionCallTemplate):
188+
def __init__(self):
189+
super().__init__(LayerNormalization, include_header=layernorm_include_list)
190+
self.template = layernorm_function_template
191+
192+
def format(self, node):
193+
params = self._default_function_params(node)
194+
params['scale'] = node.get_weights('scale').name
195+
params['bias'] = node.get_weights('bias').name
196+
197+
return self.template.format(**params)
198+
199+
139200
# Activation templates
140201

141202
activ_config_template = """struct {type}_config{index} : nnet::activ_config {{

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
GarNet,
2525
GarNetStack,
2626
Layer,
27+
LayerNormalization,
2728
Pooling1D,
2829
Pooling2D,
2930
SeparableConv1D,
@@ -32,7 +33,7 @@
3233
TimeDistributed,
3334
)
3435
from hls4ml.model.optimizer import get_backend_passes, layer_optimizer
35-
from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType, PackedType
36+
from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType, PackedType, RoundingMode, SaturationMode
3637
from hls4ml.report import parse_vivado_report
3738
from hls4ml.utils import attribute_descriptions as descriptions
3839
from hls4ml.utils.einsum_utils import parse_einsum
@@ -101,6 +102,32 @@ def _register_layer_attributes(self):
101102
)
102103
self.attribute_map[layer] = attrs
103104

105+
# Add LayerNorm attributes
106+
ln_layers = [LayerNormalization]
107+
for layer in ln_layers:
108+
attrs = self.attribute_map.get(layer, [])
109+
attrs.append(ConfigurableAttribute('table_range_power2', default=0, description=descriptions.table_range_power2))
110+
attrs.append(ConfigurableAttribute('table_size', default=4096, description=descriptions.table_size))
111+
attrs.append(
112+
TypeAttribute(
113+
'table',
114+
default=FixedPrecisionType(
115+
8, 5, signed=False, rounding_mode=RoundingMode.RND_CONV, saturation_mode=SaturationMode.SAT
116+
),
117+
description=descriptions.table_type,
118+
)
119+
)
120+
attrs.append(
121+
TypeAttribute(
122+
'accum',
123+
default=FixedPrecisionType(
124+
14, 4, signed=True, rounding_mode=RoundingMode.RND_CONV, saturation_mode=SaturationMode.SAT
125+
),
126+
description=descriptions.accum_type,
127+
)
128+
)
129+
self.attribute_map[layer] = attrs
130+
104131
# Add TimeStepLoopParallelism to TimeDistributed
105132
attrs = self.attribute_map.get(TimeDistributed, [])
106133
attrs.append(

hls4ml/converters/keras/core.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
from hls4ml.converters.keras_v2_to_hls import get_weights_data, keras_handler, parse_default_keras_layer
24
from hls4ml.model.quantizers import BinaryQuantizer, TernaryQuantizer
35
from hls4ml.model.types import IntegerPrecisionType
@@ -131,6 +133,40 @@ def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader):
131133
return layer, [shape for shape in input_shapes[0]]
132134

133135

136+
@keras_handler('LayerNormalization')
137+
def parse_layernorm_layer(keras_layer, input_names, input_shapes, data_reader):
138+
assert 'LayerNormalization' in keras_layer['class_name']
139+
140+
layer = parse_default_keras_layer(keras_layer, input_names)
141+
142+
in_size = 1
143+
for dim in input_shapes[0][1:]:
144+
in_size *= dim
145+
layer['n_in'] = layer['n_out'] = in_size
146+
147+
if not ((len(input_shapes[0])) == 3):
148+
raise Exception(
149+
'input size is not currently supported by hls4ml; '
150+
'only three-dimensional input (including batch dimension) is supported'
151+
)
152+
layer['seq_len'] = input_shapes[0][-2]
153+
154+
if not (keras_layer['config']['axis'][0] == 2):
155+
raise Exception('assigning the axis is not currently supported by hls4ml; only axis 2 is supported')
156+
layer['axis'] = keras_layer['config']['axis'][0]
157+
158+
layer['gamma_data'] = get_weights_data(data_reader, layer['name'], 'gamma')
159+
layer['beta_data'] = get_weights_data(data_reader, layer['name'], 'beta')
160+
161+
if keras_layer['config']['epsilon'] <= 0:
162+
raise Exception('epsilon must be positive')
163+
layer['epsilon_power_of_10'] = -round(math.log10(keras_layer['config']['epsilon']))
164+
if layer['epsilon_power_of_10'] <= 0:
165+
raise Exception('epsilon must be less than 1e-1')
166+
167+
return layer, [shape for shape in input_shapes[0]]
168+
169+
134170
@keras_handler('Embedding')
135171
def parse_embedding_layer(keras_layer, input_names, input_shapes, data_reader):
136172
assert 'Embedding' in keras_layer['class_name']

hls4ml/converters/pytorch/core.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import numpy as np
24

35
from hls4ml.converters.pytorch_to_hls import pytorch_handler
@@ -160,6 +162,42 @@ def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node
160162
return layer, [shape for shape in input_shapes[0]]
161163

162164

165+
@pytorch_handler('LayerNorm')
166+
def parse_layernorm_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
167+
assert 'LayerNorm' in operation
168+
169+
layer = {}
170+
171+
layer['class_name'] = 'LayerNormalization'
172+
layer['name'] = layer_name
173+
layer['inputs'] = input_names
174+
175+
in_size = 1
176+
for dim in input_shapes[0][1:]:
177+
in_size *= dim
178+
layer['n_in'] = layer['n_out'] = in_size
179+
180+
if not ((len(input_shapes[0])) == 3):
181+
raise Exception(
182+
f'Input shape {input_shapes[0]} is not currently supported for LayerNorm; '
183+
'only three-dimensional inputs (including batch dimension) are supported'
184+
)
185+
layer['seq_len'] = input_shapes[0][-2]
186+
187+
layer['axis'] = 2
188+
189+
layer['gamma_data'] = class_object.weight.data.numpy()
190+
layer['beta_data'] = class_object.bias.data.numpy()
191+
192+
if class_object.eps <= 0:
193+
raise Exception('epsilon must be positive')
194+
layer['epsilon_power_of_10'] = -round(math.log10(class_object.eps))
195+
if layer['epsilon_power_of_10'] <= 0:
196+
raise Exception('epsilon must be less than 1e-1')
197+
198+
return layer, [shape for shape in input_shapes[0]]
199+
200+
163201
@pytorch_handler('einsum')
164202
def parse_einsum_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
165203
assert 'einsum' in operation

hls4ml/model/layers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,31 @@ def add_bias(self, bias, quantizer=None, precision=None):
11321132
self.add_weights_variable(name='bias', var_name='b{index}', data=bias, quantizer=quantizer, precision=precision)
11331133

11341134

1135+
class LayerNormalization(Layer):
1136+
_expected_attributes = [
1137+
Attribute('n_in'),
1138+
Attribute('seq_len'),
1139+
Attribute('axis', value_type=int, default=2),
1140+
Attribute('epsilon_power_of_10', value_type=int, default=3),
1141+
WeightAttribute('scale'),
1142+
WeightAttribute('bias'),
1143+
TypeAttribute('scale'),
1144+
TypeAttribute('bias'),
1145+
]
1146+
1147+
def initialize(self):
1148+
inp = self.get_input_variable()
1149+
shape = inp.shape
1150+
dims = inp.dim_names
1151+
self.add_output_variable(shape, dims)
1152+
1153+
scale = self.get_attr('gamma_data')
1154+
bias = self.get_attr('beta_data')
1155+
1156+
self.add_weights_variable(name='scale', var_name='s{index}', data=scale)
1157+
self.add_weights_variable(name='bias', var_name='b{index}', data=bias)
1158+
1159+
11351160
class Merge(Layer):
11361161
def initialize(self):
11371162
assert len(self.inputs) == 2
@@ -1902,6 +1927,7 @@ def initialize(self):
19021927
'BatchNormOnnx': BatchNormOnnx,
19031928
'LayerGroup': LayerGroup,
19041929
'SymbolicExpression': SymbolicExpression,
1930+
'LayerNormalization': LayerNormalization,
19051931
'EinsumDense': EinsumDense,
19061932
'Einsum': Einsum,
19071933
# TensorFlow-specific layers:

hls4ml/model/optimizer/passes/convert_to_channels_last.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Based on https://github.com/fastmachinelearning/qonnx/blob/
33
# 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py
44

5-
from hls4ml.model.layers import Concatenate, Dense, Input, Reshape, Transpose
5+
from hls4ml.model.layers import Concatenate, Dense, Input, LayerNormalization, Reshape, Transpose
66
from hls4ml.model.optimizer import OptimizerPass
77
from hls4ml.model.types import WeightVariable
88

@@ -13,8 +13,9 @@ class ChannelsLastConverter(OptimizerPass):
1313

1414
def match(self, node):
1515
# If this parameter has not been set, this model does not need to be converted
16-
if 'ChannelsLastConversion' not in node.model.config.config['HLSConfig']['Model']:
17-
return False # No littering of unused property
16+
do_convert = node.model.config.config['HLSConfig']['Model'].get('ChannelsLastConversion', 'off')
17+
if do_convert == 'off':
18+
return False
1819
if not hasattr(node, 'channels_last_converted'):
1920
return True
2021

@@ -44,6 +45,22 @@ def transform(self, model, node):
4445
node.get_output_variable().shape = input_shape
4546
dim_names = [f'N_INPUT_{i}_{node.index}' for i in range(1, len(input_shape) + 1)]
4647
node.get_output_variable().dim_names = dim_names
48+
elif isinstance(node, LayerNormalization):
49+
# LayerNorm only works on the last dimension in PyTorch
50+
perm = [1, 0]
51+
pre_transpose = model.make_node(
52+
'Transpose', f'pre_transpose_for_{node.get_attr("name")}', {'perm': perm}, [node.get_input_node().name]
53+
)
54+
pre_transpose.channels_last_converted = True
55+
model.insert_node(pre_transpose)
56+
57+
# If not the output layer, transpose again
58+
if not node.get_attr('name') in model.outputs or model.config.config['HLSConfig']['Model']['TransposeOutputs']:
59+
post_transpose = model.make_node(
60+
'Transpose', f'post_transpose_for_{node.get_attr("name")}', {'perm': perm}, [node.name]
61+
)
62+
post_transpose.channels_last_converted = True
63+
model.insert_node(post_transpose)
4764
else:
4865
# Transpose weight tensors
4966
tensors = ['weight', 'depthwise', 'pointwise', 'zero_bias', 'scale', 'recurrent_weight']

hls4ml/model/optimizer/passes/infer_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _infer_precision(self, node, types_to_infer):
5151
if node_class in ['Dense']:
5252
return self._infer_dense_precision(node, types_to_infer)
5353

54-
if node_class in ['BatchNormalization', 'ApplyAlpha']:
54+
if node_class in ['BatchNormalization', 'ApplyAlpha', 'LayerNormalization']:
5555
return self._infer_bn_precision(node, types_to_infer)
5656

5757
if node_class in ['Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D', 'Conv2DBatchnorm']:

hls4ml/model/profiling.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,18 @@ def _keras_layer(layer):
293293
return layer.get_weights(), ['w', 'b']
294294

295295

296+
def _keras_layernorm(layer):
297+
weights = layer.get_weights()
298+
299+
gamma = weights[0]
300+
beta = weights[1]
301+
302+
scale = gamma
303+
bias = beta
304+
305+
return [scale, bias], ['s', 'b']
306+
307+
296308
def _keras_lstm(layer):
297309
return layer.get_weights(), ['w', 'u', 'b']
298310

@@ -302,6 +314,7 @@ def _keras_lstm(layer):
302314
{
303315
'BatchNormalization': _keras_batchnorm,
304316
'QBatchNormalization': _keras_batchnorm,
317+
'LayerNormalization': _keras_layernorm,
305318
'LSTM': _keras_lstm,
306319
'QLSTM': _keras_lstm,
307320
},

0 commit comments

Comments
 (0)