Skip to content

Commit c9c10a2

Browse files
pre-commit-ci[bot]dimdano
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit hooks
1 parent e50140f commit c9c10a2

File tree

8 files changed

+207
-201
lines changed

8 files changed

+207
-201
lines changed

docs/ir/multimodelgraph.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
MultiModelGraph Class
33
=======================
44

5-
This page documents the ``MultiModelGraph`` class, which enables handling multiple subgraphs (each represented as a ``ModelGraph``) derived from a single original model.
6-
The central concept here is the division of a larger model into multiple smaller subgraphs at given layers which can be useful for:
5+
This page documents the ``MultiModelGraph`` class, which enables handling multiple subgraphs (each represented as a ``ModelGraph``) derived from a single original model.
6+
The central concept here is the division of a larger model into multiple smaller subgraphs at given layers which can be useful for:
77

88
* Very large models
99
* Step-wise optimization
@@ -26,8 +26,8 @@ For example, when converting a Keras model, you can specify the layers at which
2626
config = hls4ml.utils.config_from_keras_model(model, granularity='model')
2727
2828
hls_model = hls4ml.converters.convert_from_keras_model(
29-
model,
30-
hls_config=config,
29+
model,
30+
hls_config=config,
3131
backend='vitis',
3232
split_layer_names = ['layer3', 'layer7']
3333
)
@@ -39,10 +39,10 @@ Here, the ``hls_model`` is actually a ``MultiModelGraph`` containing three subgr
3939
Key Methods for MultiModelGraph
4040
----------------------------------
4141

42-
* :ref:`compile <mmg-compile-method>`
43-
* :ref:`predict <mmg-predict-method>`
44-
* :ref:`build <mmg-build-method>`
45-
* :ref:`trace <mmg-trace-method>`
42+
* :ref:`compile <mmg-compile-method>`
43+
* :ref:`predict <mmg-predict-method>`
44+
* :ref:`build <mmg-build-method>`
45+
* :ref:`trace <mmg-trace-method>`
4646
* :ref:`make_multi_graph <make_multi_graph-method>`
4747

4848
----

hls4ml/backends/vitis/vitis_backend.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1-
import os
2-
import sys
3-
import subprocess
41
import importlib.util
52
import json
3+
import os
64
import shutil
5+
import subprocess
6+
import sys
77

88
from hls4ml.backends import VivadoBackend
99
from hls4ml.model.flow import get_flow, register_flow
10-
from hls4ml.report import parse_vivado_report, aggregate_graph_reports
11-
from hls4ml.utils.simulation_utils import write_verilog_testbench, read_testbench_log, write_testbench_input, prepare_testbench_input, prepare_zero_input
10+
from hls4ml.report import aggregate_graph_reports, parse_vivado_report
11+
from hls4ml.utils.simulation_utils import (
12+
prepare_testbench_input,
13+
prepare_zero_input,
14+
read_testbench_log,
15+
write_testbench_input,
16+
write_verilog_testbench,
17+
)
1218

1319

1420
class VitisBackend(VivadoBackend):
@@ -81,7 +87,18 @@ def create_initial_config(
8187

8288
return config
8389

84-
def build(self, model, reset=False, csim=True, synth=True, cosim=False, validation=False, export=False, vsynth=False, log_to_stdout=True):
90+
def build(
91+
self,
92+
model,
93+
reset=False,
94+
csim=True,
95+
synth=True,
96+
cosim=False,
97+
validation=False,
98+
export=False,
99+
vsynth=False,
100+
log_to_stdout=True,
101+
):
85102
if 'linux' in sys.platform:
86103
found = os.system('command -v vitis_hls > /dev/null')
87104
if found != 0:
@@ -95,18 +112,13 @@ def build(self, model, reset=False, csim=True, synth=True, cosim=False, validati
95112
output_dir = model.config.get_output_dir()
96113
stdout_log = os.path.join(output_dir, 'build_stdout.log')
97114
stderr_log = os.path.join(output_dir, 'build_stderr.log')
98-
115+
99116
stdout_target = None if log_to_stdout else open(stdout_log, 'w')
100117
stderr_target = None if log_to_stdout else open(stderr_log, 'w')
101118

102119
try:
103120
process = subprocess.Popen(
104-
build_command,
105-
shell=True,
106-
cwd=output_dir,
107-
stdout=stdout_target,
108-
stderr=stderr_target,
109-
text=True
121+
build_command, shell=True, cwd=output_dir, stdout=stdout_target, stderr=stderr_target, text=True
110122
)
111123
process.communicate()
112124

@@ -118,15 +130,16 @@ def build(self, model, reset=False, csim=True, synth=True, cosim=False, validati
118130
stderr_target.close()
119131

120132
return parse_vivado_report(output_dir)
121-
133+
122134
def build_stitched_design(
123135
self,
124136
stitch_design=True,
125137
sim_stitched_design=False,
126138
export_stitched_design=False,
127139
nn_config=None,
128140
graph_reports=None,
129-
simulation_input_data=None):
141+
simulation_input_data=None,
142+
):
130143

131144
os.makedirs(nn_config['OutputDir'], exist_ok=True)
132145
stitched_design_dir = os.path.join(nn_config['OutputDir'], nn_config['StitchedProjectName'])
@@ -136,11 +149,11 @@ def build_stitched_design(
136149

137150
spec = importlib.util.find_spec('hls4ml')
138151
hls4ml_path = os.path.dirname(spec.origin)
139-
ip_stitcher_path = os.path.join(hls4ml_path, 'templates/vivado/ip_stitcher.tcl')
152+
ip_stitcher_path = os.path.join(hls4ml_path, 'templates/vivado/ip_stitcher.tcl')
140153
stdout_log = os.path.join(stitched_design_dir, 'stitcher_stdout.log')
141154
stderr_log = os.path.join(stitched_design_dir, 'stitcher_stderr.log')
142155
nn_config_path = os.path.join(stitched_design_dir, 'nn_config.json')
143-
testbench_path = os.path.join(stitched_design_dir, 'testbench.v')
156+
testbench_path = os.path.join(stitched_design_dir, 'testbench.v')
144157
testbench_log_path = os.path.join(stitched_design_dir, 'testbench_log.csv')
145158

146159
try:
@@ -151,8 +164,8 @@ def build_stitched_design(
151164
if nn_config:
152165
with open(nn_config_path, "w") as file:
153166
json.dump(nn_config, file, indent=4)
154-
155-
if(sim_stitched_design):
167+
168+
if sim_stitched_design:
156169
write_verilog_testbench(nn_config, testbench_path)
157170
# Produce a testbench input file for every input layer
158171
for i, layer in enumerate(nn_config['inputs']):
@@ -165,30 +178,33 @@ def build_stitched_design(
165178
# Handles both single and multi-layer cases. First dim should always be batch size
166179
data = simulation_input_data[i]
167180
input_data_reshaped = prepare_testbench_input(data, layer['fifo_depth'], layer['batch_size'])
168-
write_testbench_input(input_data_reshaped, testbench_input_path, layer['integer_bits'], layer['fractional_bits'])
181+
write_testbench_input(
182+
input_data_reshaped, testbench_input_path, layer['integer_bits'], layer['fractional_bits']
183+
)
169184
print('Verilog testbench and its input data were generated.')
170185

171186
print('Running build process of stitched IP...\n')
172187
stitch_command = [
173-
'vivado', '-mode', 'batch', '-nojournal', '-nolog', '-notrace',
174-
'-source', ip_stitcher_path,
188+
'vivado',
189+
'-mode',
190+
'batch',
191+
'-nojournal',
192+
'-nolog',
193+
'-notrace',
194+
'-source',
195+
ip_stitcher_path,
175196
'-tclargs',
176197
f'stitch_design={int(stitch_design)}',
177198
f'sim_design={int(sim_stitched_design)}',
178199
f'export_design={int(export_stitched_design)}',
179200
f"stitch_project_name={nn_config['StitchedProjectName']}",
180201
f"original_project_name={nn_config['OriginalProjectName']}",
181-
f'sim_verilog_file=testbench.v'
202+
f'sim_verilog_file=testbench.v',
182203
]
183-
204+
184205
with open(stdout_log, 'w') as stdout_file, open(stderr_log, 'w') as stderr_file:
185206
process = subprocess.Popen(
186-
stitch_command,
187-
cwd=stitched_design_dir,
188-
stdout=stdout_file,
189-
stderr=stderr_file,
190-
text=True,
191-
shell=False
207+
stitch_command, cwd=stitched_design_dir, stdout=stdout_file, stderr=stderr_file, text=True, shell=False
192208
)
193209
process.communicate()
194210
if process.returncode != 0:

hls4ml/converters/keras_to_hls.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,16 +322,18 @@ def parse_keras_model(model_arch, reader):
322322
return layer_list, input_layers, output_layers, output_shapes
323323

324324

325-
def keras_to_hls(config, split_layer_names = None):
325+
def keras_to_hls(config, split_layer_names=None):
326326
model_arch, reader = get_model_arch(config)
327327
layer_list, input_layers, output_layers, output_shapes = parse_keras_model(model_arch, reader)
328-
328+
329329
print('Creating HLS model...')
330330
merge_layers = ['add', 'subtract', 'multiply', 'average', 'maximum', 'minimum', 'concatenate', 'dot']
331331
if split_layer_names:
332332
if any(any(layer in name for layer in merge_layers) for name in split_layer_names):
333333
raise ValueError(f'Split layer must not be a merge layer')
334-
hls_model = ModelGraph.make_multi_graph(config, layer_list, input_layers, output_layers, output_shapes, split_layer_names)
334+
hls_model = ModelGraph.make_multi_graph(
335+
config, layer_list, input_layers, output_layers, output_shapes, split_layer_names
336+
)
335337
print('Multi-graph HLS model created.')
336338
else:
337339
hls_model = ModelGraph(config, layer_list, input_layers, output_layers)

0 commit comments

Comments
 (0)