Skip to content

Commit 71a1aa2

Browse files
authored
skip adding param group when no param (#1797)
* skip adding param group when no param * fix test case
1 parent ac142f0 commit 71a1aa2

File tree

2 files changed

+30
-56
lines changed

2 files changed

+30
-56
lines changed

alf/optimizers/optimizers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,10 @@ def add_param_group(self, param_group):
490490
if (not self._ignore_param_not_requiring_grad or p.requires_grad)
491491
]
492492

493+
if len(param_group['params']) == 0:
494+
# If no params are in this group, ignore adding
495+
return
496+
493497
lr_scheduler = param_group.get('lr_scheduler', None)
494498
if isinstance(lr_scheduler, Callable):
495499
self._lr_schedulers.append(lr_scheduler)

alf/utils/checkpoint_utils_test.py

Lines changed: 26 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -366,68 +366,38 @@ def test_with_cycle(self):
366366
'_sub_alg2._optimizers.0',
367367
{
368368
'state': {},
369-
'param_groups': [
370-
{
371-
'lr': 0.2,
372-
'betas': (0.9, 0.999),
373-
'capturable': False,
374-
'differentiable': False,
375-
'eps': 1e-08,
376-
'foreach': None,
377-
'fused': None,
378-
'weight_decay': 0,
379-
'amsgrad': False,
380-
'maximize': False,
381-
'params': []
382-
},
383-
{
384-
'lr': 0.2,
385-
'betas': (0.9, 0.999),
386-
'capturable': False,
387-
'differentiable': False,
388-
'eps': 1e-08,
389-
'foreach': None,
390-
'fused': None,
391-
'weight_decay': 0,
392-
'amsgrad': False,
393-
'maximize': False,
394-
'params': [0] # order index instead of id
395-
}
396-
]
369+
'param_groups': [{
370+
'lr': 0.2,
371+
'betas': (0.9, 0.999),
372+
'capturable': False,
373+
'differentiable': False,
374+
'eps': 1e-08,
375+
'foreach': None,
376+
'fused': None,
377+
'weight_decay': 0,
378+
'amsgrad': False,
379+
'maximize': False,
380+
'params': [0] # order index instead of id
381+
}]
397382
}),
398383
('_param_list.0', torch.tensor([0.])),
399384
(
400385
'_optimizers.0',
401386
{
402387
'state': {},
403-
'param_groups': [
404-
{
405-
'lr': 0.1,
406-
'betas': (0.9, 0.999),
407-
'capturable': False,
408-
'differentiable': False,
409-
'eps': 1e-08,
410-
'foreach': None,
411-
'fused': None,
412-
'weight_decay': 0,
413-
'amsgrad': False,
414-
'maximize': False,
415-
'params': []
416-
},
417-
{
418-
'lr': 0.1,
419-
'betas': (0.9, 0.999),
420-
'capturable': False,
421-
'differentiable': False,
422-
'eps': 1e-08,
423-
'foreach': None,
424-
'fused': None,
425-
'weight_decay': 0,
426-
'amsgrad': False,
427-
'maximize': False,
428-
'params': [0, 1] # order indices instead of id
429-
}
430-
]
388+
'param_groups': [{
389+
'lr': 0.1,
390+
'betas': (0.9, 0.999),
391+
'capturable': False,
392+
'differentiable': False,
393+
'eps': 1e-08,
394+
'foreach': None,
395+
'fused': None,
396+
'weight_decay': 0,
397+
'amsgrad': False,
398+
'maximize': False,
399+
'params': [0, 1] # order indices instead of id
400+
}]
431401
})
432402
])
433403

0 commit comments

Comments
 (0)