Skip to content

Commit 55d4b5d

Browse files
committed
Train script issues. Fix patch_size extraction in naflex mode + distributed. Fix update time calculations after adding naflex mode. Fix #2556
1 parent d08d5a0 commit 55d4b5d

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

train.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,11 @@ def main():
566566
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
567567
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
568568

569+
model_patch_size = None
570+
if args.naflex_loader:
571+
# NaFlexVit models have embeds.patch_size. Needs to be extracted here before mutating the model.
572+
model_patch_size = getattr(getattr(model, "embeds", None), "patch_size", None)
573+
569574
if args.torchscript:
570575
assert not args.torchcompile
571576
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
@@ -762,7 +767,6 @@ def main():
762767
)
763768

764769
naflex_mode = False
765-
model_patch_size = None
766770
if args.naflex_loader:
767771
if utils.is_primary(args):
768772
_logger.info('Using NaFlex loader')
@@ -775,11 +779,8 @@ def main():
775779
mixup_args.pop('cutmix_minmax') # not supported
776780
naflex_mixup_fn = NaFlexMixup(**mixup_args)
777781

778-
# Extract model's patch size for NaFlex mode
779-
if hasattr(model, 'embeds') and hasattr(model.embeds, 'patch_size'):
780-
# NaFlexVit models have embeds.patch_size
781-
model_patch_size = model.embeds.patch_size
782-
else:
782+
# Check if we have model's patch size for NaFlex mode
783+
if model_patch_size is None:
783784
# Fallback to default
784785
model_patch_size = (16, 16)
785786
if utils.is_primary(args):
@@ -1197,6 +1198,7 @@ def _backward(_loss):
11971198
dist_scale = args.world_size * batch_size / global_batch_size
11981199
else:
11991200
dist_scale = None
1201+
global_batch_size = batch_size
12001202

12011203
if has_no_sync and not need_update:
12021204
with model.no_sync():
@@ -1212,7 +1214,10 @@ def _backward(_loss):
12121214
scaled_loss *= dist_scale
12131215
_backward(scaled_loss)
12141216
else:
1215-
batch_size = input.shape[0]
1217+
global_batch_size = batch_size = input.shape[0]
1218+
if args.distributed:
1219+
global_batch_size *= args.world_size
1220+
12161221
if has_no_sync and not need_update:
12171222
with model.no_sync():
12181223
loss = _forward()
@@ -1222,7 +1227,7 @@ def _backward(_loss):
12221227
_backward(loss)
12231228

12241229
losses_m.update(loss.item() * accum_steps, batch_size)
1225-
update_sample_count += batch_size
1230+
update_sample_count += global_batch_size
12261231

12271232
if not need_update:
12281233
data_start_time = time.time()
@@ -1240,7 +1245,7 @@ def _backward(_loss):
12401245
torch.npu.synchronize()
12411246
time_now = time.time()
12421247

1243-
update_time_m.update((time.time() - update_start_time) / update_sample_count, update_sample_count)
1248+
update_time_m.update(time.time() - update_start_time)
12441249
update_start_time = time_now
12451250

12461251
if update_idx % args.log_interval == 0:
@@ -1252,15 +1257,14 @@ def _backward(_loss):
12521257
# synchronize current step and avg loss, each process keeps its own running avg
12531258
loss_avg = utils.reduce_tensor(loss.new([loss_avg]), args.world_size).item()
12541259
loss_now = utils.reduce_tensor(loss.new([loss_now]), args.world_size).item()
1255-
update_sample_count *= args.world_size
12561260

12571261
if utils.is_primary(args):
12581262
_logger.info(
12591263
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
12601264
f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] '
12611265
f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) '
1262-
f'Time: {update_time_m.val:.3f}s, {1 / update_time_m.val:>7.2f}/s '
1263-
f'({update_time_m.avg:.3f}s, {1 / update_time_m.avg:>7.2f}/s) '
1266+
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
1267+
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
12641268
f'LR: {lr:.3e} '
12651269
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
12661270
)

0 commit comments

Comments
 (0)