@@ -407,10 +407,7 @@ def _(layer: Conv1D | Conv2D):
407
407
@_produce_kif .register (GlobalPooling1D )
408
408
@_produce_kif .register (GlobalPooling2D )
409
409
def _ (layer : Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D ):
410
- if isinstance (layer , (Pooling1D , GlobalPooling1D )):
411
- px_shape = (layer .attributes ['pool_width' ],)
412
- else :
413
- px_shape = (layer .attributes ['pool_height' ], layer .attributes ['pool_width' ])
410
+ px_shape = _get_px_shape (layer )
414
411
ch_out = ch_in = layer .attributes ['n_filt' ]
415
412
416
413
im2col_shape = * px_shape , ch_in , ch_out # conv kernel shape
@@ -432,6 +429,8 @@ def _(layer: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
432
429
raise ValueError ('Average pooling with non-power-of-2 pool size cannot be bit-exact' )
433
430
f_out += int (f_add )
434
431
432
+ if isinstance (layer , (GlobalPooling1D , GlobalPooling2D )):
433
+ k_out , i_out , f_out = k_out [0 ], i_out [0 ], f_out [0 ]
435
434
return k_out , i_out , f_out
436
435
437
436
@@ -665,6 +664,22 @@ def _(node: UnaryLUT):
665
664
default_register_precision (node )
666
665
667
666
667
+ def _get_px_shape (node : Layer ):
668
+ if isinstance (node , Pooling1D ):
669
+ px_shape = (node .attributes ['pool_width' ],)
670
+ elif isinstance (node , GlobalPooling1D ):
671
+ inp_shape = get_input_shapes (node )[0 ]
672
+ px_shape = (inp_shape [0 ],)
673
+ elif isinstance (node , Pooling2D ):
674
+ px_shape = (node .attributes ['pool_height' ], node .attributes ['pool_width' ])
675
+ elif isinstance (node , GlobalPooling2D ):
676
+ inp_shape = get_input_shapes (node )[0 ]
677
+ px_shape = (inp_shape [0 ], inp_shape [1 ])
678
+ else :
679
+ raise ValueError (f'Layer { node .class_name } is not supported for pooling precision derivation' )
680
+ return px_shape
681
+
682
+
668
683
@register_precision .register (Pooling1D )
669
684
@register_precision .register (Pooling2D )
670
685
@register_precision .register (GlobalPooling1D )
@@ -674,10 +689,7 @@ def _(node: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
674
689
pool_op = node .attributes ['pool_op' ]
675
690
if pool_op != 'Average' :
676
691
return
677
- if isinstance (node , (Pooling1D , GlobalPooling1D )):
678
- px_shape = (node .attributes ['pool_width' ],)
679
- else :
680
- px_shape = (node .attributes ['pool_height' ], node .attributes ['pool_width' ])
692
+ px_shape = _get_px_shape (node )
681
693
i_add = int (log2 (prod (px_shape )))
682
694
node .attributes ['accum_t' ].precision .width += i_add
683
695
node .attributes ['accum_t' ].precision .integer += i_add
0 commit comments