Skip to content

Commit 1addbc2

Browse files
committed
[Big tensor] fix nan in cross_entropy
1 parent 1107fe4 commit 1addbc2

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

python/paddle/nn/functional/loss.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3074,6 +3074,10 @@ def cross_entropy(
30743074
# numerator: loss's weighted sum
30753075
# denominator: cal the sum of weight where the sample's class_index!=ignore_index
30763076
if ignore_index >= 0: # ignore label
3077+
out_type = out.dtype
3078+
if out_type == paddle.float16:
3079+
out = paddle.cast(out, dtype=paddle.float32)
3080+
30773081
out_sum = _C_ops.sum(out, [], None, False)
30783082
# for each label[i],set 1 or 0, according to ignore_index
30793083
# mask[i]=0, if label[i]==ignore_index
@@ -3093,6 +3097,7 @@ def cross_entropy(
30933097
weight_sum
30943098
+ (weight_sum == 0.0).astype(weight_sum.dtype)
30953099
)
3100+
ret = paddle.cast(ret, dtype=out_type)
30963101
return ret
30973102
elif weight is not None:
30983103
out_sum = _C_ops.sum(out, [], None, False)

0 commit comments

Comments
 (0)