Skip to content

Commit 0276de5

Browse files
committed
Summary for first importance ratio for PPO
The importance ratio for the first gradient update in each iteration should be exact 0 (or very close to it). Nonzero importance ratio indicates something is wrong. So it is useful to summarize it to uncover some hidden bugs.
1 parent da93236 commit 0276de5

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

alf/algorithms/algorithm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,12 +1640,15 @@ def _train_experience(self,
16401640
if self._config.empty_cache:
16411641
torch.cuda.empty_cache()
16421642

1643+
grad_step = 0
16431644
indices = None
16441645
for u in range(num_updates):
16451646
if mini_batch_size < batch_size:
16461647
indices = torch.randperm(batch_size,
16471648
device=experience.step_type.device)
16481649
for b in range(0, batch_size, mini_batch_size):
1650+
alf.summary.set_grad_step_counter(grad_step)
1651+
grad_step += 1
16491652

16501653
is_last_mini_batch = (u == num_updates - 1
16511654
and b + mini_batch_size >= batch_size)

alf/algorithms/ppo_loss.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ def _pg_loss(self, info, advantages):
135135
log_prob_clipping=self._log_prob_clipping,
136136
check_numerics=self._check_numerics,
137137
debug_summaries=self._debug_summaries)
138+
if alf.summary.get_grad_step_counter() == 0:
139+
# For the first gradient step in one iteration, the importance ratios
140+
# should be 1. Summarize them so that we can notice something is wrong
141+
# if they are not 1.
142+
with alf.summary.record_if(lambda: True), scope:
143+
alf.summary.histogram('importance_ratio0', importance_ratio)
144+
alf.summary.scalar('importance_ratio0_mean',
145+
importance_ratio.mean())
146+
138147
# Pessimistically choose the maximum objective value for clipped and
139148
# unclipped importance ratios.
140149
pg_objective = -importance_ratio * advantages
@@ -143,9 +152,12 @@ def _pg_loss(self, info, advantages):
143152

144153
if self._debug_summaries and alf.summary.should_record_summaries():
145154
with scope:
146-
alf.summary.histogram('pg_objective', pg_objective)
147-
alf.summary.histogram('pg_objective_clipped',
148-
pg_objective_clipped)
155+
alf.summary.scalar('pg_objective', pg_objective.mean())
156+
alf.summary.scalar('pg_objective_clipped',
157+
pg_objective_clipped.mean())
158+
alf.summary.scalar('objective_clip_fraction',
159+
(pg_objective_clipped
160+
> pg_objective).float().mean())
149161

150162
if self._check_numerics:
151163
assert torch.all(torch.isfinite(policy_gradient_loss))

alf/summary/summary_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,38 @@ def set_global_counter(counter):
336336
update_progress("global_counter", counter)
337337

338338

339+
_grad_step_counter = 0
340+
341+
342+
def get_grad_step_counter():
343+
"""Get which gradient step we are at in the current global step.
344+
345+
For many algorithms, each global step (iteration) may have multiple gradient
346+
steps. Typically, only the last gradient step will be recorded in summary.
347+
This function return the current gradient step counter, if an algorithm needs
348+
to record summaries at a different gradient step, it can use
349+
`with record_if(lambda: alf.summary.get_grad_step_counter() == n):`
350+
to record summaries at gradient step `n`.
351+
352+
Returns:
353+
int: the current gradient step counter
354+
"""
355+
return _grad_step_counter
356+
357+
358+
def set_grad_step_counter(counter):
359+
"""Set the current gradient step counter.
360+
361+
This function is used by ALF framework to set the gradient step counter
362+
before running the gradient step.
363+
364+
Args:
365+
counter (int): the gradient step counter to set
366+
"""
367+
global _grad_step_counter
368+
_grad_step_counter = counter
369+
370+
339371
class record_if(object):
340372
"""Context manager to set summary recording on or off according to `cond`."""
341373

0 commit comments

Comments
 (0)