Skip to content

Commit 358c42e

Browse files
committed
Simplified type 3. Add test for cuda type 3 nufft. Unified verify_type3 with CPU version gen_coef_ind
1 parent a439acf commit 358c42e

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

python/cufinufft/tests/test_simple.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,39 @@ def test_simple_type2(to_gpu, to_cpu, dtype, shape, n_trans, M, tol, output_arg)
8484
c = to_cpu(c_gpu)
8585

8686
utils.verify_type2(k, fk, c, tol)
87+
88+
89+
@pytest.mark.parametrize("dtype", DTYPES)
90+
@pytest.mark.parametrize("dim", list(set(len(shape) for shape in SHAPES)))
91+
@pytest.mark.parametrize("n_source_pts", MS)
92+
@pytest.mark.parametrize("n_target_pts", MS)
93+
@pytest.mark.parametrize("n_trans", N_TRANS)
94+
@pytest.mark.parametrize("tol", TOLS)
95+
@pytest.mark.parametrize("output_arg", OUTPUT_ARGS)
96+
def test_cufinufft3_simple(to_gpu, to_cpu, dtype, dim, n_source_pts, n_target_pts, n_trans, tol, output_arg):
97+
complex_dtype = utils._complex_dtype(dtype)
98+
99+
fun = {1: cufinufft.nufft1d3,
100+
2: cufinufft.nufft2d3,
101+
3: cufinufft.nufft3d3}[dim]
102+
103+
source_pts, source_coefs, target_pts = utils.type3_problem(
104+
complex_dtype, dim, n_source_pts, n_target_pts, n_trans
105+
)
106+
107+
108+
source_pts_gpu = to_gpu(source_pts)
109+
source_coefs_gpu = to_gpu(source_coefs)
110+
target_pts_gpu = to_gpu(target_pts)
111+
112+
if output_arg:
113+
target_coefs_gpu = _compat.array_empty_like(
114+
source_coefs_gpu, n_trans + (n_target_pts,), dtype=complex_dtype)
115+
116+
fun(*source_pts_gpu, source_coefs_gpu, *target_pts_gpu, out=target_coefs_gpu, eps=tol)
117+
else:
118+
target_coefs_gpu = fun(*source_pts_gpu, source_coefs_gpu, *target_pts_gpu, eps=tol)
119+
120+
target_coefs = to_cpu(target_coefs_gpu)
121+
122+
utils.verify_type3(source_pts, source_coefs, target_pts, target_coefs, tol)

python/cufinufft/tests/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def _real_dtype(complex_dtype):
2222

2323
return real_dtype
2424

25+
def gen_coef_ind(n_pts, n_tr):
26+
ind = tuple(np.random.randint(0, n) for n in n_tr + (n_pts,))
27+
return ind
2528

2629
def gen_nu_pts(M, dim=3, seed=0):
2730
np.random.seed(seed)
@@ -155,10 +158,8 @@ def verify_type3(source_pts, source_coef, target_pts, target_coef, tol):
155158
n_source_pts = source_pts.shape[-1]
156159
n_target_pts = target_pts.shape[-1]
157160
n_tr = source_coef.shape[:-1]
158-
159161
assert target_coef.shape == n_tr + (n_target_pts,)
160-
161-
ind = (int(0.1789 * n_target_pts),)
162+
ind = gen_coef_ind(n_target_pts, n_tr)
162163

163164
target_est = target_coef[ind]
164165
target_true = direct_type3(source_pts, source_coef, target_pts, ind)

0 commit comments

Comments
 (0)