Skip to content

Commit 22449b9

Browse files
authored
Merge pull request #293 from nspope/numerical-fixes
Numerical fixes (closes #286 and #289)
2 parents 4a1890c + f3fcc90 commit 22449b9

File tree

3 files changed

+64
-21
lines changed

3 files changed

+64
-21
lines changed

tests/test_hypergeo.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,28 @@ def test_2f1_grad(self, muts, hyp2f1_func, pars):
240240
offset = self._2f1_validate(*pars)
241241
check = self._2f1_grad_validate(*pars, offset=offset)
242242
assert np.allclose(grad, check)
243+
244+
245+
@pytest.mark.parametrize(
246+
"pars",
247+
[
248+
# taken from examples in issues tsdate/286, tsdate/289
249+
[1.104, 0.0001125, 118.1396, 0.009052, 1.0, 0.001404],
250+
[2.7481, 0.001221, 344.94083, 0.02329, 3.0, 0.00026624],
251+
],
252+
)
253+
class TestSingular2F1:
254+
"""
255+
Test detection of cases where 2F1 is close to singular and DLMF 15.8.3
256+
suffers from catastrophic cancellation: in these cases, use DLMF 15.8.1
257+
even though it takes much longer to converge.
258+
"""
259+
260+
def test_dlmf1583_throws_exception(self, pars):
261+
with pytest.raises(Exception, match="is singular"):
262+
hypergeo._hyp2f1_dlmf1583(*pars)
263+
264+
def test_exception_uses_dlmf1581(self, pars):
265+
v1, *_ = hypergeo._hyp2f1(*pars)
266+
v2, *_ = hypergeo._hyp2f1_dlmf1581(*pars)
267+
assert np.isclose(v1, v2)

tsdate/approx.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,6 @@ def sufficient_statistics(a_i, b_i, a_j, b_j, y_ij, mu_ij):
111111
112112
:return: normalizing constant, E[t_i], E[log t_i], E[t_j], E[log t_j]
113113
"""
114-
assert a_i > 0 and b_i >= 0, "Invalid parent parameters"
115-
assert a_j > 0 and b_j >= 0, "Invalid child parameters"
116-
assert y_ij >= 0 and mu_ij > 0, "Invalid edge parameters"
117114

118115
a = a_i + a_j + y_ij
119116
b = a_j
@@ -124,9 +121,6 @@ def sufficient_statistics(a_i, b_i, a_j, b_j, y_ij, mu_ij):
124121
a_i, b_i, a_j, b_j, y_ij, mu_ij
125122
)
126123

127-
if sign_f <= 0:
128-
raise hypergeo.Invalid2F1("Singular hypergeometric function")
129-
130124
logconst = (
131125
log_f + hypergeo._betaln(y_ij + 1, b) + hypergeo._gammaln(a) - a * np.log(t)
132126
)
@@ -142,6 +136,10 @@ def sufficient_statistics(a_i, b_i, a_j, b_j, y_ij, mu_ij):
142136
- hypergeo._digamma(c)
143137
)
144138

139+
# check that Jensen's inequality holds
140+
assert np.log(t_i) > ln_t_i
141+
assert np.log(t_j) > ln_t_j
142+
145143
return logconst, t_i, ln_t_i, t_j, ln_t_j
146144

147145

tsdate/hypergeo.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
import numpy as np
3030
from numba.extending import get_cython_function_address
3131

32+
# TODO: these are reasonable defaults, but could
33+
# be made settable via a control dict
3234
_HYP2F1_TOL = np.sqrt(np.finfo(np.float64).eps)
33-
_HYP2F1_CHECK = np.sqrt(_HYP2F1_TOL)
3435
_HYP2F1_MAXTERM = int(1e6)
3536

3637
_PTR = ctypes.POINTER
@@ -115,7 +116,7 @@ def _is_valid_2f1(f1, f2, a, b, c, z):
115116
See Eq. 6 in https://doi.org/10.1016/j.cpc.2007.11.007
116117
"""
117118
if z == 0.0:
118-
return np.abs(f1 - a * b / c) < _HYP2F1_CHECK
119+
return np.abs(f1 - a * b / c) < _HYP2F1_TOL
119120
u = c - (a + b + 1) * z
120121
v = a * b
121122
w = z * (1 - z)
@@ -124,7 +125,7 @@ def _is_valid_2f1(f1, f2, a, b, c, z):
124125
numer = np.abs(u * f1 - v)
125126
else:
126127
numer = np.abs(f2 + u / w * f1 - v / w)
127-
return numer / denom < _HYP2F1_CHECK
128+
return numer / denom < _HYP2F1_TOL
128129

129130

130131
@numba.njit("UniTuple(float64, 7)(float64, float64, float64, float64)")
@@ -255,7 +256,7 @@ def _hyp2f1_recurrence(a, b, c, z):
255256

256257

257258
@numba.njit(
258-
"UniTuple(float64, 6)(float64, float64, float64, float64, float64, float64)"
259+
"UniTuple(float64, 7)(float64, float64, float64, float64, float64, float64)"
259260
)
260261
def _hyp2f1_dlmf1583_first(a_i, b_i, a_j, b_j, y, mu):
261262
"""
@@ -287,21 +288,26 @@ def _hyp2f1_dlmf1583_first(a_i, b_i, a_j, b_j, y, mu):
287288
)
288289

289290
# 2F1(a, -y; c; z) via backwards recurrence
290-
val, sign, da, _, dc, dz, _ = _hyp2f1_recurrence(a, y, c, z)
291+
val, sign, da, _, dc, dz, d2z = _hyp2f1_recurrence(a, y, c, z)
291292

292293
# map gradient to parameters
293294
da_i = dc - _digamma(a_i + a_j) + _digamma(a_i)
294295
da_j = da + dc - np.log(s) + _digamma(a_j + y + 1) - _digamma(a_i + a_j)
295296
db_i = dz / (b_j - mu) + a_j / (mu + b_i)
296297
db_j = dz * (1 - z) / (b_j - mu) - a_j / s / (mu + b_i)
297298

299+
# needed to verify result
300+
d2b_j = (1 - z) / (b_j - mu) ** 2 * (d2z * (1 - z) - 2 * dz * (1 + a_j)) + (
301+
1 + a_j
302+
) * a_j / (b_j - mu) ** 2
303+
298304
val += scale
299305

300-
return val, sign, da_i, db_i, da_j, db_j
306+
return val, sign, da_i, db_i, da_j, db_j, d2b_j
301307

302308

303309
@numba.njit(
304-
"UniTuple(float64, 6)(float64, float64, float64, float64, float64, float64)"
310+
"UniTuple(float64, 7)(float64, float64, float64, float64, float64, float64)"
305311
)
306312
def _hyp2f1_dlmf1583_second(a_i, b_i, a_j, b_j, y, mu):
307313
"""
@@ -320,18 +326,24 @@ def _hyp2f1_dlmf1583_second(a_i, b_i, a_j, b_j, y, mu):
320326
)
321327

322328
# 2F1(a, y+1; c; z) via series expansion
323-
val, sign, da, _, dc, dz, _ = _hyp2f1_taylor_series(a, y + 1, c, z)
329+
val, sign, da, _, dc, dz, d2z = _hyp2f1_taylor_series(a, y + 1, c, z)
324330

325331
# map gradient to parameters
326332
da_i = da + np.log(z) + dc + _digamma(a_i) - _digamma(a_i + y + 1)
327333
da_j = da + np.log(z) + _digamma(a_j + y + 1) - _digamma(a_j)
328334
db_i = (1 - z) * (dz + a / z) / (b_i + b_j)
329335
db_j = -z * (dz + a / z) / (b_i + b_j)
330336

337+
# needed to verify result
338+
d2b_j = (
339+
z / (b_i + b_j) ** 2 * (d2z * z + 2 * dz * (1 + a))
340+
+ a * (1 + a) / (b_i + b_j) ** 2
341+
)
342+
331343
sign *= (-1) ** (y + 1)
332344
val += scale
333345

334-
return val, sign, da_i, db_i, da_j, db_j
346+
return val, sign, da_i, db_i, da_j, db_j, d2b_j
335347

336348

337349
@numba.njit(
@@ -345,18 +357,14 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
345357
assert 0 <= mu <= b_j
346358
assert y >= 0 and y % 1.0 == 0.0
347359

348-
f_1, s_1, da_i_1, db_i_1, da_j_1, db_j_1 = _hyp2f1_dlmf1583_first(
360+
f_1, s_1, da_i_1, db_i_1, da_j_1, db_j_1, d2b_j_1 = _hyp2f1_dlmf1583_first(
349361
a_i, b_i, a_j, b_j, y, mu
350362
)
351363

352-
f_2, s_2, da_i_2, db_i_2, da_j_2, db_j_2 = _hyp2f1_dlmf1583_second(
364+
f_2, s_2, da_i_2, db_i_2, da_j_2, db_j_2, d2b_j_2 = _hyp2f1_dlmf1583_second(
353365
a_i, b_i, a_j, b_j, y, mu
354366
)
355367

356-
if np.abs(f_1 - f_2) < _HYP2F1_TOL:
357-
# TODO: detect a priori if this will occur
358-
raise Invalid2F1("Singular hypergeometric function")
359-
360368
f_0 = max(f_1, f_2)
361369
f_1 = np.exp(f_1 - f_0) * s_1
362370
f_2 = np.exp(f_2 - f_0) * s_2
@@ -366,10 +374,22 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
366374
db_i = (db_i_1 * f_1 + db_i_2 * f_2) / f
367375
da_j = (da_j_1 * f_1 + da_j_2 * f_2) / f
368376
db_j = (db_j_1 * f_1 + db_j_2 * f_2) / f
377+
d2b_j = (d2b_j_1 * f_1 + d2b_j_2 * f_2) / f
369378

370379
sign = np.sign(f)
371380
val = np.log(np.abs(f)) + f_0
372381

382+
# use first/second derivatives to check that result is non-singular
383+
dz = -db_j * (mu + b_i)
384+
d2z = d2b_j * (mu + b_i) ** 2
385+
if (
386+
not _is_valid_2f1(
387+
dz, d2z, a_j, a_i + a_j + y, a_j + y + 1, (mu - b_j) / (mu + b_i)
388+
)
389+
or sign <= 0
390+
):
391+
raise Invalid2F1("Hypergeometric series is singular")
392+
373393
return val, sign, da_i, db_i, da_j, db_j
374394

375395

0 commit comments

Comments
 (0)