29
29
import numpy as np
30
30
from numba .extending import get_cython_function_address
31
31
32
+ # TODO: these are reasonable defaults, but could
33
+ # be made settable via a control dict
32
34
_HYP2F1_TOL = np .sqrt (np .finfo (np .float64 ).eps )
33
- _HYP2F1_CHECK = np .sqrt (_HYP2F1_TOL )
34
35
_HYP2F1_MAXTERM = int (1e6 )
35
36
36
37
_PTR = ctypes .POINTER
@@ -115,7 +116,7 @@ def _is_valid_2f1(f1, f2, a, b, c, z):
115
116
See Eq. 6 in https://doi.org/10.1016/j.cpc.2007.11.007
116
117
"""
117
118
if z == 0.0 :
118
- return np .abs (f1 - a * b / c ) < _HYP2F1_CHECK
119
+ return np .abs (f1 - a * b / c ) < _HYP2F1_TOL
119
120
u = c - (a + b + 1 ) * z
120
121
v = a * b
121
122
w = z * (1 - z )
@@ -124,7 +125,7 @@ def _is_valid_2f1(f1, f2, a, b, c, z):
124
125
numer = np .abs (u * f1 - v )
125
126
else :
126
127
numer = np .abs (f2 + u / w * f1 - v / w )
127
- return numer / denom < _HYP2F1_CHECK
128
+ return numer / denom < _HYP2F1_TOL
128
129
129
130
130
131
@numba .njit ("UniTuple(float64, 7)(float64, float64, float64, float64)" )
@@ -255,7 +256,7 @@ def _hyp2f1_recurrence(a, b, c, z):
255
256
256
257
257
258
@numba .njit (
258
- "UniTuple(float64, 6 )(float64, float64, float64, float64, float64, float64)"
259
+ "UniTuple(float64, 7 )(float64, float64, float64, float64, float64, float64)"
259
260
)
260
261
def _hyp2f1_dlmf1583_first (a_i , b_i , a_j , b_j , y , mu ):
261
262
"""
@@ -287,21 +288,26 @@ def _hyp2f1_dlmf1583_first(a_i, b_i, a_j, b_j, y, mu):
287
288
)
288
289
289
290
# 2F1(a, -y; c; z) via backwards recurrence
290
- val , sign , da , _ , dc , dz , _ = _hyp2f1_recurrence (a , y , c , z )
291
+ val , sign , da , _ , dc , dz , d2z = _hyp2f1_recurrence (a , y , c , z )
291
292
292
293
# map gradient to parameters
293
294
da_i = dc - _digamma (a_i + a_j ) + _digamma (a_i )
294
295
da_j = da + dc - np .log (s ) + _digamma (a_j + y + 1 ) - _digamma (a_i + a_j )
295
296
db_i = dz / (b_j - mu ) + a_j / (mu + b_i )
296
297
db_j = dz * (1 - z ) / (b_j - mu ) - a_j / s / (mu + b_i )
297
298
299
+ # needed to verify result
300
+ d2b_j = (1 - z ) / (b_j - mu ) ** 2 * (d2z * (1 - z ) - 2 * dz * (1 + a_j )) + (
301
+ 1 + a_j
302
+ ) * a_j / (b_j - mu ) ** 2
303
+
298
304
val += scale
299
305
300
- return val , sign , da_i , db_i , da_j , db_j
306
+ return val , sign , da_i , db_i , da_j , db_j , d2b_j
301
307
302
308
303
309
@numba .njit (
304
- "UniTuple(float64, 6 )(float64, float64, float64, float64, float64, float64)"
310
+ "UniTuple(float64, 7 )(float64, float64, float64, float64, float64, float64)"
305
311
)
306
312
def _hyp2f1_dlmf1583_second (a_i , b_i , a_j , b_j , y , mu ):
307
313
"""
@@ -320,18 +326,24 @@ def _hyp2f1_dlmf1583_second(a_i, b_i, a_j, b_j, y, mu):
320
326
)
321
327
322
328
# 2F1(a, y+1; c; z) via series expansion
323
- val , sign , da , _ , dc , dz , _ = _hyp2f1_taylor_series (a , y + 1 , c , z )
329
+ val , sign , da , _ , dc , dz , d2z = _hyp2f1_taylor_series (a , y + 1 , c , z )
324
330
325
331
# map gradient to parameters
326
332
da_i = da + np .log (z ) + dc + _digamma (a_i ) - _digamma (a_i + y + 1 )
327
333
da_j = da + np .log (z ) + _digamma (a_j + y + 1 ) - _digamma (a_j )
328
334
db_i = (1 - z ) * (dz + a / z ) / (b_i + b_j )
329
335
db_j = - z * (dz + a / z ) / (b_i + b_j )
330
336
337
+ # needed to verify result
338
+ d2b_j = (
339
+ z / (b_i + b_j ) ** 2 * (d2z * z + 2 * dz * (1 + a ))
340
+ + a * (1 + a ) / (b_i + b_j ) ** 2
341
+ )
342
+
331
343
sign *= (- 1 ) ** (y + 1 )
332
344
val += scale
333
345
334
- return val , sign , da_i , db_i , da_j , db_j
346
+ return val , sign , da_i , db_i , da_j , db_j , d2b_j
335
347
336
348
337
349
@numba .njit (
@@ -345,18 +357,14 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
345
357
assert 0 <= mu <= b_j
346
358
assert y >= 0 and y % 1.0 == 0.0
347
359
348
- f_1 , s_1 , da_i_1 , db_i_1 , da_j_1 , db_j_1 = _hyp2f1_dlmf1583_first (
360
+ f_1 , s_1 , da_i_1 , db_i_1 , da_j_1 , db_j_1 , d2b_j_1 = _hyp2f1_dlmf1583_first (
349
361
a_i , b_i , a_j , b_j , y , mu
350
362
)
351
363
352
- f_2 , s_2 , da_i_2 , db_i_2 , da_j_2 , db_j_2 = _hyp2f1_dlmf1583_second (
364
+ f_2 , s_2 , da_i_2 , db_i_2 , da_j_2 , db_j_2 , d2b_j_2 = _hyp2f1_dlmf1583_second (
353
365
a_i , b_i , a_j , b_j , y , mu
354
366
)
355
367
356
- if np .abs (f_1 - f_2 ) < _HYP2F1_TOL :
357
- # TODO: detect a priori if this will occur
358
- raise Invalid2F1 ("Singular hypergeometric function" )
359
-
360
368
f_0 = max (f_1 , f_2 )
361
369
f_1 = np .exp (f_1 - f_0 ) * s_1
362
370
f_2 = np .exp (f_2 - f_0 ) * s_2
@@ -366,10 +374,22 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
366
374
db_i = (db_i_1 * f_1 + db_i_2 * f_2 ) / f
367
375
da_j = (da_j_1 * f_1 + da_j_2 * f_2 ) / f
368
376
db_j = (db_j_1 * f_1 + db_j_2 * f_2 ) / f
377
+ d2b_j = (d2b_j_1 * f_1 + d2b_j_2 * f_2 ) / f
369
378
370
379
sign = np .sign (f )
371
380
val = np .log (np .abs (f )) + f_0
372
381
382
+ # use first/second derivatives to check that result is non-singular
383
+ dz = - db_j * (mu + b_i )
384
+ d2z = d2b_j * (mu + b_i ) ** 2
385
+ if (
386
+ not _is_valid_2f1 (
387
+ dz , d2z , a_j , a_i + a_j + y , a_j + y + 1 , (mu - b_j ) / (mu + b_i )
388
+ )
389
+ or sign <= 0
390
+ ):
391
+ raise Invalid2F1 ("Hypergeometric series is singular" )
392
+
373
393
return val , sign , da_i , db_i , da_j , db_j
374
394
375
395
0 commit comments