|
24 | 24 | Reshape,
|
25 | 25 | Softmax,
|
26 | 26 | )
|
27 |
| -from hls4ml.model.optimizer import OptimizerPass |
| 27 | +from hls4ml.model.optimizer import ModelOptimizerPass, OptimizerPass |
28 | 28 | from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer, UnaryLUT
|
29 | 29 | from hls4ml.model.types import FixedPrecisionType, NamedType, RoundingMode, SaturationMode, WeightVariable
|
30 | 30 | from hls4ml.utils.qinterval import QIntervalArray, einsum, minimal_kif
|
@@ -545,22 +545,32 @@ def _(node: Softmax):
|
545 | 545 |
|
546 | 546 | @register_precision.register
|
547 | 547 | def _(node: UnaryLUT):
|
548 |
| - k, i, f = minimal_kif(node.attributes['table'].data) |
| 548 | + k, i, f = minimal_kif(node.attributes['table'].data) # type: ignore |
549 | 549 | k, i, f = bool(np.max(k)), int(np.max(i)), int(np.max(f))
|
550 | 550 | table_t = to_hls4ml_fixed(k, i, f, f'{node.name}_table_t')
|
551 | 551 | node.attributes['table_t'] = table_t
|
552 | 552 | default_register_precision(node)
|
553 | 553 |
|
554 | 554 |
|
555 |
| -class BitExact(OptimizerPass): |
556 |
| - def match(self, node): |
557 |
| - if node.attributes.get('bit_exact_transformed'): |
| 555 | +class BitExact(ModelOptimizerPass): |
| 556 | + def __init__(self): |
| 557 | + pass |
| 558 | + |
| 559 | + def _match(self, model: 'ModelGraph'): |
| 560 | + if not any(isinstance(node, FixedPointQuantizer) for node in model.graph.values()): |
558 | 561 | return False
|
559 | 562 | return True
|
560 | 563 |
|
561 |
| - def transform(self, model, node): |
562 |
| - register_precision(node) |
563 |
| - node.attributes['bit_exact_transformed'] = True |
| 564 | + def transform(self, model): |
| 565 | + if not self._match(model): |
| 566 | + return False |
| 567 | + |
| 568 | + for node in model.graph.values(): |
| 569 | + if node.attributes.get('bit_exact_transformed'): |
| 570 | + return False |
| 571 | + register_precision(node) |
| 572 | + node.attributes['bit_exact_transformed'] = True |
| 573 | + |
564 | 574 | return False
|
565 | 575 |
|
566 | 576 |
|
|
0 commit comments