@@ -566,6 +566,11 @@ def main():
566
566
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
567
567
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' )
568
568
569
+ model_patch_size = None
570
+ if args .naflex_loader :
571
+ # NaFlexVit models have embeds.patch_size. Needs to be extracted here before mutating the model.
572
+ model_patch_size = getattr (getattr (model , "embeds" , None ), "patch_size" , None )
573
+
569
574
if args .torchscript :
570
575
assert not args .torchcompile
571
576
assert not use_amp == 'apex' , 'Cannot use APEX AMP with torchscripted model'
@@ -762,7 +767,6 @@ def main():
762
767
)
763
768
764
769
naflex_mode = False
765
- model_patch_size = None
766
770
if args .naflex_loader :
767
771
if utils .is_primary (args ):
768
772
_logger .info ('Using NaFlex loader' )
@@ -775,11 +779,8 @@ def main():
775
779
mixup_args .pop ('cutmix_minmax' ) # not supported
776
780
naflex_mixup_fn = NaFlexMixup (** mixup_args )
777
781
778
- # Extract model's patch size for NaFlex mode
779
- if hasattr (model , 'embeds' ) and hasattr (model .embeds , 'patch_size' ):
780
- # NaFlexVit models have embeds.patch_size
781
- model_patch_size = model .embeds .patch_size
782
- else :
782
+ # Check if we have model's patch size for NaFlex mode
783
+ if model_patch_size is None :
783
784
# Fallback to default
784
785
model_patch_size = (16 , 16 )
785
786
if utils .is_primary (args ):
@@ -1197,6 +1198,7 @@ def _backward(_loss):
1197
1198
dist_scale = args .world_size * batch_size / global_batch_size
1198
1199
else :
1199
1200
dist_scale = None
1201
+ global_batch_size = batch_size
1200
1202
1201
1203
if has_no_sync and not need_update :
1202
1204
with model .no_sync ():
@@ -1212,7 +1214,10 @@ def _backward(_loss):
1212
1214
scaled_loss *= dist_scale
1213
1215
_backward (scaled_loss )
1214
1216
else :
1215
- batch_size = input .shape [0 ]
1217
+ global_batch_size = batch_size = input .shape [0 ]
1218
+ if args .distributed :
1219
+ global_batch_size *= args .world_size
1220
+
1216
1221
if has_no_sync and not need_update :
1217
1222
with model .no_sync ():
1218
1223
loss = _forward ()
@@ -1222,7 +1227,7 @@ def _backward(_loss):
1222
1227
_backward (loss )
1223
1228
1224
1229
losses_m .update (loss .item () * accum_steps , batch_size )
1225
- update_sample_count += batch_size
1230
+ update_sample_count += global_batch_size
1226
1231
1227
1232
if not need_update :
1228
1233
data_start_time = time .time ()
@@ -1240,7 +1245,7 @@ def _backward(_loss):
1240
1245
torch .npu .synchronize ()
1241
1246
time_now = time .time ()
1242
1247
1243
- update_time_m .update (( time .time () - update_start_time ) / update_sample_count , update_sample_count )
1248
+ update_time_m .update (time .time () - update_start_time )
1244
1249
update_start_time = time_now
1245
1250
1246
1251
if update_idx % args .log_interval == 0 :
@@ -1252,15 +1257,14 @@ def _backward(_loss):
1252
1257
# synchronize current step and avg loss, each process keeps its own running avg
1253
1258
loss_avg = utils .reduce_tensor (loss .new ([loss_avg ]), args .world_size ).item ()
1254
1259
loss_now = utils .reduce_tensor (loss .new ([loss_now ]), args .world_size ).item ()
1255
- update_sample_count *= args .world_size
1256
1260
1257
1261
if utils .is_primary (args ):
1258
1262
_logger .info (
1259
1263
f'Train: { epoch } [{ update_idx :>4d} /{ updates_per_epoch } '
1260
1264
f'({ 100. * (update_idx + 1 ) / updates_per_epoch :>3.0f} %)] '
1261
1265
f'Loss: { loss_now :#.3g} ({ loss_avg :#.3g} ) '
1262
- f'Time: { update_time_m .val :.3f} s, { 1 / update_time_m .val :>7.2f} /s '
1263
- f'({ update_time_m .avg :.3f} s, { 1 / update_time_m .avg :>7.2f} /s) '
1266
+ f'Time: { update_time_m .val :.3f} s, { update_sample_count / update_time_m .val :>7.2f} /s '
1267
+ f'({ update_time_m .avg :.3f} s, { update_sample_count / update_time_m .avg :>7.2f} /s) '
1264
1268
f'LR: { lr :.3e} '
1265
1269
f'Data: { data_time_m .val :.3f} ({ data_time_m .avg :.3f} )'
1266
1270
)
0 commit comments