diff --git a/Chapter05/sac_agent.py b/Chapter05/sac_agent.py index 5c6d20e..4ff3b2d 100644 --- a/Chapter05/sac_agent.py +++ b/Chapter05/sac_agent.py @@ -130,7 +130,7 @@ def process_actions(self, mean, log_std, test=False, eps=1e-6): log_prob_u = tfp.distributions.Normal(loc=mean, scale=std).log_prob(raw_actions) actions = tf.math.tanh(raw_actions) - log_prob = tf.reduce_sum(log_prob_u - tf.math.log(1 - actions ** 2 + eps)) + log_prob = tf.reduce_sum(log_prob_u - tf.math.log(1 - actions ** 2 + eps), axis=1) actions = actions * self.action_bound + self.action_shift