Skip to content

Commit ac142f0

Browse files
authored
Summarize the first importance ratio for PPO (#1795)
* Summary for first importance ratio for PPO The importance ratio for the first gradient update in each iteration should be exact 0 (or very close to it). Nonzero importance ratio indicates something is wrong. So it is useful to summarize it to uncover some hidden bugs. * Address comments Also summarize importance_ratio - 1 for better visualization.
1 parent 0682417 commit ac142f0

File tree

3 files changed

+58
-3
lines changed

3 files changed

+58
-3
lines changed

alf/algorithms/algorithm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,12 +1640,15 @@ def _train_experience(self,
16401640
if self._config.empty_cache:
16411641
torch.cuda.empty_cache()
16421642

1643+
grad_step = 0
16431644
indices = None
16441645
for u in range(num_updates):
16451646
if mini_batch_size < batch_size:
16461647
indices = torch.randperm(batch_size,
16471648
device=experience.step_type.device)
16481649
for b in range(0, batch_size, mini_batch_size):
1650+
alf.summary.set_grad_step_counter(grad_step)
1651+
grad_step += 1
16491652

16501653
is_last_mini_batch = (u == num_updates - 1
16511654
and b + mini_batch_size >= batch_size)

alf/algorithms/ppo_loss.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,22 @@ def _pg_loss(self, info, advantages):
135135
log_prob_clipping=self._log_prob_clipping,
136136
check_numerics=self._check_numerics,
137137
debug_summaries=self._debug_summaries)
138+
if alf.summary.get_grad_step_counter() == 0:
139+
# For the first gradient step in one iteration, the importance ratios
140+
# should be 1. Summarize them so that we can notice something is wrong
141+
# if they are not 1. Note that due to floating point precision,
142+
# importance_ratio0 may not be exactly 1, but it should be very close
143+
# to 1.
144+
global_step = alf.summary.get_global_counter()
145+
summary_interval = alf.get_config_value(
146+
'TrainerConfig.summary_interval')
147+
if global_step < summary_interval or global_step % summary_interval == 0:
148+
with alf.summary.record_if(lambda: True), scope:
149+
alf.summary.histogram('importance_ratio0_minus1',
150+
importance_ratio - 1)
151+
alf.summary.scalar('importance_ratio0_minus1_abs',
152+
(importance_ratio - 1).abs().mean())
153+
138154
# Pessimistically choose the maximum objective value for clipped and
139155
# unclipped importance ratios.
140156
pg_objective = -importance_ratio * advantages
@@ -143,9 +159,12 @@ def _pg_loss(self, info, advantages):
143159

144160
if self._debug_summaries and alf.summary.should_record_summaries():
145161
with scope:
146-
alf.summary.histogram('pg_objective', pg_objective)
147-
alf.summary.histogram('pg_objective_clipped',
148-
pg_objective_clipped)
162+
alf.summary.scalar('pg_objective', pg_objective.mean())
163+
alf.summary.scalar('pg_objective_clipped',
164+
pg_objective_clipped.mean())
165+
alf.summary.scalar('objective_clip_fraction',
166+
(pg_objective_clipped
167+
> pg_objective).float().mean())
149168

150169
if self._check_numerics:
151170
assert torch.all(torch.isfinite(policy_gradient_loss))

alf/summary/summary_ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,39 @@ def set_global_counter(counter):
336336
update_progress("global_counter", counter)
337337

338338

339+
_grad_step_counter = 0
340+
341+
342+
def get_grad_step_counter():
343+
"""Get which gradient step we are at in the current global step.
344+
345+
For many algorithms, each global step (iteration) may have multiple gradient
346+
steps. Typically, only the last gradient step will be recorded in summary.
347+
This function return the current gradient step counter. If an algorithm needs
348+
to record summaries at a different gradient step, it can use
349+
`with record_if(lambda: alf.summary.get_grad_step_counter() == n):`
350+
to record summaries at gradient step `n`.
351+
352+
Returns:
353+
int: the current gradient step counter. The first gradient step in a
354+
global step is 0, the second is 1, etc.
355+
"""
356+
return _grad_step_counter
357+
358+
359+
def set_grad_step_counter(counter):
360+
"""Set the current gradient step counter.
361+
362+
This function is used by ALF framework to set the gradient step counter
363+
before running the gradient step.
364+
365+
Args:
366+
counter (int): the gradient step counter to set
367+
"""
368+
global _grad_step_counter
369+
_grad_step_counter = counter
370+
371+
339372
class record_if(object):
340373
"""Context manager to set summary recording on or off according to `cond`."""
341374

0 commit comments

Comments
 (0)