Skip to content

Add the capability to do adjoint transforms #633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ If not stated, FINUFFT is assumed (cuFINUFFT <=1.3 is listed separately).

Master, using release name V 2.4.0 (4/23/25)

* Added functionality for adjoint execution of FINUFFT plans (Reinecke #633,
addresses #566 and #571).
* Update CUDA version to 12.4 for cufinufft (Andén).
* Binary Python wheels for Windows and musllinux (Barbone).
* fix CMake MATLAB build (final *.m copy), PR 667 (Barnett).
Expand Down
14 changes: 12 additions & 2 deletions include/finufft/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ template<> struct Finufft_FFT_plan<float> {
unlock();
}
void execute [[maybe_unused]] () { fftwf_execute(plan_); }
void execute [[maybe_unused]] (std::complex<float> *data) {
fftwf_execute_dft(plan_, reinterpret_cast<fftwf_complex *>(data),
reinterpret_cast<fftwf_complex *>(data));
}

static void forget_wisdom [[maybe_unused]] () { fftwf_forget_wisdom(); }
static void cleanup [[maybe_unused]] () { fftwf_cleanup(); }
Expand Down Expand Up @@ -152,6 +156,10 @@ template<> struct Finufft_FFT_plan<double> {
unlock();
}
void execute [[maybe_unused]] () { fftw_execute(plan_); }
void execute [[maybe_unused]] (std::complex<double> *data) {
fftw_execute_dft(plan_, reinterpret_cast<fftw_complex *>(data),
reinterpret_cast<fftw_complex *>(data));
}

static void forget_wisdom [[maybe_unused]] () { fftw_forget_wisdom(); }
static void cleanup [[maybe_unused]] () { fftw_cleanup(); }
Expand Down Expand Up @@ -179,7 +187,9 @@ static inline void finufft_fft_cleanup_threads [[maybe_unused]] () {
Finufft_FFT_plan<double>::cleanup_threads();
}
template<typename TF> struct FINUFFT_PLAN_T;
template<typename TF> std::vector<int> gridsize_for_fft(FINUFFT_PLAN_T<TF> *p);
template<typename TF> void do_fft(FINUFFT_PLAN_T<TF> *p);
template<typename TF> std::vector<int> gridsize_for_fft(const FINUFFT_PLAN_T<TF> &p);
template<typename TF>
void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, int ntrans_actual,
bool adjoint);

#endif // FINUFFT_INCLUDE_FINUFFT_FFT_H
15 changes: 7 additions & 8 deletions include/finufft/finufft_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <xsimd/xsimd.hpp>

#include <array>
#include <finufft_errors.h>
#include <memory>

Expand Down Expand Up @@ -172,10 +173,6 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++

std::array<std::vector<TF>, 3> phiHat; // FT of kernel in t1,2, on x,y,z-axis mode grid

// fwBatch: (batches of) fine working grid(s) for the FFT to plan & act on.
// Usually the largest internal array. Its allocator is 64-byte (cache-line) aligned:
std::vector<TC, xsimd::aligned_allocator<TC, 64>> fwBatch;

std::vector<BIGINT> sortIndices; // precomputed NU pt permutation, speeds spread/interp
bool didSort; // whether binsorting used (false: identity perm used)

Expand All @@ -188,12 +185,11 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++
// arrays (no new allocs)
std::vector<TC> prephase; // pre-phase, for all input NU pts
std::vector<TC> deconv; // reciprocal of kernel FT, phase, all output NU pts
std::vector<TC> CpBatch; // working array of prephased strengths
std::array<std::vector<TF>, 3> XYZp; // internal primed NU points (x'_j, etc)
std::array<std::vector<TF>, 3> STUp; // internal primed targs (s'_k, etc)
type3params<TF> t3P; // groups together type 3 shift, scale, phase, parameters
std::unique_ptr<FINUFFT_PLAN_T<TF>> innerT2plan; // ptr used for type 2 in step 2 of
// type 3
std::unique_ptr<const FINUFFT_PLAN_T<TF>> innerT2plan; // ptr used for type 2 in step 2
// of type 3

// other internal structs
std::unique_ptr<Finufft_FFT_plan<TF>> fftPlan;
Expand All @@ -202,7 +198,10 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++

// Remaining actions (not create/delete) in guru interface are now methods...
int setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF *s, TF *t, TF *u);
int execute(std::complex<TF> *cj, std::complex<TF> *fk);
int execute_internal(TC *cj, TC *fk, bool adjoint = false, int ntrans_actual = -1,
TC *aligned_scratch = nullptr, size_t scratch_size = 0) const;
int execute(TC *cj, TC *fk) const { return execute_internal(cj, fk, false); }
int execute_adjoint(TC *cj, TC *fk) const { return execute_internal(cj, fk, true); }
};

void finufft_default_opts_t(finufft_opts *o);
Expand Down
2 changes: 1 addition & 1 deletion include/finufft/spreadinterp.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ FINUFFT_EXPORT int FINUFFT_CDECL spreadinterpSorted(
const std::vector<BIGINT> &sort_indices, const UBIGINT N1, const UBIGINT N2,
const UBIGINT N3, T *data_uniform, const UBIGINT M, T *FINUFFT_RESTRICT kx,
T *FINUFFT_RESTRICT ky, T *FINUFFT_RESTRICT kz, T *FINUFFT_RESTRICT data_nonuniform,
const finufft_spread_opts &opts, int did_sort);
const finufft_spread_opts &opts, int did_sort, bool adjoint);
template<typename T>
FINUFFT_EXPORT T FINUFFT_CDECL evaluate_kernel(T x, const finufft_spread_opts &opts);
template<typename T>
Expand Down
2 changes: 2 additions & 0 deletions include/finufft_eitherprec.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ FINUFFT_EXPORT int FINUFFT_CDECL FINUFFTIFY(_setpts)(
FINUFFT_FLT *zj, FINUFFT_BIGINT N, FINUFFT_FLT *s, FINUFFT_FLT *t, FINUFFT_FLT *u);
FINUFFT_EXPORT int FINUFFT_CDECL FINUFFTIFY(_execute)(
FINUFFT_PLAN plan, FINUFFT_CPX *weights, FINUFFT_CPX *result);
FINUFFT_EXPORT int FINUFFT_CDECL FINUFFTIFY(_execute_adjoint)(
FINUFFT_PLAN plan, FINUFFT_CPX *weights, FINUFFT_CPX *result);
FINUFFT_EXPORT int FINUFFT_CDECL FINUFFTIFY(_destroy)(FINUFFT_PLAN plan);

// ----------------- the 18 simple interfaces -------------------------------
Expand Down
8 changes: 8 additions & 0 deletions python/finufft/finufft/_finufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ class FinufftOpts(ctypes.Structure):
_executef.argtypes = [c_void_p, c_void_p, c_void_p]
_executef.restype = c_int

_execute_adjoint = lib.finufft_execute_adjoint
_execute_adjoint.argtypes = [c_void_p, c_void_p, c_void_p]
_execute_adjoint.restype = c_int

_execute_adjointf = lib.finufftf_execute_adjoint
_execute_adjointf.argtypes = [c_void_p, c_void_p, c_void_p]
_execute_adjointf.restype = c_int

_destroy = lib.finufft_destroy
_destroy.argtypes = [c_void_p]
_destroy.restype = c_int
Expand Down
58 changes: 58 additions & 0 deletions python/finufft/finufft/_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,13 @@ def __init__(self,nufft_type,n_modes_or_dim,n_trans=1,eps=1e-6,isign=None,dtype=
self._makeplan = _finufft._makeplanf
self._setpts = _finufft._setptsf
self._execute = _finufft._executef
self._execute_adjoint = _finufft._execute_adjointf
self._destroy = _finufft._destroyf
else:
self._makeplan = _finufft._makeplan
self._setpts = _finufft._setpts
self._execute = _finufft._execute
self._execute_adjoint = _finufft._execute_adjoint
self._destroy = _finufft._destroy

ier = self._makeplan(nufft_type, dim, n_modes, isign, n_trans, eps,
Expand Down Expand Up @@ -305,6 +307,62 @@ def execute(self,data,out=None):

return _out

### execute_adjoint
def execute_adjoint(self,data,out=None):
_data = _ensure_array_type(data, "data", self._dtype)
_out = _ensure_array_type(out, "out", self._dtype, output=True)

tp = self._type
n_trans = self._n_trans
nj = self._nj
nk = self._nk
dim = self._dim

if tp==1 or tp==2:
ms, mt, mu = [*self._n_modes, *([1]*(3-len(self._n_modes)))]

# input shape and size check
if tp==1:
valid_fshape(data.shape,n_trans,dim,ms,mt,mu,None,2)
if tp==2:
valid_cshape(data.shape,nj,n_trans)
if tp==3:
valid_cshape(data.shape,nk,n_trans)

# out shape and size check
if out is not None:
if tp==1:
valid_cshape(out.shape,nj,n_trans)
if tp==2:
valid_fshape(out.shape,n_trans,dim,ms,mt,mu,None,1)
if tp==3:
valid_cshape(out.shape,nj,n_trans)

# allocate out if None
if out is None:
if tp==1:
_out = np.ones([*data.shape[:-dim], nj], dtype=self._dtype, order='C')
if tp==2:
_out = 2*np.ones([*data.shape[:-1], *self._n_modes[::-1]], dtype=self._dtype, order='C')
if tp==3:
_out = 3*np.ones([*data.shape[:-1], nj], dtype=self._dtype, order='C')

# call execute based on type and precision type
if tp==1 or tp==3:
ier = self._execute_adjoint(self._inner_plan,
_out.ctypes.data_as(c_void_p),
_data.ctypes.data_as(c_void_p))
elif tp==2:
ier = self._execute_adjoint(self._inner_plan,
_data.ctypes.data_as(c_void_p),
_out.ctypes.data_as(c_void_p))

# check error
if ier != 0:
err_handler(ier)

return _out


def __del__(self):
destroy(self)
Expand Down
47 changes: 47 additions & 0 deletions python/finufft/test/test_finufft_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@ def test_finufft1_plan(dtype, shape, n_pts, output_arg, modeord):

utils.verify_type1(pts, coefs, shape, sig, 1e-6)

# test adjoint type 2
plan = Plan(2, shape, dtype=dtype, modeord=modeord)

plan.setpts(*pts)

if not output_arg:
sig = plan.execute_adjoint(coefs)
else:
sig = np.empty(shape, dtype=dtype)
plan.execute_adjoint(coefs, out=sig)

if modeord == 1:
sig = np.fft.fftshift(sig)

utils.verify_type1(pts, coefs, shape, sig, 1e-6)


@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
Expand All @@ -64,6 +80,24 @@ def test_finufft2_plan(dtype, shape, n_pts, output_arg, modeord):

utils.verify_type2(pts, sig, coefs, 1e-6)

# test adjoint type 1
plan = Plan(1, shape, dtype=dtype, modeord=modeord)

plan.setpts(*pts)

if modeord == 1:
_sig = np.fft.ifftshift(sig)
else:
_sig = sig

if not output_arg:
coefs = plan.execute_adjoint(_sig)
else:
coefs = np.empty(n_pts, dtype=dtype)
plan.execute_adjoint(_sig, out=coefs)

utils.verify_type2(pts, sig, coefs, 1e-6)


@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("dim", list(set(len(shape) for shape in SHAPES)))
Expand All @@ -86,6 +120,19 @@ def test_finufft3_plan(dtype, dim, n_source_pts, n_target_pts, output_arg):

utils.verify_type3(source_pts, source_coefs, target_pts, target_coefs, 1e-6)

# test adjoint type 3
plan = Plan(3, dim, dtype=dtype, isign=-1, eps=1e-5)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm increasing eps from 1e-6 to 1e-5 here, because I get occasional failures with single precision otherwise. Given that 1e-6 is uncomfortably close to machine epsilon, I'm not too worried about this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. type 3 errors are usually 2-3x bigger than type 1 or 2 at the same tolerance.


plan.setpts(*target_pts, *((None,) * (3 - dim)), *source_pts)

if not output_arg:
target_coefs = plan.execute_adjoint(source_coefs)
else:
target_coefs = np.empty(n_target_pts, dtype=dtype)
plan.execute_adjoint(source_coefs, out=target_coefs)

utils.verify_type3(source_pts, source_coefs, target_pts, target_coefs, 1e-5)


def test_finufft_plan_errors():
with pytest.raises(RuntimeError, match="must be single or double"):
Expand Down
6 changes: 3 additions & 3 deletions python/finufft/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def verify_type1(pts, coefs, shape, sig_est, tol):

type1_rel_err = np.linalg.norm(fk_target - fk_est) / np.linalg.norm(fk_target)

assert type1_rel_err < 25 * tol
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching from assert to np.testing.assert_allclose here, because the latter will provide more information in case of failure, which speeds up debugging a lot.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

np.testing.assert_allclose(type1_rel_err, 0, rtol=0, atol=25*tol)


def verify_type2(pts, sig, coefs_est, tol):
Expand All @@ -172,7 +172,7 @@ def verify_type2(pts, sig, coefs_est, tol):

type2_rel_err = np.linalg.norm(c_target - c_est) / np.linalg.norm(c_target)

assert type2_rel_err < 25 * tol
np.testing.assert_allclose(type2_rel_err, 0, rtol=0, atol=25*tol)


def verify_type3(source_pts, source_coef, target_pts, target_coef, tol):
Expand All @@ -191,4 +191,4 @@ def verify_type3(source_pts, source_coef, target_pts, target_coef, tol):

type3_rel_err = np.linalg.norm(target_est - target_true) / np.linalg.norm(target_true)

assert type3_rel_err < 100 * tol
np.testing.assert_allclose(type3_rel_err, 0, rtol=0, atol=100*tol)
6 changes: 6 additions & 0 deletions src/c_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ int finufft_execute(finufft_plan p, c128 *cj, c128 *fk) {
int finufftf_execute(finufftf_plan p, c64 *cj, c64 *fk) {
return reinterpret_cast<FINUFFT_PLAN_T<f32> *>(p)->execute(cj, fk);
}
int finufft_execute_adjoint(finufft_plan p, c128 *cj, c128 *fk) {
return reinterpret_cast<FINUFFT_PLAN_T<f64> *>(p)->execute_adjoint(cj, fk);
}
int finufftf_execute_adjoint(finufftf_plan p, c64 *cj, c64 *fk) {
return reinterpret_cast<FINUFFT_PLAN_T<f32> *>(p)->execute_adjoint(cj, fk);
}

int finufft_destroy(finufft_plan p)
// Free everything we allocated inside of finufft_plan pointed to by p.
Expand Down
Loading
Loading