Skip to content

Commit 84e9a55

Browse files
committed
switch to model opt
1 parent 8200142 commit 84e9a55

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
Reshape,
2525
Softmax,
2626
)
27-
from hls4ml.model.optimizer import OptimizerPass
27+
from hls4ml.model.optimizer import ModelOptimizerPass, OptimizerPass
2828
from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer, UnaryLUT
2929
from hls4ml.model.types import FixedPrecisionType, NamedType, RoundingMode, SaturationMode, WeightVariable
3030
from hls4ml.utils.qinterval import QIntervalArray, einsum, minimal_kif
@@ -545,22 +545,32 @@ def _(node: Softmax):
545545

546546
@register_precision.register
547547
def _(node: UnaryLUT):
548-
k, i, f = minimal_kif(node.attributes['table'].data)
548+
k, i, f = minimal_kif(node.attributes['table'].data) # type: ignore
549549
k, i, f = bool(np.max(k)), int(np.max(i)), int(np.max(f))
550550
table_t = to_hls4ml_fixed(k, i, f, f'{node.name}_table_t')
551551
node.attributes['table_t'] = table_t
552552
default_register_precision(node)
553553

554554

555-
class BitExact(OptimizerPass):
556-
def match(self, node):
557-
if node.attributes.get('bit_exact_transformed'):
555+
class BitExact(ModelOptimizerPass):
556+
def __init__(self):
557+
pass
558+
559+
def _match(self, model: 'ModelGraph'):
560+
if not any(isinstance(node, FixedPointQuantizer) for node in model.graph.values()):
558561
return False
559562
return True
560563

561-
def transform(self, model, node):
562-
register_precision(node)
563-
node.attributes['bit_exact_transformed'] = True
564+
def transform(self, model):
565+
if not self._match(model):
566+
return False
567+
568+
for node in model.graph.values():
569+
if node.attributes.get('bit_exact_transformed'):
570+
return False
571+
register_precision(node)
572+
node.attributes['bit_exact_transformed'] = True
573+
564574
return False
565575

566576

0 commit comments

Comments
 (0)