1
1
import typing
2
2
from copy import copy
3
3
from functools import reduce , singledispatch
4
- from math import ceil , log2
4
+ from math import ceil , log2 , prod
5
5
from typing import Sequence
6
6
from warnings import warn
7
7
17
17
Einsum ,
18
18
EinsumDense ,
19
19
GlobalPooling1D ,
20
+ GlobalPooling2D ,
20
21
Input ,
21
22
Layer ,
22
23
Merge ,
23
24
Pooling1D ,
25
+ Pooling2D ,
24
26
Reshape ,
25
27
Softmax ,
26
28
)
@@ -101,52 +103,6 @@ def _(layer: FixedPointQuantizer):
101
103
return ((k , i , f ),)
102
104
103
105
104
- @request_kif .register (Pooling1D )
105
- # @request_kif.register(Pooling2D)
106
- @request_kif .register (GlobalPooling1D )
107
- # @request_kif.register(GlobalPooling2D)
108
- def _ (layer : Pooling1D | GlobalPooling1D ):
109
- # inp_shape = get_input_shapes(layer)[0]
110
- out_shape = get_output_shape (layer )
111
- pool_width = layer .attributes .attributes ['pool_width' ]
112
- stride_width = layer .attributes .attributes ['stride_width' ]
113
- pool_op = layer .attributes .attributes ['pool_op' ]
114
- if isinstance (layer , Pooling1D ):
115
- pad_0_0 : int = layer .attributes .attributes ['pad_left' ]
116
- else :
117
- pad_0_0 = 0
118
- is_ch_last = layer .attributes .attributes ['data_format' ] == 'channels_last'
119
-
120
- k = np .ones (out_shape , dtype = np .int8 )
121
- i = np .full (out_shape , - 127 , dtype = np .int8 )
122
- f = np .full (out_shape , 126 , dtype = np .int8 )
123
-
124
- _ , i_out , f_out = requested_kif (layer )
125
-
126
- if not is_ch_last :
127
- i = np .moveaxis (i , 0 , - 1 )
128
- f = np .moveaxis (f , 0 , - 1 )
129
-
130
- for idx_out in range (k .shape [- 1 ]):
131
- i_in_0 = i_out * stride_width - pad_0_0
132
- i_in_1 = i_in_0 + pool_width
133
- if i_in_0 < 0 :
134
- i_in_0 = 0
135
- i [..., i_in_0 :i_in_1 ] = i_out [..., idx_out ]
136
- f [..., i_in_0 :i_in_1 ] = f_out [..., idx_out ]
137
-
138
- if not is_ch_last :
139
- i = np .moveaxis (i , - 1 , 0 )
140
- f = np .moveaxis (f , - 1 , 0 )
141
-
142
- if pool_op == 'Average' :
143
- ln2_size = np .log2 (pool_width )
144
- i += np .ceil (ln2_size ).astype (np .int8 )
145
- if not ln2_size .is_integer ():
146
- f [:] = 126
147
- return ((k , i , f ),)
148
-
149
-
150
106
@request_kif .register
151
107
def _ (layer : Reshape ):
152
108
inp_shape = get_input_shapes (layer )[0 ]
@@ -332,15 +288,15 @@ def im2col(kernel_size: Sequence[int], *arrs: np.ndarray):
332
288
333
289
def pad_arrs (node : Layer , pad_val : float = 0 , * arrs : np .ndarray ):
334
290
out_arrs = []
335
- if node .class_name .endswith ('Conv2D ' ):
291
+ if node .class_name .endswith ('2D ' ):
336
292
pad_top = node .attributes .attributes ['pad_top' ]
337
293
pad_bottom = node .attributes .attributes ['pad_bottom' ]
338
294
pad_left = node .attributes .attributes ['pad_left' ]
339
295
pad_right = node .attributes .attributes ['pad_right' ]
340
296
for arr in arrs :
341
297
r = np .pad (arr , ((pad_top , pad_bottom ), (pad_left , pad_right ), (0 , 0 )), constant_values = pad_val )
342
298
out_arrs .append (r )
343
- elif node .class_name .endswith ('Conv1D ' ):
299
+ elif node .class_name .endswith ('1D ' ):
344
300
pad_left = node .attributes .attributes ['pad_left' ]
345
301
pad_right = node .attributes .attributes ['pad_right' ]
346
302
for arr in arrs :
@@ -352,11 +308,11 @@ def pad_arrs(node: Layer, pad_val: float = 0, *arrs: np.ndarray):
352
308
353
309
354
310
def stride_arrs (node : Layer , * arrs : np .ndarray ):
355
- if node .class_name .endswith ('Conv2D ' ):
311
+ if node .class_name .endswith ('2D ' ):
356
312
st_h = node .attributes .attributes ['stride_height' ]
357
313
st_w = node .attributes .attributes ['stride_width' ]
358
314
return tuple (arr [::st_h , ::st_w ] for arr in arrs )
359
- if node .class_name .endswith ('Conv1D ' ):
315
+ if node .class_name .endswith ('1D ' ):
360
316
st_w = node .attributes .attributes ['stride_width' ]
361
317
return tuple (arr [::st_w ] for arr in arrs )
362
318
raise ValueError (f'Layer { node .class_name } is not supported for stride_arrs' )
@@ -365,6 +321,7 @@ def stride_arrs(node: Layer, *arrs: np.ndarray):
365
321
@produce_kif .register (Conv1D )
366
322
@produce_kif .register (Conv2D )
367
323
def _ (layer : Conv1D | Conv2D ):
324
+ assert layer .attributes .attributes ['data_format' ] == 'channels_last' , 'Only channels_last format is supported'
368
325
kernel = layer .attributes .attributes ['weight' ].data
369
326
_bias = layer .attributes .attributes ['bias' ]
370
327
bias = _bias .data if _bias is not None else 0
@@ -380,6 +337,39 @@ def _(layer: Conv1D | Conv2D):
380
337
return k .astype (np .int8 ), i , f
381
338
382
339
340
+ @produce_kif .register (Pooling1D )
341
+ @produce_kif .register (Pooling2D )
342
+ @produce_kif .register (GlobalPooling1D )
343
+ @produce_kif .register (GlobalPooling2D )
344
+ def _ (layer : Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D ):
345
+ if isinstance (layer , (Pooling1D , GlobalPooling1D )):
346
+ px_shape = (layer .attributes ['pool_width' ],)
347
+ else :
348
+ px_shape = (layer .attributes ['pool_height' ], layer .attributes ['pool_width' ])
349
+ ch_out = ch_in = layer .attributes ['n_filt' ]
350
+
351
+ im2col_shape = * px_shape , ch_in , ch_out # conv kernel shape
352
+ k_in , i_in , f_in = get_input_kifs (layer )[0 ]
353
+ if isinstance (layer , (Pooling1D , Pooling2D )):
354
+ k_in , i_in , f_in = pad_arrs (layer , 0 , k_in , i_in , f_in )
355
+ k_in , i_in , f_in = im2col (im2col_shape , k_in , i_in , f_in )
356
+ if isinstance (layer , (Pooling1D , Pooling2D )):
357
+ k_in , i_in , f_in = stride_arrs (layer , k_in , i_in , f_in )
358
+
359
+ k_out = k_in .reshape (* k_in .shape [:- 1 ], - 1 , ch_in ).max (axis = - 2 ).astype (np .int8 )
360
+ i_out = i_in .reshape (* i_in .shape [:- 1 ], - 1 , ch_in ).max (axis = - 2 ).astype (np .int8 )
361
+ f_out = f_in .reshape (* f_in .shape [:- 1 ], - 1 , ch_in ).max (axis = - 2 ).astype (np .int8 )
362
+
363
+ pool_op = layer .attributes ['pool_op' ]
364
+ if pool_op == 'Average' :
365
+ f_add = log2 (prod (px_shape ))
366
+ if not f_add .is_integer ():
367
+ raise ValueError ('Average pooling with non-power-of-2 pool size cannot be bit-exact' )
368
+ f_out += int (f_add )
369
+
370
+ return k_out , i_out , f_out
371
+
372
+
383
373
@produce_kif .register
384
374
def _ (layer : BatchNormalization ):
385
375
k_in , i_in , f_in = get_input_kifs (layer )[0 ]
0 commit comments