Skip to content

More stable frechet distance computation #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
AIWanderer-X opened this issue May 16, 2025 · 0 comments
Open

More stable frechet distance computation #117

AIWanderer-X opened this issue May 16, 2025 · 0 comments

Comments

@AIWanderer-X
Copy link

covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

I always got "Imaginary component" errors that drove me crazy. I'm not sure if wrong numpy versions cause this. But then, I found a more stable way to compute FiD.

        diff = mu1 - mu2
        
        # Ensure covariance matrices are positive definite
        sigma1 = sigma1 + np.eye(sigma1.shape[0]) * eps
        sigma2 = sigma2 + np.eye(sigma2.shape[0]) * eps
        
        # More robust implementation using eigendecomposition
        sigma1_sqrt = linalg.sqrtm(sigma1)
        sigma1_sqrt = sigma1_sqrt.astype(np.float64)  # Ensure double precision
        
        # Compute sigma1_sqrt * sigma2 * sigma1_sqrt more carefully
        product = sigma1_sqrt @ sigma2 @ sigma1_sqrt
        
        # Ensure product is Hermitian (symmetric)
        product = (product + product.T) / 2
        
        # Compute trace directly from eigenvalues for better numerical stability
        try:
            s = np.real(np.linalg.eigvalsh(product))
            trace_sqrt = np.sum(np.sqrt(np.maximum(s, 0)))
        except np.linalg.LinAlgError:
            # Fall back to original sqrtm approach with warning
            warnings.warn("Eigendecomposition failed, falling back to sqrtm")
            covmean = linalg.sqrtm(product)
            if np.iscomplexobj(covmean):
                covmean = covmean.real
            trace_sqrt = np.trace(covmean)
        
        return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * trace_sqrt
@AIWanderer-X AIWanderer-X changed the title More stable frechet distance compution More stable frechet distance computation May 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant