Skip to content

Commit f9e22d5

Browse files
committed
fix global pooling bw inference
1 parent a558b49 commit f9e22d5

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

hls4ml/model/optimizer/passes/bit_exact.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,7 @@ def _(layer: Conv1D | Conv2D):
407407
@_produce_kif.register(GlobalPooling1D)
408408
@_produce_kif.register(GlobalPooling2D)
409409
def _(layer: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
410-
if isinstance(layer, (Pooling1D, GlobalPooling1D)):
411-
px_shape = (layer.attributes['pool_width'],)
412-
else:
413-
px_shape = (layer.attributes['pool_height'], layer.attributes['pool_width'])
410+
px_shape = _get_px_shape(layer)
414411
ch_out = ch_in = layer.attributes['n_filt']
415412

416413
im2col_shape = *px_shape, ch_in, ch_out # conv kernel shape
@@ -432,6 +429,8 @@ def _(layer: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
432429
raise ValueError('Average pooling with non-power-of-2 pool size cannot be bit-exact')
433430
f_out += int(f_add)
434431

432+
if isinstance(layer, (GlobalPooling1D, GlobalPooling2D)):
433+
k_out, i_out, f_out = k_out[0], i_out[0], f_out[0]
435434
return k_out, i_out, f_out
436435

437436

@@ -665,6 +664,22 @@ def _(node: UnaryLUT):
665664
default_register_precision(node)
666665

667666

667+
def _get_px_shape(node: Layer):
668+
if isinstance(node, Pooling1D):
669+
px_shape = (node.attributes['pool_width'],)
670+
elif isinstance(node, GlobalPooling1D):
671+
inp_shape = get_input_shapes(node)[0]
672+
px_shape = (inp_shape[0],)
673+
elif isinstance(node, Pooling2D):
674+
px_shape = (node.attributes['pool_height'], node.attributes['pool_width'])
675+
elif isinstance(node, GlobalPooling2D):
676+
inp_shape = get_input_shapes(node)[0]
677+
px_shape = (inp_shape[0], inp_shape[1])
678+
else:
679+
raise ValueError(f'Layer {node.class_name} is not supported for pooling precision derivation')
680+
return px_shape
681+
682+
668683
@register_precision.register(Pooling1D)
669684
@register_precision.register(Pooling2D)
670685
@register_precision.register(GlobalPooling1D)
@@ -674,10 +689,7 @@ def _(node: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
674689
pool_op = node.attributes['pool_op']
675690
if pool_op != 'Average':
676691
return
677-
if isinstance(node, (Pooling1D, GlobalPooling1D)):
678-
px_shape = (node.attributes['pool_width'],)
679-
else:
680-
px_shape = (node.attributes['pool_height'], node.attributes['pool_width'])
692+
px_shape = _get_px_shape(node)
681693
i_add = int(log2(prod(px_shape)))
682694
node.attributes['accum_t'].precision.width += i_add
683695
node.attributes['accum_t'].precision.integer += i_add

0 commit comments

Comments
 (0)