Skip to content

Commit 4984a57

Browse files
authored
Remove runtime checks on memory spaces (#1018)
(1) allow a mix of memory spaces inside the sparse tensor; ultimately, only the visibility matters, not the exact consistency (2) remove the runtime check on offset = {-1,0,+1} for now; I see no other solution than copying this to host or having some kernel check on device, which all seems too costly; given the very limited usage of UST, we just assume the solve is only called on tridiag with these offsets
1 parent bfe088c commit 4984a57

File tree

3 files changed

+5
-11
lines changed

3 files changed

+5
-11
lines changed

include/matx/core/sparse_tensor.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,6 @@ class sparse_tensor_t
145145
for (int l = 0; l < LVL; l++) {
146146
c[l] = coordinates_[l].data();
147147
p[l] = positions_[l].data();
148-
// All non-null data resides in same space.
149-
if (v) {
150-
assert(!c[l] || GetPointerKind(c[l]) == GetPointerKind(v));
151-
assert(!p[l] || GetPointerKind(p[l]) == GetPointerKind(v));
152-
}
153148
}
154149
this->SetSparseData(v, c, p);
155150
}

include/matx/operators/solve.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class SolveOp : public BaseOp<SolveOp<OpA, OpB>> {
106106
void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const {
107107
static_assert(!is_sparse_tensor_v<OpB>, "sparse rhs not implemented");
108108
if constexpr (is_sparse_tensor_v<OpA>) {
109+
// Note that diagonal solve assumes TRI-diagonal storage currently.
109110
if constexpr (OpA::Format::isDIAI() || OpA::Format::isDIAJ()) {
110111
sparse_dia_solve_impl(cuda::std::get<0>(out), a_, b_, ex);
111112
} else if constexpr (OpA::Format::isBatchedDIAIUniform()) {

include/matx/transforms/solve/solve_cusparse.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,8 @@ void sparse_dia_solve_impl(TensorTypeC &C, const TensorTypeA &a,
210210
using CRD = typename atype::crd_type;
211211
CRD *diags = a.CRDData(0);
212212
const index_t numD = a.crdSize(0);
213-
if (numD != 3 || diags[0] != -1 || diags[1] != 0 || diags[2] != 1) {
214-
MATX_THROW(matxNotSupported, "Only tridiagonal solve supported");
215-
}
213+
// TODO: we should also check that offsets = {-1,0,1} (host and device)?
214+
MATX_ASSERT(numD == 3, matxInvalidParameter);
216215
using T = std::conditional_t<
217216
std::is_same_v<TA, cuda::std::complex<double>>, cuDoubleComplex,
218217
std::conditional_t<std::is_same_v<TA, cuda::std::complex<float>>,
@@ -280,9 +279,8 @@ void sparse_batched_dia_solve_impl(TensorTypeC &C, const TensorTypeA &a,
280279
using CRD = typename atype::crd_type;
281280
CRD *diags = a.CRDData(0);
282281
const index_t numD = a.crdSize(0);
283-
if (numD != 3 || diags[0] != -1 || diags[1] != 0 || diags[2] != 1) {
284-
MATX_THROW(matxNotSupported, "Only tridiagonal solve supported");
285-
}
282+
// TODO: we should also check that offsets = {-1,0,1} (host and device)?
283+
MATX_ASSERT(numD == 3, matxInvalidParameter);
286284
using T = std::conditional_t<
287285
std::is_same_v<TA, cuda::std::complex<double>>, cuDoubleComplex,
288286
std::conditional_t<std::is_same_v<TA, cuda::std::complex<float>>,

0 commit comments

Comments
 (0)