File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed
python/paddle/nn/functional Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -3074,6 +3074,10 @@ def cross_entropy(
3074
3074
# numerator: loss's weighted sum
3075
3075
# denominator: cal the sum of weight where the sample's class_index!=ignore_index
3076
3076
if ignore_index >= 0 : # ignore label
3077
+ out_type = out .dtype
3078
+ if out_type == paddle .float16 :
3079
+ out = paddle .cast (out , dtype = paddle .float32 )
3080
+
3077
3081
out_sum = _C_ops .sum (out , [], None , False )
3078
3082
# for each label[i],set 1 or 0, according to ignore_index
3079
3083
# mask[i]=0, if label[i]==ignore_index
@@ -3093,6 +3097,7 @@ def cross_entropy(
3093
3097
weight_sum
3094
3098
+ (weight_sum == 0.0 ).astype (weight_sum .dtype )
3095
3099
)
3100
+ ret = paddle .cast (ret , dtype = out_type )
3096
3101
return ret
3097
3102
elif weight is not None :
3098
3103
out_sum = _C_ops .sum (out , [], None , False )
You can’t perform that action at this time.
0 commit comments