Skip to content

Commit 3c16abe

Browse files
committed
support for MultiModelGraph predict using chained bridge file
1 parent c929f78 commit 3c16abe

File tree

4 files changed

+246
-253
lines changed

4 files changed

+246
-253
lines changed

hls4ml/model/graph.py

+58-164
Original file line numberDiff line numberDiff line change
@@ -978,10 +978,9 @@ def make_multi_graph(cls, config, layer_list, input_layers, output_layers, outpu
978978
if previous_layer_name in sub_config['HLSConfig']['LayerName']:
979979
prev_layer_config = sub_config['HLSConfig']['LayerName'][previous_layer_name]
980980
new_layer_config = {}
981-
new_layer_config['Precision'] = prev_layer_config['Precision']
981+
new_layer_config['Precision'] = last_output_precision if last_output_precision is not None else 'auto'
982982
# NOTE - We copy Trace as well but it might be better to reset it
983983
new_layer_config['Trace'] = prev_layer_config['Trace']
984-
# copy last layer config from previous graph to the new input layer config of current graph
985984
sub_config['HLSConfig']['LayerName'][input_layer_name] = new_layer_config
986985
else:
987986
raise KeyError(f"Layer '{previous_layer_name}' not found in subconfig.")
@@ -994,23 +993,19 @@ def make_multi_graph(cls, config, layer_list, input_layers, output_layers, outpu
994993
sub_config, sub_layer_list, graph_input_layers, graph_output_layers, initial_index=current_index
995994
)
996995

997-
# After creating subgraph, get the precision from the last layer's output.
996+
# After creating the subgraph, extract the actual precision from the last layer's result.
998997
if hls_model.graph:
999-
try:
1000-
last_layer = next(reversed(hls_model.graph.values()))
1001-
last_output_precision = last_layer.attributes['precision']['result']
1002-
except (KeyError, AttributeError):
1003-
warnings.warn(
1004-
"Could not find precision in the last layer. " "Setting 'last_output_precision' to 'auto'."
1005-
)
1006-
last_output_precision = 'auto'
998+
last_layer = next(reversed(hls_model.graph.values()))
999+
last_prec = last_layer.attributes.get('result_t')
1000+
last_output_precision = (last_prec.precision if hasattr(last_prec, 'precision') else last_prec) if last_prec is not None else 'auto'
1001+
if last_output_precision == 'auto' or last_output_precision is None:
1002+
raise ValueError("Could not extract a valid precision from the last layer!")
10071003

1008-
# Update the current index for the next graph
1009-
# Get the index of the last element in the graph
1004+
# Update current_index based on the new graph (accounting for the inserted input layer).
10101005
layer_indices = [layer.index for layer in hls_model.graph.values()]
10111006
if layer_indices:
10121007
max_index = max(layer_indices)
1013-
current_index = max_index - 1 # we have the new input layer as well
1008+
current_index = max_index - 1
10141009
model_graphs.append(hls_model)
10151010

10161011
return MultiModelGraph(model_graphs)
@@ -1019,32 +1014,45 @@ def make_multi_graph(cls, config, layer_list, input_layers, output_layers, outpu
10191014
class MultiModelGraph:
10201015
def __init__(self, graphs):
10211016
self.graphs = graphs
1022-
self.config = copy.copy(self.graphs[0].config)
1023-
self._deepcopy_config_names(self.graphs[0].config.config)
10241017
self._initialize_config(graphs[0])
1025-
self.config.config['StitchedProjectName'] = 'vivado_stitched_design'
1026-
self.backend = graphs[0].config.backend
1018+
self._bind_modelgraph_methods()
1019+
self._initialize_io_attributes(graphs)
1020+
1021+
def _initialize_config(self, first_graph):
1022+
self.config = copy.copy(first_graph.config)
1023+
# Deep copy only 'ProjectName' and 'OutputDir', shallow copy others
1024+
keys_to_deepcopy = ['ProjectName', 'OutputDir']
1025+
self.config.config = {
1026+
k: copy.deepcopy(first_graph.config.config[k]) if k in keys_to_deepcopy else first_graph.config.config[k]
1027+
for k in first_graph.config.config
1028+
}
1029+
self._update_project_config(first_graph)
1030+
self.backend = first_graph.config.backend
1031+
1032+
def _bind_modelgraph_methods(self):
1033+
# Bind necessary ModelGraph methods to this instance
1034+
self._compile = ModelGraph._compile.__get__(self, MultiModelGraph)
1035+
self.get_output_variables = ModelGraph.get_output_variables.__get__(self, MultiModelGraph)
1036+
self._compute_n_samples = ModelGraph._compute_n_samples.__get__(self, MultiModelGraph)
1037+
self._get_top_function = ModelGraph._get_top_function.__get__(self, MultiModelGraph)
1038+
self._predict = ModelGraph.predict.__get__(self, MultiModelGraph)
1039+
self.trace = ModelGraph.trace.__get__(self, MultiModelGraph)
1040+
1041+
def _initialize_io_attributes(self, graphs):
10271042
self.graph_reports = None
10281043
self._top_function_lib = None
1029-
self.config.config['Stamp'] = '64616e'
10301044
self.inputs = graphs[0].inputs
10311045
self.outputs = graphs[-1].outputs
1032-
self._compile = ModelGraph._compile.__get__(self, MultiModelGraph)
1046+
self.output_vars = graphs[-1].output_vars
10331047

1034-
def _initialize_config(self, first_graph):
1035-
"""
1036-
Initialize the configuration using details from the first graph
1037-
"""
1048+
def _update_project_config(self, first_graph):
10381049
original_project_name = first_graph.config.get_project_name().partition('_graph')[0]
10391050
self.config.config['ProjectName'] = f"{original_project_name}_stitched"
10401051
self.config.config['OriginalProjectName'] = original_project_name
10411052
original_output_dir = first_graph.config.get_output_dir().partition('/graph')[0]
10421053
self.config.config['OutputDir'] = os.path.join(original_output_dir, 'stitched')
1043-
1044-
def _deepcopy_config_names(self, config):
1045-
# Deep copy only 'ProjectName' and 'OutputDir', shallow copy others
1046-
keys_to_deepcopy = ['ProjectName', 'OutputDir']
1047-
self.config.config = {k: copy.deepcopy(config[k]) if k in keys_to_deepcopy else config[k] for k in config}
1054+
self.config.config['StitchedProjectName'] = 'vivado_stitched_design'
1055+
self.config.config['Stamp'] = '64616e'
10481056

10491057
def __getitem__(self, index):
10501058
return self.graphs[index]
@@ -1137,6 +1145,9 @@ def build_wrapper(idx, g, **kwargs):
11371145

11381146
if stitch_design or sim_stitched_design or export_stitched_design:
11391147
self._assert_consistent_pragmas()
1148+
vivado_folder = os.path.join(self.config.config['OutputDir'], self.config.config['StitchedProjectName'])
1149+
if os.path.exists(vivado_folder):
1150+
raise FileExistsError(f"Vivado stitched project folder '{vivado_folder}' already exists.")
11401151
nn_config = self.parse_nn_config()
11411152
stitched_report = self.backend.build_stitched_design(
11421153
stitch_design=stitch_design,
@@ -1152,18 +1163,13 @@ def build_wrapper(idx, g, **kwargs):
11521163
def compile(self):
11531164
for g in self.graphs:
11541165
g.compile()
1155-
# TODO
1156-
# self.write_build_script()
1157-
# self.write_bridge()
1158-
# self._compile()
1166+
# Bypass VitisWriter and invoke write_hls directly from VivadoWriter
1167+
super(self.backend.writer.__class__, self.backend.writer).write_hls(self, is_multigraph=True)
1168+
self._compile()
11591169

11601170
def predict(self, x, sim='csim'):
11611171
if sim == 'csim':
1162-
input_data = x
1163-
for g in self.graphs:
1164-
output_data = g.predict(input_data)
1165-
input_data = output_data
1166-
return output_data
1172+
return self._predict(x)
11671173
elif sim == 'rtl':
11681174
nn_config = self.parse_nn_config()
11691175
stitched_report = self.backend.build_stitched_design(
@@ -1177,134 +1183,22 @@ def predict(self, x, sim='csim'):
11771183
return stitched_report['BehavSimResults']
11781184
else:
11791185
print('Unknown simulation option given.')
1180-
1186+
11811187
def trace(self, x):
1182-
# TODO: finish trace function
1183-
input_data = x
1184-
trace_output = []
1185-
for g in self.graphs:
1186-
output_data, curr_trace_output = g.trace(input_data)
1187-
input_data = output_data
1188-
trace_output.append(curr_trace_output)
1189-
return output_data, trace_output
1190-
1191-
def write_build_script(self):
1192-
# NOTE we need to move this function to Vivado writer with each graph object
1193-
spec = importlib.util.find_spec('hls4ml')
1194-
hls4ml_path = os.path.dirname(spec.origin)
1195-
build_lib_src = os.path.join(hls4ml_path, 'templates/vivado/build_lib_multigraph.sh')
1196-
os.makedirs(self.config.config['OutputDir'], exist_ok=True)
1197-
build_lib_dst = os.path.join(self.config.config['OutputDir'], 'build_lib.sh')
1198-
graph_project_names = ' '.join(f"\"{g.config.get_output_dir().split('/')[-1]}\"" for g in self.graphs)
1199-
with open(build_lib_src) as src, open(build_lib_dst, 'w') as dst:
1200-
for line in src.readlines():
1201-
line = line.replace('myproject', self.config.config['OriginalProjectName'])
1202-
line = line.replace('myproject_stitched', self.config.config['ProjectName'])
1203-
line = line.replace('mystamp', self.config.config['Stamp'])
1204-
line = line.replace('mygraph_name_list', graph_project_names)
1205-
dst.write(line)
1206-
os.chmod(build_lib_dst, os.stat(build_lib_dst).st_mode | stat.S_IEXEC)
1207-
1208-
def write_bridge(self):
1209-
# NOTE we need to move this function to Vivado writer with each graph object
1210-
"""Write the Python-C++ bridge (myproject_bridge.cpp)
1211-
Args:
1212-
model (ModelGraph): the hls4ml model.
1213-
"""
1214-
1215-
filedir = os.path.dirname(os.path.abspath(__file__))
1216-
f = open(os.path.join(filedir, '../templates/vivado/myproject_bridge_multigraph.cpp'))
1217-
fout = open(f"{self.config.get_output_dir()}/{self.config.config['ProjectName']}_bridge.cpp", 'w')
1218-
model_inputs = self.graphs[0].get_input_variables()
1219-
model_outputs = self.graphs[-1].get_output_variables()
1220-
model_brams = [var for var in self.graphs[0].get_weight_variables() if var.storage.lower() == 'bram']
1221-
1222-
indent = ' '
1223-
1224-
for line in f.readlines():
1225-
newline = ''
1226-
if 'MYPROJECT' in line:
1227-
newline = line.replace('MYPROJECT', format(self.config.config['ProjectName'].upper()))
1228-
elif 'firmware/myproject' in line:
1229-
for graph_idx in range(len(self.graphs)):
1230-
newline += line.replace('myproject', format(self.graphs[graph_idx].config.config['ProjectName']))
1231-
newline += '\n#undef DEFINES_H_\n' if graph_idx < len(self.graphs) - 1 else ''
1232-
elif 'myproject' in line:
1233-
newline = line.replace('myproject', format(self.graphs[0].config.config['ProjectName']))
1234-
1235-
elif '// hls-fpga-machine-learning insert bram' in line:
1236-
newline = line
1237-
for bram in model_brams:
1238-
newline += f'#include \"firmware/weights/{bram.name}.h\"\n'
1239-
1240-
elif '// hls-fpga-machine-learning insert header' in line:
1241-
dtype = line.split('#', 1)[1].strip()
1242-
inputs_str = ', '.join([f'{dtype} {i.name}[{i.size_cpp()}]' for i in model_inputs])
1243-
outputs_str = ', '.join([f'{dtype} {o.name}[{o.size_cpp()}]' for o in model_outputs])
1244-
1245-
newline = ''
1246-
newline += indent + inputs_str + ',\n'
1247-
newline += indent + outputs_str + '\n'
1248-
1249-
elif '// hls-fpga-machine-learning insert wrapper' in line:
1250-
dtype = line.split('#', 1)[1].strip()
1251-
newline = ''
1252-
for i in model_inputs:
1253-
newline += indent + '{var};\n'.format(var=i.definition_cpp(name_suffix='_ap'))
1254-
newline += indent + 'nnet::convert_data<{}, {}, {}>({}, {}_ap);\n'.format(
1255-
dtype, i.type.name, i.size_cpp(), i.name, i.name
1256-
)
1257-
newline += '\n'
1258-
1259-
for o in model_outputs:
1260-
newline += indent + '{var};\n'.format(var=o.definition_cpp(name_suffix='_ap'))
1261-
1262-
newline += '\n'
1263-
1264-
input_vars = ','.join([i.name + '_ap' for i in model_inputs])
1265-
bram_vars = ','.join([b.name for b in model_brams])
1266-
output_vars = ','.join([o.name + '_ap' for o in model_outputs])
1267-
1268-
# Concatenate the input, output, and bram variables. Filter out empty/null values
1269-
all_vars = ','.join(filter(None, [input_vars, output_vars, bram_vars]))
1270-
1271-
top_level = indent + f"//{self.config.config['ProjectName']}({all_vars});\n"
1272-
newline += top_level
1273-
1274-
newline += '\n'
1275-
1276-
for o in model_outputs:
1277-
newline += indent + 'nnet::convert_data<{}, {}, {}>({}_ap, {});\n'.format(
1278-
o.type.name, dtype, o.size_cpp(), o.name, o.name
1279-
)
1280-
1281-
elif '// hls-fpga-machine-learning insert trace_outputs' in line:
1282-
newline = ''
1283-
for layer in self.graphs[0].get_layers():
1284-
func = layer.get_attr('function_cpp', None)
1285-
if func and self.graphs[0].config.trace_output and layer.get_attr('trace', False):
1286-
vars = layer.get_variables()
1287-
for var in vars:
1288-
newline += (
1289-
indent
1290-
+ 'nnet::trace_outputs->insert(std::pair<std::string, void *>('
1291-
+ f'"{layer.name}", (void *) malloc({var.size_cpp()} * element_size)));\n'
1292-
)
1293-
1294-
elif '// hls-fpga-machine-learning insert namespace' in line:
1295-
newline = ''
1296-
1297-
namespace = self.config.get_writer_config().get('Namespace', None)
1298-
if namespace is not None:
1299-
newline += indent + f'using namespace {namespace};\n'
1300-
1301-
else:
1302-
newline = line
1303-
fout.write(newline)
1304-
1305-
f.close()
1306-
fout.close()
1188+
raise NotImplementedError("Trace function has not been implemented yet for MultiModelGraph.")
13071189

1190+
def get_input_variables(self):
1191+
variables = []
1192+
for inp in self.inputs:
1193+
variables.append(self.graphs[0].graph[inp].get_output_variable())
1194+
return variables
1195+
1196+
def get_layers(self):
1197+
all_values = []
1198+
for g in self.graphs:
1199+
all_values.extend(g.graph.values())
1200+
return dict(zip(all_values, all_values)).values()
1201+
13081202
def _get_pragma_details(self, pragma):
13091203
"""
13101204
Extracts the pragma type and FIFO depth from the given pragma.

hls4ml/templates/vivado/build_lib_multigraph.sh

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,19 @@ ORIGINAL_PROJECT=myproject
1515
PROJECT=myproject_stitched
1616
LIB_STAMP=mystamp
1717
BASEDIR="$(cd "$(dirname "$0")" && cd .. && pwd)"
18-
AP_TYPES_PATH="-I${BASEDIR}/${graph_project_names[0]}/firmware/ap_types/"
1918
INCFLAGS=""
2019
OUTPUT_DIR="${BASEDIR}/stitched/firmware"
20+
WEIGHTS_DIR="\"${BASEDIR}/stitched/firmware/weights\""
2121

2222
mkdir -p "${OUTPUT_DIR}"
2323

2424
# Compile all graphs
2525
OBJECT_FILES=()
2626
for g in "${graph_project_names[@]}"; do
27-
WEIGHTS_DIR="\"${BASEDIR}/${g}/firmware/weights\""
2827
SRC_FILE="${g}/firmware/${ORIGINAL_PROJECT}_${g}.cpp"
2928
OBJ_FILE="${ORIGINAL_PROJECT}_${g}.o"
30-
29+
AP_TYPES_PATH="-I${BASEDIR}/${g}/firmware/ap_types/"
30+
3131
${CC} ${CFLAGS} ${AP_TYPES_PATH} -D WEIGHTS_DIR="${WEIGHTS_DIR}" -c "${BASEDIR}/${SRC_FILE}" -o "${OBJ_FILE}"
3232
OBJECT_FILES+=("${OBJ_FILE}")
3333
INCFLAGS+="-I${BASEDIR}/${g}/ "

hls4ml/templates/vivado/myproject_bridge_multigraph.cpp

-70
This file was deleted.

0 commit comments

Comments
 (0)