Skip to content

Commit 1d58a6d

Browse files
committed
[BUG] Updated sbd_distance() to handle multivariate series (aeon-toolkit#2674)
* Updated sbd_distance() to handle multivariate data consistently with tslearn and other implementations * added _multivariate_sbd_distance() which finds the correlations for each of the channels and then normalizes using the norm of the multivariate series. a61927f
1 parent a2de234 commit 1d58a6d

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

aeon/distances/_sbd.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,7 @@ def sbd_distance(x: np.ndarray, y: np.ndarray, standardize: bool = True) -> floa
9898
if x.ndim == 1 and y.ndim == 1:
9999
return _univariate_sbd_distance(x, y, standardize)
100100
if x.ndim == 2 and y.ndim == 2:
101-
if x.shape[0] == 1 and y.shape[0] == 1:
102-
_x = x.ravel()
103-
_y = y.ravel()
104-
return _univariate_sbd_distance(_x, _y, standardize)
105-
else:
106-
# independent (time series should have the same number of channels!)
107-
nchannels = min(x.shape[0], y.shape[0])
108-
distance = 0.0
109-
for i in range(nchannels):
110-
distance += _univariate_sbd_distance(x[i], y[i], standardize)
111-
return distance / nchannels
101+
return _multivariate_sbd_distance(x, y, standardize)
112102

113103
raise ValueError("x and y must be 1D or 2D")
114104

@@ -245,3 +235,34 @@ def _univariate_sbd_distance(x: np.ndarray, y: np.ndarray, standardize: bool) ->
245235

246236
b = np.sqrt(np.dot(x, x) * np.dot(y, y))
247237
return np.abs(1.0 - np.max(a / b))
238+
239+
@njit(cache=True, fastmath=True)
240+
def _multivariate_sbd_distance(x: np.ndarray, y: np.ndarray, standardize: bool) -> float:
241+
x = x.astype(np.float64)
242+
y = y.astype(np.float64)
243+
244+
x = np.transpose(x, (1, 0))
245+
y = np.transpose(y, (1, 0))
246+
247+
if standardize:
248+
if x.size == 1 or y.size == 1:
249+
return 0.0
250+
251+
x = (x - np.mean(x)) / np.std(x)
252+
y = (y - np.mean(y)) / np.std(y)
253+
254+
norm1 = np.linalg.norm(x)
255+
norm2 = np.linalg.norm(y)
256+
257+
denom = norm1 * norm2
258+
if denom < 1e-9: # Avoid NaNs
259+
denom = np.inf
260+
261+
with objmode(cc="float64[:, :]"):
262+
cc = np.array([correlate(x[:, i], y[:, i], mode="full", method="fft") for i in range(x.shape[1])]).T
263+
264+
sz = x.shape[0]
265+
cc = np.vstack((cc[-(sz - 1):], cc[:sz]))
266+
norm_cc = np.real(cc).sum(axis=-1) / denom
267+
268+
return np.abs(1.0 - np.max(norm_cc))

0 commit comments

Comments
 (0)