Skip to content

Commit 85d3234

Browse files
committed
support ctx save tensor with grad
1 parent a2f407d commit 85d3234

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/paddle/distributed/fleet/recompute/recompute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def inner_pack(inner_x):
471471
if inner_x is None:
472472
storage[holder_list[unpack_counter - 1]()] = None
473473
return
474-
if hasattr(inner_x, "main_grad"):
474+
if hasattr(inner_x, "main_grad") or inner_x.grad is not None:
475475
storage[holder_list[unpack_counter - 1]()] = inner_x
476476
else:
477477
if inner_x.is_dist():

0 commit comments

Comments
 (0)