@@ -85,6 +85,7 @@ def __init__(self,
85
85
priority_replay_eps = 1e-6 ,
86
86
offline_buffer_dir = None ,
87
87
offline_buffer_length = None ,
88
+ offline_loss_weight = 1.0 ,
88
89
rl_train_after_update_steps = 0 ,
89
90
rl_train_every_update_steps = 1 ,
90
91
empty_cache : bool = False ,
@@ -335,6 +336,12 @@ def __init__(self,
335
336
buffer length is offline_buffer_length * len(offline_buffer_dir).
336
337
If None, all the samples from all the provided replay buffer
337
338
checkpoints will be loaded.
339
+ offline_loss_weight (float|Scheduler): weight for the offline loss.
340
+ Current behavior is that whenever offline loss becomes 0,
341
+ the offline replay buffer will be released, and we switch from
342
+ performing hybrid updates to online updates for speed. In
343
+ other words, it is assumed the weight will never go to zero
344
+ and climb back up after.
338
345
rl_train_after_update_steps (int): only used in the hybrid training
339
346
mode. It is used as a starting criteria for the normal (non-offline)
340
347
part of the RL training, which only starts after so many number
@@ -437,6 +444,7 @@ def __init__(self,
437
444
# offline options
438
445
self .offline_buffer_dir = offline_buffer_dir
439
446
self .offline_buffer_length = offline_buffer_length
447
+ self .offline_loss_weight = as_scheduler (offline_loss_weight )
440
448
self .rl_train_after_update_steps = rl_train_after_update_steps
441
449
self .rl_train_every_update_steps = rl_train_every_update_steps
442
450
self .empty_cache = empty_cache
0 commit comments