Skip to content

Commit 7d62e1a

Browse files
committed
bit-exact concatenate
1 parent d0fbbf1 commit 7d62e1a

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

hls4ml/model/optimizer/passes/bit_exact.py

+26
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from hls4ml.model.layers import (
1212
Activation,
1313
BatchNormalization,
14+
Concatenate,
1415
Conv1D,
1516
Conv2D,
1617
Dense,
@@ -126,6 +127,20 @@ def _(layer: Activation):
126127
return (_maximum_kif_at_shape(inp_shape),)
127128

128129

130+
@request_kif.register
131+
def _(layer: Concatenate):
132+
inp_shape0, inp_shape1 = get_input_shapes(layer)
133+
k, i, f = requested_kif(layer)
134+
ax = layer.attributes['axis']
135+
n_split = inp_shape0[ax]
136+
137+
k0, k1 = np.split(k, [n_split], axis=ax)
138+
i0, i1 = np.split(i, [n_split], axis=ax)
139+
f0, f1 = np.split(f, [n_split], axis=ax)
140+
141+
return ((k0, i0, f0), (k1, i1, f1))
142+
143+
129144
def requested_kif(layer: Layer) -> KIF_t:
130145
out_layers = get_output_layers(layer)
131146
out_shape = get_output_shape(layer)
@@ -403,6 +418,17 @@ def _(layer: Softmax):
403418
return k, i, f
404419

405420

421+
@produce_kif.register
422+
def _(layer: Concatenate):
423+
kifs_in = get_input_kifs(layer)
424+
ks, is_, fs = zip(*kifs_in)
425+
ax = layer.attributes.attributes['axis']
426+
k = np.concatenate(ks, axis=ax)
427+
i = np.concatenate(is_, axis=ax)
428+
f = np.concatenate(fs, axis=ax)
429+
return k, i, f
430+
431+
406432
@produce_kif.register
407433
def _(layer: Activation):
408434
fn_name = layer.attributes.attributes['activation']

0 commit comments

Comments
 (0)