|
9 | 9 | static const unsigned n_out = {n_out}; |
10 | 10 | static const unsigned reuse_factor = {reuse}; |
11 | 11 | static const unsigned strategy = nnet::{strategy}; |
12 | | - static const unsigned n_zeros = 0; |
| 12 | + static const unsigned n_zeros = {nzeros}; |
13 | 13 | static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; |
14 | 14 | typedef {accum_t.name} accum_t; |
15 | 15 | typedef {bias_t.name} bias_t; |
@@ -83,6 +83,7 @@ def format(self, node): |
83 | 83 | mult_params = self._default_config_params(node) |
84 | 84 | mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width') |
85 | 85 | mult_params['n_out'] = node.get_attr('n_filt') |
| 86 | + mult_params['nzeros'] = node.get_weights('weight').nzeros |
86 | 87 | mult_params['product_type'] = get_backend('vivado').product_type( |
87 | 88 | node.get_input_variable().type.precision, node.get_weights('weight').type.precision |
88 | 89 | ) |
@@ -189,6 +190,7 @@ def format(self, node): |
189 | 190 | mult_params = self._default_config_params(node) |
190 | 191 | mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_height') * node.get_attr('filt_width') |
191 | 192 | mult_params['n_out'] = node.get_attr('n_filt') |
| 193 | + mult_params['nzeros'] = node.get_weights('weight').nzeros |
192 | 194 | mult_params['product_type'] = get_backend('vivado').product_type( |
193 | 195 | node.get_input_variable().type.precision, node.get_weights('weight').type.precision |
194 | 196 | ) |
@@ -274,6 +276,7 @@ def format(self, node): |
274 | 276 | mult_params['index'] = str(node.index) + '_depthwise' |
275 | 277 | mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width') |
276 | 278 | mult_params['n_out'] = node.get_attr('n_chan') |
| 279 | + mult_params['nzeros'] = node.get_weights('depthwise').nzeros |
277 | 280 | mult_params['weight_t'] = node.get_weights('depthwise').type |
278 | 281 | mult_params['product_type'] = get_backend('vivado').product_type( |
279 | 282 | node.get_input_variable().type.precision, node.get_weights('depthwise').type.precision |
@@ -313,6 +316,7 @@ def format(self, node): |
313 | 316 | mult_params['index'] = str(node.index) + '_pointwise' |
314 | 317 | mult_params['n_in'] = node.get_attr('n_chan') |
315 | 318 | mult_params['n_out'] = node.get_attr('n_filt') |
| 319 | + mult_params['nzeros'] = node.get_weights('pointwise').nzeros |
316 | 320 | mult_params['weight_t'] = node.get_weights('pointwise').type |
317 | 321 | mult_params['product_type'] = get_backend('vivado').product_type( |
318 | 322 | node.get_input_variable().type.precision, node.get_weights('pointwise').type.precision |
@@ -395,6 +399,7 @@ def format(self, node): |
395 | 399 | mult_params['index'] = str(node.index) + '_depthwise' |
396 | 400 | mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_height') * node.get_attr('filt_width') |
397 | 401 | mult_params['n_out'] = node.get_attr('n_chan') |
| 402 | + mult_params['nzeros'] = node.get_weights('depthwise').nzeros |
398 | 403 | mult_params['weight_t'] = node.get_weights('depthwise').type |
399 | 404 | mult_params['product_type'] = get_backend('vivado').product_type( |
400 | 405 | node.get_input_variable().type.precision, node.get_weights('depthwise').type.precision |
@@ -438,6 +443,7 @@ def format(self, node): |
438 | 443 | mult_params['index'] = str(node.index) + '_pointwise' |
439 | 444 | mult_params['n_in'] = node.get_attr('n_chan') |
440 | 445 | mult_params['n_out'] = node.get_attr('n_filt') |
| 446 | + mult_params['nzeros'] = node.get_weights('pointwise').nzeros |
441 | 447 | mult_params['weight_t'] = node.get_weights('pointwise').type |
442 | 448 | mult_params['product_type'] = get_backend('vivado').product_type( |
443 | 449 | node.get_input_variable().type.precision, node.get_weights('pointwise').type.precision |
|
0 commit comments