Skip to content

Commit 72a8d64

Browse files
committed
pooling layers
1 parent 84e9a55 commit 72a8d64

File tree

1 file changed

+41
-51
lines changed

1 file changed

+41
-51
lines changed

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import typing
22
from copy import copy
33
from functools import reduce, singledispatch
4-
from math import ceil, log2
4+
from math import ceil, log2, prod
55
from typing import Sequence
66
from warnings import warn
77

@@ -17,10 +17,12 @@
1717
Einsum,
1818
EinsumDense,
1919
GlobalPooling1D,
20+
GlobalPooling2D,
2021
Input,
2122
Layer,
2223
Merge,
2324
Pooling1D,
25+
Pooling2D,
2426
Reshape,
2527
Softmax,
2628
)
@@ -101,52 +103,6 @@ def _(layer: FixedPointQuantizer):
101103
return ((k, i, f),)
102104

103105

104-
@request_kif.register(Pooling1D)
105-
# @request_kif.register(Pooling2D)
106-
@request_kif.register(GlobalPooling1D)
107-
# @request_kif.register(GlobalPooling2D)
108-
def _(layer: Pooling1D | GlobalPooling1D):
109-
# inp_shape = get_input_shapes(layer)[0]
110-
out_shape = get_output_shape(layer)
111-
pool_width = layer.attributes.attributes['pool_width']
112-
stride_width = layer.attributes.attributes['stride_width']
113-
pool_op = layer.attributes.attributes['pool_op']
114-
if isinstance(layer, Pooling1D):
115-
pad_0_0: int = layer.attributes.attributes['pad_left']
116-
else:
117-
pad_0_0 = 0
118-
is_ch_last = layer.attributes.attributes['data_format'] == 'channels_last'
119-
120-
k = np.ones(out_shape, dtype=np.int8)
121-
i = np.full(out_shape, -127, dtype=np.int8)
122-
f = np.full(out_shape, 126, dtype=np.int8)
123-
124-
_, i_out, f_out = requested_kif(layer)
125-
126-
if not is_ch_last:
127-
i = np.moveaxis(i, 0, -1)
128-
f = np.moveaxis(f, 0, -1)
129-
130-
for idx_out in range(k.shape[-1]):
131-
i_in_0 = i_out * stride_width - pad_0_0
132-
i_in_1 = i_in_0 + pool_width
133-
if i_in_0 < 0:
134-
i_in_0 = 0
135-
i[..., i_in_0:i_in_1] = i_out[..., idx_out]
136-
f[..., i_in_0:i_in_1] = f_out[..., idx_out]
137-
138-
if not is_ch_last:
139-
i = np.moveaxis(i, -1, 0)
140-
f = np.moveaxis(f, -1, 0)
141-
142-
if pool_op == 'Average':
143-
ln2_size = np.log2(pool_width)
144-
i += np.ceil(ln2_size).astype(np.int8)
145-
if not ln2_size.is_integer():
146-
f[:] = 126
147-
return ((k, i, f),)
148-
149-
150106
@request_kif.register
151107
def _(layer: Reshape):
152108
inp_shape = get_input_shapes(layer)[0]
@@ -332,15 +288,15 @@ def im2col(kernel_size: Sequence[int], *arrs: np.ndarray):
332288

333289
def pad_arrs(node: Layer, pad_val: float = 0, *arrs: np.ndarray):
334290
out_arrs = []
335-
if node.class_name.endswith('Conv2D'):
291+
if node.class_name.endswith('2D'):
336292
pad_top = node.attributes.attributes['pad_top']
337293
pad_bottom = node.attributes.attributes['pad_bottom']
338294
pad_left = node.attributes.attributes['pad_left']
339295
pad_right = node.attributes.attributes['pad_right']
340296
for arr in arrs:
341297
r = np.pad(arr, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), constant_values=pad_val)
342298
out_arrs.append(r)
343-
elif node.class_name.endswith('Conv1D'):
299+
elif node.class_name.endswith('1D'):
344300
pad_left = node.attributes.attributes['pad_left']
345301
pad_right = node.attributes.attributes['pad_right']
346302
for arr in arrs:
@@ -352,11 +308,11 @@ def pad_arrs(node: Layer, pad_val: float = 0, *arrs: np.ndarray):
352308

353309

354310
def stride_arrs(node: Layer, *arrs: np.ndarray):
355-
if node.class_name.endswith('Conv2D'):
311+
if node.class_name.endswith('2D'):
356312
st_h = node.attributes.attributes['stride_height']
357313
st_w = node.attributes.attributes['stride_width']
358314
return tuple(arr[::st_h, ::st_w] for arr in arrs)
359-
if node.class_name.endswith('Conv1D'):
315+
if node.class_name.endswith('1D'):
360316
st_w = node.attributes.attributes['stride_width']
361317
return tuple(arr[::st_w] for arr in arrs)
362318
raise ValueError(f'Layer {node.class_name} is not supported for stride_arrs')
@@ -365,6 +321,7 @@ def stride_arrs(node: Layer, *arrs: np.ndarray):
365321
@produce_kif.register(Conv1D)
366322
@produce_kif.register(Conv2D)
367323
def _(layer: Conv1D | Conv2D):
324+
assert layer.attributes.attributes['data_format'] == 'channels_last', 'Only channels_last format is supported'
368325
kernel = layer.attributes.attributes['weight'].data
369326
_bias = layer.attributes.attributes['bias']
370327
bias = _bias.data if _bias is not None else 0
@@ -380,6 +337,39 @@ def _(layer: Conv1D | Conv2D):
380337
return k.astype(np.int8), i, f
381338

382339

340+
@produce_kif.register(Pooling1D)
341+
@produce_kif.register(Pooling2D)
342+
@produce_kif.register(GlobalPooling1D)
343+
@produce_kif.register(GlobalPooling2D)
344+
def _(layer: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
345+
if isinstance(layer, (Pooling1D, GlobalPooling1D)):
346+
px_shape = (layer.attributes['pool_width'],)
347+
else:
348+
px_shape = (layer.attributes['pool_height'], layer.attributes['pool_width'])
349+
ch_out = ch_in = layer.attributes['n_filt']
350+
351+
im2col_shape = *px_shape, ch_in, ch_out # conv kernel shape
352+
k_in, i_in, f_in = get_input_kifs(layer)[0]
353+
if isinstance(layer, (Pooling1D, Pooling2D)):
354+
k_in, i_in, f_in = pad_arrs(layer, 0, k_in, i_in, f_in)
355+
k_in, i_in, f_in = im2col(im2col_shape, k_in, i_in, f_in)
356+
if isinstance(layer, (Pooling1D, Pooling2D)):
357+
k_in, i_in, f_in = stride_arrs(layer, k_in, i_in, f_in)
358+
359+
k_out = k_in.reshape(*k_in.shape[:-1], -1, ch_in).max(axis=-2).astype(np.int8)
360+
i_out = i_in.reshape(*i_in.shape[:-1], -1, ch_in).max(axis=-2).astype(np.int8)
361+
f_out = f_in.reshape(*f_in.shape[:-1], -1, ch_in).max(axis=-2).astype(np.int8)
362+
363+
pool_op = layer.attributes['pool_op']
364+
if pool_op == 'Average':
365+
f_add = log2(prod(px_shape))
366+
if not f_add.is_integer():
367+
raise ValueError('Average pooling with non-power-of-2 pool size cannot be bit-exact')
368+
f_out += int(f_add)
369+
370+
return k_out, i_out, f_out
371+
372+
383373
@produce_kif.register
384374
def _(layer: BatchNormalization):
385375
k_in, i_in, f_in = get_input_kifs(layer)[0]

0 commit comments

Comments
 (0)