Skip to content

Commit d0fbbf1

Browse files
committed
fix pooling layer accum_t
1 parent d59d246 commit d0fbbf1

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

hls4ml/model/optimizer/passes/bit_exact.py

+18
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,24 @@ def _(node: UnaryLUT):
542542
default_register_precision(node)
543543

544544

545+
@register_precision.register(Pooling1D)
546+
@register_precision.register(Pooling2D)
547+
@register_precision.register(GlobalPooling1D)
548+
@register_precision.register(GlobalPooling2D)
549+
def _(node: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
550+
default_register_precision(node)
551+
pool_op = node.attributes['pool_op']
552+
if pool_op != 'Average':
553+
return
554+
if isinstance(node, (Pooling1D, GlobalPooling1D)):
555+
px_shape = (node.attributes['pool_width'],)
556+
else:
557+
px_shape = (node.attributes['pool_height'], node.attributes['pool_width'])
558+
i_add = int(log2(prod(px_shape)))
559+
node.attributes['accum_t'].precision.width += i_add
560+
node.attributes['accum_t'].precision.integer += i_add
561+
562+
545563
class BitExact(ModelOptimizerPass):
546564
def __init__(self):
547565
pass

0 commit comments

Comments
 (0)