Skip to content

Commit 53ed70b

Browse files
authored
Merge pull request #426 from brainpy/doc
fix autograd bugs
2 parents 4703ddf + bd3b2ec commit 53ed70b

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

brainpy/_src/math/object_transform/autograd.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,17 @@ def __call__(self, *args, **kwargs):
224224
)
225225
cache_stack(self.target, stack)
226226

227-
self._dyn_vars = stack
228-
self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars])
229-
self._eval_dyn_vars = True
227+
self._dyn_vars = stack
228+
self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars])
229+
self._eval_dyn_vars = True
230230

231-
# if not the outermost transformation
232-
if current_transform_number():
233-
return self._return(rets)
231+
# if not the outermost transformation
232+
if current_transform_number():
233+
return self._return(rets)
234+
else:
235+
self._dyn_vars = stack
236+
self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars])
237+
self._eval_dyn_vars = True
234238

235239
rets = self._transform(
236240
[v.value for v in self._grad_vars], # variables for gradients

brainpy/_src/math/object_transform/tests/test_autograd.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,4 +1149,25 @@ def test_debug_correctness2(self):
11491149
self.assertTrue(bm.allclose(r1[1], r2[1]))
11501150
self.assertTrue(bm.allclose(r1[2], r2[2]))
11511151

1152+
def test_cache1(self):
1153+
file = tempfile.TemporaryFile(mode='w+')
1154+
1155+
def f(a, b):
1156+
print('compiling f ...', file=file)
1157+
return a + b
1158+
1159+
grad1 = bm.grad(f)(1., 2.) # call "f" twice, one for Variable finding, one for compiling
1160+
grad2 = bm.vector_grad(f)(1., 2.) # call "f" once for compiling
1161+
1162+
file.seek(0)
1163+
print(file.read().strip())
1164+
1165+
expect_res = '''
1166+
compiling f ...
1167+
compiling f ...
1168+
compiling f ...
1169+
'''
1170+
file.seek(0)
1171+
self.assertTrue(file.read().strip() == expect_res.strip())
1172+
11521173

0 commit comments

Comments
 (0)