Public Member Functions | |
def | __init__ (self, num_directions=10, n_jobs=None) |
def | fit (self, X, y=None) |
def | transform (self, X) |
def | __call__ (self, diag1, diag2) |
This is a class for computing the sliced Wasserstein distance matrix from a list of persistence diagrams. The Sliced Wasserstein distance is computed by projecting the persistence diagrams onto lines, comparing the projections with the 1-norm, and finally integrating over all possible lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
def gudhi.representations.metrics.SlicedWassersteinDistance.__init__ | ( | self, | |
num_directions = 10 , |
|||
n_jobs = None |
|||
) |
Constructor for the SlicedWassersteinDistance class. Parameters: num_directions (int): number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the distance computation (default 10). n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_distances` for details.
def gudhi.representations.metrics.SlicedWassersteinDistance.__call__ | ( | self, | |
diag1, | |||
diag2 | |||
) |
Apply SlicedWassersteinDistance on a single pair of persistence diagrams and outputs the result. Parameters: diag1 (n x 2 numpy array): first input persistence diagram. diag2 (n x 2 numpy array): second input persistence diagram. Returns: float: sliced Wasserstein distance.
def gudhi.representations.metrics.SlicedWassersteinDistance.fit | ( | self, | |
X, | |||
y = None |
|||
) |
Fit the SlicedWassersteinDistance class on a list of persistence diagrams: persistence diagrams are projected onto the different lines. The diagrams themselves and their projections are then stored in numpy arrays, called **diagrams_** and **approx_diag_**. Parameters: X (list of n x 2 numpy arrays): input persistence diagrams. y (n x 1 array): persistence diagram labels (unused).
def gudhi.representations.metrics.SlicedWassersteinDistance.transform | ( | self, | |
X | |||
) |
Compute all sliced Wasserstein distances between the persistence diagrams that were stored after calling the fit() method, and a given list of (possibly different) persistence diagrams. Parameters: X (list of n x 2 numpy arrays): input persistence diagrams. Returns: numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise sliced Wasserstein distances.
GUDHI Version 3.3.0 - C++ library for Topological Data Analysis (TDA) and Higher Dimensional Geometry Understanding. - Copyright : MIT | Generated on Tue Aug 11 2020 11:58:59 for GUDHI by Doxygen 1.8.18 |