@@ -978,10 +978,9 @@ def make_multi_graph(cls, config, layer_list, input_layers, output_layers, outpu
978
978
if previous_layer_name in sub_config ['HLSConfig' ]['LayerName' ]:
979
979
prev_layer_config = sub_config ['HLSConfig' ]['LayerName' ][previous_layer_name ]
980
980
new_layer_config = {}
981
- new_layer_config ['Precision' ] = prev_layer_config [ 'Precision' ]
981
+ new_layer_config ['Precision' ] = last_output_precision if last_output_precision is not None else 'auto'
982
982
# NOTE - We copy Trace as well but it might be better to reset it
983
983
new_layer_config ['Trace' ] = prev_layer_config ['Trace' ]
984
- # copy last layer config from previous graph to the new input layer config of current graph
985
984
sub_config ['HLSConfig' ]['LayerName' ][input_layer_name ] = new_layer_config
986
985
else :
987
986
raise KeyError (f"Layer '{ previous_layer_name } ' not found in subconfig." )
@@ -994,23 +993,19 @@ def make_multi_graph(cls, config, layer_list, input_layers, output_layers, outpu
994
993
sub_config , sub_layer_list , graph_input_layers , graph_output_layers , initial_index = current_index
995
994
)
996
995
997
- # After creating subgraph, get the precision from the last layer's output .
996
+ # After creating the subgraph, extract the actual precision from the last layer's result .
998
997
if hls_model .graph :
999
- try :
1000
- last_layer = next (reversed (hls_model .graph .values ()))
1001
- last_output_precision = last_layer .attributes ['precision' ]['result' ]
1002
- except (KeyError , AttributeError ):
1003
- warnings .warn (
1004
- "Could not find precision in the last layer. " "Setting 'last_output_precision' to 'auto'."
1005
- )
1006
- last_output_precision = 'auto'
998
+ last_layer = next (reversed (hls_model .graph .values ()))
999
+ last_prec = last_layer .attributes .get ('result_t' )
1000
+ last_output_precision = (last_prec .precision if hasattr (last_prec , 'precision' ) else last_prec ) if last_prec is not None else 'auto'
1001
+ if last_output_precision == 'auto' or last_output_precision is None :
1002
+ raise ValueError ("Could not extract a valid precision from the last layer!" )
1007
1003
1008
- # Update the current index for the next graph
1009
- # Get the index of the last element in the graph
1004
+ # Update current_index based on the new graph (accounting for the inserted input layer).
1010
1005
layer_indices = [layer .index for layer in hls_model .graph .values ()]
1011
1006
if layer_indices :
1012
1007
max_index = max (layer_indices )
1013
- current_index = max_index - 1 # we have the new input layer as well
1008
+ current_index = max_index - 1
1014
1009
model_graphs .append (hls_model )
1015
1010
1016
1011
return MultiModelGraph (model_graphs )
@@ -1019,32 +1014,45 @@ def make_multi_graph(cls, config, layer_list, input_layers, output_layers, outpu
1019
1014
class MultiModelGraph :
1020
1015
def __init__ (self , graphs ):
1021
1016
self .graphs = graphs
1022
- self .config = copy .copy (self .graphs [0 ].config )
1023
- self ._deepcopy_config_names (self .graphs [0 ].config .config )
1024
1017
self ._initialize_config (graphs [0 ])
1025
- self .config .config ['StitchedProjectName' ] = 'vivado_stitched_design'
1026
- self .backend = graphs [0 ].config .backend
1018
+ self ._bind_modelgraph_methods ()
1019
+ self ._initialize_io_attributes (graphs )
1020
+
1021
+ def _initialize_config (self , first_graph ):
1022
+ self .config = copy .copy (first_graph .config )
1023
+ # Deep copy only 'ProjectName' and 'OutputDir', shallow copy others
1024
+ keys_to_deepcopy = ['ProjectName' , 'OutputDir' ]
1025
+ self .config .config = {
1026
+ k : copy .deepcopy (first_graph .config .config [k ]) if k in keys_to_deepcopy else first_graph .config .config [k ]
1027
+ for k in first_graph .config .config
1028
+ }
1029
+ self ._update_project_config (first_graph )
1030
+ self .backend = first_graph .config .backend
1031
+
1032
+ def _bind_modelgraph_methods (self ):
1033
+ # Bind necessary ModelGraph methods to this instance
1034
+ self ._compile = ModelGraph ._compile .__get__ (self , MultiModelGraph )
1035
+ self .get_output_variables = ModelGraph .get_output_variables .__get__ (self , MultiModelGraph )
1036
+ self ._compute_n_samples = ModelGraph ._compute_n_samples .__get__ (self , MultiModelGraph )
1037
+ self ._get_top_function = ModelGraph ._get_top_function .__get__ (self , MultiModelGraph )
1038
+ self ._predict = ModelGraph .predict .__get__ (self , MultiModelGraph )
1039
+ self .trace = ModelGraph .trace .__get__ (self , MultiModelGraph )
1040
+
1041
+ def _initialize_io_attributes (self , graphs ):
1027
1042
self .graph_reports = None
1028
1043
self ._top_function_lib = None
1029
- self .config .config ['Stamp' ] = '64616e'
1030
1044
self .inputs = graphs [0 ].inputs
1031
1045
self .outputs = graphs [- 1 ].outputs
1032
- self ._compile = ModelGraph . _compile . __get__ ( self , MultiModelGraph )
1046
+ self .output_vars = graphs [ - 1 ]. output_vars
1033
1047
1034
- def _initialize_config (self , first_graph ):
1035
- """
1036
- Initialize the configuration using details from the first graph
1037
- """
1048
+ def _update_project_config (self , first_graph ):
1038
1049
original_project_name = first_graph .config .get_project_name ().partition ('_graph' )[0 ]
1039
1050
self .config .config ['ProjectName' ] = f"{ original_project_name } _stitched"
1040
1051
self .config .config ['OriginalProjectName' ] = original_project_name
1041
1052
original_output_dir = first_graph .config .get_output_dir ().partition ('/graph' )[0 ]
1042
1053
self .config .config ['OutputDir' ] = os .path .join (original_output_dir , 'stitched' )
1043
-
1044
- def _deepcopy_config_names (self , config ):
1045
- # Deep copy only 'ProjectName' and 'OutputDir', shallow copy others
1046
- keys_to_deepcopy = ['ProjectName' , 'OutputDir' ]
1047
- self .config .config = {k : copy .deepcopy (config [k ]) if k in keys_to_deepcopy else config [k ] for k in config }
1054
+ self .config .config ['StitchedProjectName' ] = 'vivado_stitched_design'
1055
+ self .config .config ['Stamp' ] = '64616e'
1048
1056
1049
1057
def __getitem__ (self , index ):
1050
1058
return self .graphs [index ]
@@ -1137,6 +1145,9 @@ def build_wrapper(idx, g, **kwargs):
1137
1145
1138
1146
if stitch_design or sim_stitched_design or export_stitched_design :
1139
1147
self ._assert_consistent_pragmas ()
1148
+ vivado_folder = os .path .join (self .config .config ['OutputDir' ], self .config .config ['StitchedProjectName' ])
1149
+ if os .path .exists (vivado_folder ):
1150
+ raise FileExistsError (f"Vivado stitched project folder '{ vivado_folder } ' already exists." )
1140
1151
nn_config = self .parse_nn_config ()
1141
1152
stitched_report = self .backend .build_stitched_design (
1142
1153
stitch_design = stitch_design ,
@@ -1152,18 +1163,13 @@ def build_wrapper(idx, g, **kwargs):
1152
1163
def compile (self ):
1153
1164
for g in self .graphs :
1154
1165
g .compile ()
1155
- # TODO
1156
- # self.write_build_script()
1157
- # self.write_bridge()
1158
- # self._compile()
1166
+ # Bypass VitisWriter and invoke write_hls directly from VivadoWriter
1167
+ super (self .backend .writer .__class__ , self .backend .writer ).write_hls (self , is_multigraph = True )
1168
+ self ._compile ()
1159
1169
1160
1170
def predict (self , x , sim = 'csim' ):
1161
1171
if sim == 'csim' :
1162
- input_data = x
1163
- for g in self .graphs :
1164
- output_data = g .predict (input_data )
1165
- input_data = output_data
1166
- return output_data
1172
+ return self ._predict (x )
1167
1173
elif sim == 'rtl' :
1168
1174
nn_config = self .parse_nn_config ()
1169
1175
stitched_report = self .backend .build_stitched_design (
@@ -1177,134 +1183,22 @@ def predict(self, x, sim='csim'):
1177
1183
return stitched_report ['BehavSimResults' ]
1178
1184
else :
1179
1185
print ('Unknown simulation option given.' )
1180
-
1186
+
1181
1187
def trace (self , x ):
1182
- # TODO: finish trace function
1183
- input_data = x
1184
- trace_output = []
1185
- for g in self .graphs :
1186
- output_data , curr_trace_output = g .trace (input_data )
1187
- input_data = output_data
1188
- trace_output .append (curr_trace_output )
1189
- return output_data , trace_output
1190
-
1191
- def write_build_script (self ):
1192
- # NOTE we need to move this function to Vivado writer with each graph object
1193
- spec = importlib .util .find_spec ('hls4ml' )
1194
- hls4ml_path = os .path .dirname (spec .origin )
1195
- build_lib_src = os .path .join (hls4ml_path , 'templates/vivado/build_lib_multigraph.sh' )
1196
- os .makedirs (self .config .config ['OutputDir' ], exist_ok = True )
1197
- build_lib_dst = os .path .join (self .config .config ['OutputDir' ], 'build_lib.sh' )
1198
- graph_project_names = ' ' .join (f"\" { g .config .get_output_dir ().split ('/' )[- 1 ]} \" " for g in self .graphs )
1199
- with open (build_lib_src ) as src , open (build_lib_dst , 'w' ) as dst :
1200
- for line in src .readlines ():
1201
- line = line .replace ('myproject' , self .config .config ['OriginalProjectName' ])
1202
- line = line .replace ('myproject_stitched' , self .config .config ['ProjectName' ])
1203
- line = line .replace ('mystamp' , self .config .config ['Stamp' ])
1204
- line = line .replace ('mygraph_name_list' , graph_project_names )
1205
- dst .write (line )
1206
- os .chmod (build_lib_dst , os .stat (build_lib_dst ).st_mode | stat .S_IEXEC )
1207
-
1208
- def write_bridge (self ):
1209
- # NOTE we need to move this function to Vivado writer with each graph object
1210
- """Write the Python-C++ bridge (myproject_bridge.cpp)
1211
- Args:
1212
- model (ModelGraph): the hls4ml model.
1213
- """
1214
-
1215
- filedir = os .path .dirname (os .path .abspath (__file__ ))
1216
- f = open (os .path .join (filedir , '../templates/vivado/myproject_bridge_multigraph.cpp' ))
1217
- fout = open (f"{ self .config .get_output_dir ()} /{ self .config .config ['ProjectName' ]} _bridge.cpp" , 'w' )
1218
- model_inputs = self .graphs [0 ].get_input_variables ()
1219
- model_outputs = self .graphs [- 1 ].get_output_variables ()
1220
- model_brams = [var for var in self .graphs [0 ].get_weight_variables () if var .storage .lower () == 'bram' ]
1221
-
1222
- indent = ' '
1223
-
1224
- for line in f .readlines ():
1225
- newline = ''
1226
- if 'MYPROJECT' in line :
1227
- newline = line .replace ('MYPROJECT' , format (self .config .config ['ProjectName' ].upper ()))
1228
- elif 'firmware/myproject' in line :
1229
- for graph_idx in range (len (self .graphs )):
1230
- newline += line .replace ('myproject' , format (self .graphs [graph_idx ].config .config ['ProjectName' ]))
1231
- newline += '\n #undef DEFINES_H_\n ' if graph_idx < len (self .graphs ) - 1 else ''
1232
- elif 'myproject' in line :
1233
- newline = line .replace ('myproject' , format (self .graphs [0 ].config .config ['ProjectName' ]))
1234
-
1235
- elif '// hls-fpga-machine-learning insert bram' in line :
1236
- newline = line
1237
- for bram in model_brams :
1238
- newline += f'#include \" firmware/weights/{ bram .name } .h\" \n '
1239
-
1240
- elif '// hls-fpga-machine-learning insert header' in line :
1241
- dtype = line .split ('#' , 1 )[1 ].strip ()
1242
- inputs_str = ', ' .join ([f'{ dtype } { i .name } [{ i .size_cpp ()} ]' for i in model_inputs ])
1243
- outputs_str = ', ' .join ([f'{ dtype } { o .name } [{ o .size_cpp ()} ]' for o in model_outputs ])
1244
-
1245
- newline = ''
1246
- newline += indent + inputs_str + ',\n '
1247
- newline += indent + outputs_str + '\n '
1248
-
1249
- elif '// hls-fpga-machine-learning insert wrapper' in line :
1250
- dtype = line .split ('#' , 1 )[1 ].strip ()
1251
- newline = ''
1252
- for i in model_inputs :
1253
- newline += indent + '{var};\n ' .format (var = i .definition_cpp (name_suffix = '_ap' ))
1254
- newline += indent + 'nnet::convert_data<{}, {}, {}>({}, {}_ap);\n ' .format (
1255
- dtype , i .type .name , i .size_cpp (), i .name , i .name
1256
- )
1257
- newline += '\n '
1258
-
1259
- for o in model_outputs :
1260
- newline += indent + '{var};\n ' .format (var = o .definition_cpp (name_suffix = '_ap' ))
1261
-
1262
- newline += '\n '
1263
-
1264
- input_vars = ',' .join ([i .name + '_ap' for i in model_inputs ])
1265
- bram_vars = ',' .join ([b .name for b in model_brams ])
1266
- output_vars = ',' .join ([o .name + '_ap' for o in model_outputs ])
1267
-
1268
- # Concatenate the input, output, and bram variables. Filter out empty/null values
1269
- all_vars = ',' .join (filter (None , [input_vars , output_vars , bram_vars ]))
1270
-
1271
- top_level = indent + f"//{ self .config .config ['ProjectName' ]} ({ all_vars } );\n "
1272
- newline += top_level
1273
-
1274
- newline += '\n '
1275
-
1276
- for o in model_outputs :
1277
- newline += indent + 'nnet::convert_data<{}, {}, {}>({}_ap, {});\n ' .format (
1278
- o .type .name , dtype , o .size_cpp (), o .name , o .name
1279
- )
1280
-
1281
- elif '// hls-fpga-machine-learning insert trace_outputs' in line :
1282
- newline = ''
1283
- for layer in self .graphs [0 ].get_layers ():
1284
- func = layer .get_attr ('function_cpp' , None )
1285
- if func and self .graphs [0 ].config .trace_output and layer .get_attr ('trace' , False ):
1286
- vars = layer .get_variables ()
1287
- for var in vars :
1288
- newline += (
1289
- indent
1290
- + 'nnet::trace_outputs->insert(std::pair<std::string, void *>('
1291
- + f'"{ layer .name } ", (void *) malloc({ var .size_cpp ()} * element_size)));\n '
1292
- )
1293
-
1294
- elif '// hls-fpga-machine-learning insert namespace' in line :
1295
- newline = ''
1296
-
1297
- namespace = self .config .get_writer_config ().get ('Namespace' , None )
1298
- if namespace is not None :
1299
- newline += indent + f'using namespace { namespace } ;\n '
1300
-
1301
- else :
1302
- newline = line
1303
- fout .write (newline )
1304
-
1305
- f .close ()
1306
- fout .close ()
1188
+ raise NotImplementedError ("Trace function has not been implemented yet for MultiModelGraph." )
1307
1189
1190
+ def get_input_variables (self ):
1191
+ variables = []
1192
+ for inp in self .inputs :
1193
+ variables .append (self .graphs [0 ].graph [inp ].get_output_variable ())
1194
+ return variables
1195
+
1196
+ def get_layers (self ):
1197
+ all_values = []
1198
+ for g in self .graphs :
1199
+ all_values .extend (g .graph .values ())
1200
+ return dict (zip (all_values , all_values )).values ()
1201
+
1308
1202
def _get_pragma_details (self , pragma ):
1309
1203
"""
1310
1204
Extracts the pragma type and FIFO depth from the given pragma.
0 commit comments