Skip to content

Commit c40bcd3

Browse files
authored
Add TrainerConfig.offline_loss_weight (#1791)
1 parent 5f13499 commit c40bcd3

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

alf/algorithms/algorithm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,6 +2145,18 @@ def _hybrid_update(self, experience, batch_info, offline_experience,
21452145
offline_valid_masks,
21462146
offline_batch_info)
21472147

2148+
# Weight the offline loss
2149+
offline_loss_weight = self._config.offline_loss_weight()
2150+
offline_loss_info = offline_loss_info._replace(
2151+
loss=offline_loss_info.loss *
2152+
self._config.offline_loss_weight())
2153+
2154+
# If the weight becomes 0, we'll switch from hybrid to online
2155+
# updates and release the offline buffer
2156+
if offline_loss_weight == 0:
2157+
self._has_offline = False
2158+
self._offline_replay_buffer = None
2159+
21482160
if loss_info is not None:
21492161
if self.is_rl():
21502162
valid_masks = (experience.step_type

alf/algorithms/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(self,
8585
priority_replay_eps=1e-6,
8686
offline_buffer_dir=None,
8787
offline_buffer_length=None,
88+
offline_loss_weight=1.0,
8889
rl_train_after_update_steps=0,
8990
rl_train_every_update_steps=1,
9091
empty_cache: bool = False,
@@ -335,6 +336,12 @@ def __init__(self,
335336
buffer length is offline_buffer_length * len(offline_buffer_dir).
336337
If None, all the samples from all the provided replay buffer
337338
checkpoints will be loaded.
339+
offline_loss_weight (float|Scheduler): weight for the offline loss.
340+
Current behavior is that whenever offline loss becomes 0,
341+
the offline replay buffer will be released, and we switch from
342+
performing hybrid updates to online updates for speed. In
343+
other words, it is assumed the weight will never go to zero
344+
and climb back up after.
338345
rl_train_after_update_steps (int): only used in the hybrid training
339346
mode. It is used as a starting criteria for the normal (non-offline)
340347
part of the RL training, which only starts after so many number
@@ -437,6 +444,7 @@ def __init__(self,
437444
# offline options
438445
self.offline_buffer_dir = offline_buffer_dir
439446
self.offline_buffer_length = offline_buffer_length
447+
self.offline_loss_weight = as_scheduler(offline_loss_weight)
440448
self.rl_train_after_update_steps = rl_train_after_update_steps
441449
self.rl_train_every_update_steps = rl_train_every_update_steps
442450
self.empty_cache = empty_cache

0 commit comments

Comments
 (0)