Skip to content

Support for multi graph build #1174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
ee3b51d
test commit
dimdano Oct 11, 2024
cbeee24
split ModelGraph at specified layer name
dimdano Oct 11, 2024
03111c9
feat: add make_multi_graph classmethod to ModelGraph
dimdano Oct 14, 2024
f4a77bb
make_multi_graph can now support arbitrary number of graphs
dimdano Oct 15, 2024
851e835
Pass output_shapes to make_multi_graph to detect input shapes of spli…
dimdano Oct 17, 2024
0e0cf11
fixed layer index in the newly created graph
dimdano Oct 17, 2024
323236b
fix minor mistakes
dimdano Oct 18, 2024
f759a3e
Add TCL script for automatic connection of subgraph IPs in Vivado
dimdano Oct 24, 2024
a5f8277
some minor fixes in tcl script and make_multi_graph
dimdano Oct 29, 2024
07d23ae
support for parallel subgraph builds. Also, make_multi_graph now retu…
dimdano Oct 31, 2024
5dc4ac6
new tcl script
dimdano Nov 12, 2024
202991d
connected external and control signals
dimdano Nov 13, 2024
dc60722
integrate ip_stitcher tcl script in hls4ml
dimdano Nov 14, 2024
bba704b
fix in tcl. folder creation for stitch project
dimdano Nov 18, 2024
da3efb0
package final stitched ip in hls4ml
dimdano Nov 22, 2024
0f40e2a
support for multiple inputs/outputs in first/last layer of stitched ip
dimdano Dec 2, 2024
d24c42b
initial support for stitched ip simulation
dimdano Dec 3, 2024
6e8f462
generate verilog testbench for stitched ip
dimdano Dec 6, 2024
27c76b3
read testbench output
dimdano Dec 9, 2024
704a874
minor changes
dimdano Dec 10, 2024
9d69355
improvements in testbench generation and build interface​
dimdano Dec 11, 2024
d1dd0fd
general improvements
dimdano Dec 12, 2024
0bb10df
only simulate stitched_design, better verilog testbench
dimdano Dec 17, 2024
f1e2e57
prepare testbench input from user
dimdano Dec 18, 2024
55db302
support for user-defined input in verilog testbench of stitched IP
dimdano Dec 19, 2024
0af75e7
fix for multi input/output layers in graph splitting
dimdano Dec 19, 2024
db95628
documentation for MultiModelGraph flow
dimdano Dec 20, 2024
738d489
faster rtl simulation
dimdano Jan 8, 2025
7829e41
unwrap list if it has single element
dimdano Jan 10, 2025
f9fd4c0
Make MultiModelGraph adaptable to user-defined names
dimdano Jan 15, 2025
05ea6c9
stitch script time verbose
dimdano Jan 15, 2025
193381d
fix with existing stitch project folder
dimdano Jan 15, 2025
04ac0f4
initial support for multigraph compilation in bridge file
dimdano Jan 16, 2025
10e95a8
stitched report fix for VivadoSynth aggregate
dimdano Jan 17, 2025
8c5a13b
use log_to_stdout flag for parallel builds
dimdano Jan 21, 2025
4a7e6c3
small change
dimdano Jan 24, 2025
d6c19d5
remove bridged multigraph compilation for now
dimdano Jan 24, 2025
0225845
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] Jan 24, 2025
89f5eb3
fix 'ap_rst' port polarity for active high case
dimdano Jan 28, 2025
e21cb53
support for partition interface in verilog testbench
dimdano Jan 29, 2025
e070ea1
support for MultiModelGraph predict using chained bridge file
dimdano Feb 14, 2025
7fbf439
Add pytest for multi-graph and fix minor issues
dimdano Mar 3, 2025
ba86132
pre-commit fixes
dimdano Mar 4, 2025
773c411
removed pandas dependency in read_testbench_log
dimdano Mar 10, 2025
b91f97a
Ensure stitched RTL simulation results align with CSim output
dimdano Mar 14, 2025
3dcd0d5
parallel subgraph compilation
dimdano Apr 16, 2025
fa3e679
added additional checks in ip_stitcher
dimdano Apr 16, 2025
05d22d3
small improvements on MultiModelGraph
dimdano Apr 16, 2025
3a74eea
correct AXIS port slicing for Verilog simulation
dimdano Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/img/logo_small.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
137 changes: 137 additions & 0 deletions docs/ir/multimodelgraph.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
=======================
MultiModelGraph Class
=======================

This page documents the ``MultiModelGraph`` class, which enables handling multiple subgraphs (each represented as a ``ModelGraph``) derived from a single original model.
The central concept here is the division of a larger model into multiple smaller subgraphs at given layers which can be useful for:

* Very large models
* Step-wise optimization
* Modular design flows

A ``MultiModelGraph`` manages these subgraphs, facilitating:

* Parallel building and synthesis
* Stitched designs (merging the subgraphs in HW after synthesis)
* Simulation and performance estimation of the stitched design

--------------
Keras Example
--------------

For example, when converting a Keras model, you can specify the layers at which to split the model directly:

.. code-block:: python

config = hls4ml.utils.config_from_keras_model(model, granularity='model')

hls_model = hls4ml.converters.convert_from_keras_model(
model,
hls_config=config,
backend='vitis',
split_layer_names = ['layer3', 'layer7']
)

Here, the ``hls_model`` is actually a ``MultiModelGraph`` containing three subgraphs. Each subgraph is a ``ModelGraph`` accessible via indexing: ``hls_model[i]``.


----------------------------------
Key Methods for MultiModelGraph
----------------------------------

* :ref:`compile <mmg-compile-method>`
* :ref:`predict <mmg-predict-method>`
* :ref:`build <mmg-build-method>`
* :ref:`trace <mmg-trace-method>`
* :ref:`make_multi_graph <make_multi_graph-method>`

----

.. _make_multi_graph-method:

``make_multi_graph`` method
===========================

The ``make_multi_graph`` method of ``ModelGraph`` takes a configuration, a full list of layers, the output shapes, and a list of split layers. It returns a ``MultiModelGraph`` that contains multiple ``ModelGraph`` instances.

.. code-block:: python

from my_hls4ml_lib.modelgraph import ModelGraph
multi_graph = ModelGraph.make_multi_graph(config, layer_list, output_shapes, split_layer_names=['fc2', 'fc3'])

This allows modular design flows and easier debugging of large models.

----

.. _mmg-compile-method:

``compile`` method
==================

Compiles all the individual ``ModelGraph`` subgraphs within the ``MultiModelGraph``. Also, compiles a chained bridge file with all the subgraphs linked together that can be used for the predict function.

.. code-block:: python

multi_graph.compile()

----

.. _mmg-build-method:

``build`` method
================

Builds all subgraphs in parallel, each as if they were standalone ``ModelGraph`` projects. Returns reports for each subgraph. If configured, it then runs the stitching flow in Vivado, connecting the individual exported IPs and allowing you to simulate the stitched design at the RTL level.

.. code-block:: python

report = multi_graph.build(export=True, stitch_design=True)

The returned ``report`` contains data from each subgraph's build and, if stitching was performed, a combined report of the stitched design.


----

.. _mmg-predict-method:

``predict`` method
==================

Performs a forward pass through the chained bridge file using the C-simulation (``sim='csim'``). Data is automatically passed from one subgraph's output to the next subgraph's input. For large stitched designs, you can also leverage RTL simulation (``sim='rtl'``) to perform the forward pass at the register-transfer level. In this case, a Verilog testbench is dynamically generated and executed against the stitched IP design, providing behavioral simulation to accurately verify latency and output at the hardware level. Note that the input data for the RTL simulation must have a single batch dimension.

.. code-block:: python

# Perform prediction using C-simulation (default)
y_csim = hls_model.predict(X, sim='csim')

# Perform prediction using RTL simulation (behavioral)
y_rtl = hls_model.predict(X, sim='rtl')


.. _mmg-trace-method:

``trace`` method [TODO]
================

Provides detailed layer-by-layer outputs across all sub-models, which is essential for debugging or tuning quantization and precision settings.

.. code-block:: python

final_output, trace_outputs = hls_model.trace(X)

``trace_outputs`` includes intermediate results from each subgraph, enabling insights into the data flow.

--------------------------
Summary
--------------------------

The ``MultiModelGraph`` class is a tool for modular hardware design. By splitting a large neural network into multiple subgraphs, building each independently, and then stitching them together, you gain flexibility, parallelism, and facilitate hierarchical design, incremental optimization, and integrated system-level simulations.

--------------------------
Other Notes
--------------------------

* Branch Splitting Limitation: Splitting in the middle of a branched architecture (e.g., ResNet skip connections or multi-path networks) is currently unsupported. Also, each split subgraph must have a single input and a single output.
* Handling Multiple NN Inputs & Outputs: The final NN output can support multiple output layers. However, for networks with multiple input layers, proper synchronization is required to drive inputs—especially for stream interfaces. A fork-join mechanism in the Verilog testbench can help manage input synchronization effectively.
* RTL Simulation Issue: RTL simulation of stitched IPs with io_type='io_parallel' and a split at the flatten layer leads to improper simulation behavior and should be avoided.
* Array Partitioning for Parallel I/O: For io_parallel interfaces, all IPs must use the 'partition' pragma instead of 'reshape'.
163 changes: 144 additions & 19 deletions hls4ml/backends/vitis/vitis_backend.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import importlib.util
import json
import os
import shutil
import subprocess
import sys

from hls4ml.backends import VivadoBackend
from hls4ml.model.flow import get_flow, register_flow
from hls4ml.report import parse_vivado_report
from hls4ml.report import aggregate_graph_reports, parse_vivado_report
from hls4ml.utils.simulation_utils import (
annotate_axis_stream_widths,
prepare_testbench_input,
prepare_zero_input,
read_testbench_log,
write_testbench_input,
write_verilog_testbench,
)


class VitisBackend(VivadoBackend):
Expand Down Expand Up @@ -94,29 +106,142 @@ def build(
export=False,
vsynth=False,
fifo_opt=False,
log_to_stdout=True,
):
if 'linux' in sys.platform:
found = os.system('command -v vitis_hls > /dev/null')
if found != 0:
raise Exception('Vitis HLS installation not found. Make sure "vitis_hls" is on PATH.')

curr_dir = os.getcwd()
os.chdir(model.config.get_output_dir())
os.system(
(
'vitis_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} '
'validation={validation} export={export} vsynth={vsynth} fifo_opt={fifo_opt}"'
).format(
reset=reset,
csim=csim,
synth=synth,
cosim=cosim,
validation=validation,
export=export,
vsynth=vsynth,
fifo_opt=fifo_opt,
)
build_command = (
'vitis_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} '
'validation={validation} export={export} vsynth={vsynth} fifo_opt={fifo_opt}"'
).format(
reset=reset,
csim=csim,
synth=synth,
cosim=cosim,
validation=validation,
export=export,
vsynth=vsynth,
fifo_opt=fifo_opt,
)
os.chdir(curr_dir)

return parse_vivado_report(model.config.get_output_dir())
output_dir = model.config.get_output_dir()
stdout_log = os.path.join(output_dir, 'build_stdout.log')
stderr_log = os.path.join(output_dir, 'build_stderr.log')

stdout_target = None if log_to_stdout else open(stdout_log, 'w')
stderr_target = None if log_to_stdout else open(stderr_log, 'w')

try:
process = subprocess.Popen(
build_command, shell=True, cwd=output_dir, stdout=stdout_target, stderr=stderr_target, text=True
)
process.communicate()

if process.returncode != 0:
raise Exception(f'Build failed for {model.config.get_project_name()}. See logs for details.')
finally:
if not log_to_stdout:
stdout_target.close()
stderr_target.close()

return parse_vivado_report(output_dir)

def build_stitched_design(
self,
model,
stitch_design=True,
sim_stitched_design=False,
export_stitched_design=False,
nn_config=None,
graph_reports=None,
simulation_input_data=None,
):

os.makedirs(nn_config['OutputDir'], exist_ok=True)
stitched_design_dir = os.path.join(nn_config['OutputDir'], nn_config['StitchedProjectName'])
if stitch_design:
if not os.path.exists(stitched_design_dir):
os.makedirs(stitched_design_dir)

spec = importlib.util.find_spec('hls4ml')
hls4ml_path = os.path.dirname(spec.origin)
ip_stitcher_path = os.path.join(hls4ml_path, 'templates/vivado/ip_stitcher.tcl')
stdout_log = os.path.join(stitched_design_dir, 'stitcher_stdout.log')
stderr_log = os.path.join(stitched_design_dir, 'stitcher_stderr.log')
nn_config_path = os.path.join(stitched_design_dir, 'nn_config.json')
testbench_path = os.path.join(stitched_design_dir, 'testbench.v')
testbench_log_path = os.path.join(stitched_design_dir, 'testbench_log.csv')

try:
shutil.copy(ip_stitcher_path, stitched_design_dir)
except Exception as e:
print(f"Error: {e}. Cannot copy 'ip_stitcher.tcl' to {nn_config['StitchedProjectName']} folder.")

if nn_config:
if nn_config['outputs'][0]['pragma'] == 'stream':
last_graph_project_path = os.path.join(
model.graphs[-1].config.get_output_dir(), model.graphs[-1].config.get_project_dir()
)
annotate_axis_stream_widths(nn_config, last_graph_project_path)
with open(nn_config_path, "w") as file:
json.dump(nn_config, file, indent=4)

if sim_stitched_design:
write_verilog_testbench(nn_config, testbench_path)
# Produce a testbench input file for every input layer
for i, layer in enumerate(nn_config['inputs']):
testbench_input_path = os.path.join(stitched_design_dir, f"{layer['name']}_input_data.txt")
# We reshape input simulation data to (fifo_depth, batch_size)
if simulation_input_data is None:
input_data_reshaped = prepare_zero_input(layer)
print("No simulation input provided. Using zero-filled inputs.")
else:
# Handles both single and multi-layer cases. First dim should always be batch size
data = simulation_input_data[i]
input_data_reshaped = prepare_testbench_input(data, layer['fifo_depth'], layer['batch_size'])
write_testbench_input(
input_data_reshaped, testbench_input_path, layer['integer_bits'], layer['fractional_bits']
)
print('Verilog testbench and its input data were generated.')

print('Running build process of stitched IP...\n')
stitch_command = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest exporting this into the build_prj.tcl and invoke it from these, as having hls4ml creating the model and put them on another machine for HLS/logic could be a common workflow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stitch_command is relatively fast and only runs after all the individual subgraph builds are complete. However, since hls4ml manages these builds in parallel using a Python thread pool, supporting this workflow on a remote server would require a Python script that mimics this behavior, so essentially looping over each subgraph directory and running its corresponding build_prj.tcl in parallel using threads or processes. It's not hard to set up and I will do it once we finalize the flow.

'vivado',
'-mode',
'batch',
'-nojournal',
'-nolog',
'-notrace',
'-source',
ip_stitcher_path,
'-tclargs',
f'stitch_design={int(stitch_design)}',
f'sim_design={int(sim_stitched_design)}',
f'export_design={int(export_stitched_design)}',
f"stitch_project_name={nn_config['StitchedProjectName']}",
f"original_project_name={nn_config['OriginalProjectName']}",
'sim_verilog_file=testbench.v',
]

with open(stdout_log, 'w') as stdout_file, open(stderr_log, 'w') as stderr_file:
process = subprocess.Popen(
stitch_command, cwd=stitched_design_dir, stdout=stdout_file, stderr=stderr_file, text=True, shell=False
)
process.communicate()
if process.returncode != 0:
raise Exception(f"Stitching failed for {nn_config['StitchedProjectName']}. See logs for details.")

stitched_report = {'StitchedDesignReport': {}}
if stitch_design:
stitched_report = aggregate_graph_reports(graph_reports)

if sim_stitched_design:
testbench_output = read_testbench_log(testbench_log_path, nn_config['outputs'])
stitched_report['BehavSimResults'] = testbench_output['BehavSimResults']
stitched_report['StitchedDesignReport']['BestLatency'] = testbench_output['BestLatency']
stitched_report['StitchedDesignReport']['WorstLatency'] = testbench_output['WorstLatency']

return stitched_report
1 change: 1 addition & 0 deletions hls4ml/backends/vivado/passes/transform_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def transform(self, model, node):
new_var = self.array_var_converter.convert(var, pragma='stream')
elif io_type == 'io_parallel':
if out_name in node.model.inputs:
# NOTE this needs to be changed to partition
new_var = self.array_var_converter.convert(var, pragma='reshape')
elif isinstance(var, InplaceTensorVariable):
new_var = self.inplace_array_var_converter.convert(var, pragma='')
Expand Down
5 changes: 4 additions & 1 deletion hls4ml/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def convert_from_keras_model(

_check_hls_config(config, hls_config)

return keras_to_hls(config)
# Retrieve 'split_layer_names' from kwargs, if provided, for multi-graph creation
split_layer_names = kwargs.get('split_layer_names', [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity, I would suggest split_before_layer instead, as it is not intuitively clear if the layers listed will be listed the graph before or after.


return keras_to_hls(config, split_layer_names=split_layer_names)


@requires('_torch')
Expand Down
17 changes: 13 additions & 4 deletions hls4ml/converters/keras_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,18 @@ def parse_keras_model(model_arch, reader):
return layer_list, input_layers, output_layers, output_shapes


def keras_to_hls(config):
def keras_to_hls(config, split_layer_names=None):
model_arch, reader = get_model_arch(config)
layer_list, input_layers, output_layers, _ = parse_keras_model(model_arch, reader)
print('Creating HLS model')
hls_model = ModelGraph(config, layer_list, input_layers, output_layers)
layer_list, input_layers, output_layers, output_shapes = parse_keras_model(model_arch, reader)

print('Creating HLS model...')
if split_layer_names:
hls_model = ModelGraph.make_multi_graph(
config, layer_list, input_layers, output_layers, output_shapes, split_layer_names
)
print('Multi-graph HLS model created.')
else:
hls_model = ModelGraph(config, layer_list, input_layers, output_layers)
print('HLS model created.')

return hls_model
Loading