deepof.models.get_vade
- deepof.models.get_vade(input_shape: tuple, edge_feature_shape: tuple, adjacency_matrix: ndarray, latent_dim: int, use_gnn: bool, n_components: int, batch_size: int = 64, kl_warmup: int = 15, kl_annealing_mode: str = 'sigmoid', mc_kl: int = 100, kmeans_loss: float = 1.0, reg_cluster_variance: bool = False, encoder_type: str = 'recurrent', interaction_regularization: float = 0.0)
Build a Gaussian mixture variational autoencoder (VaDE) model, adapted to the DeepOF setting.
- Parameters:
input_shape (tuple) – shape of the input data.
edge_feature_shape (tuple) – shape of the edge feature matrix used for graph representations.
adjacency_matrix (np.ndarray) – adjacency matrix of the connectivity graph to use.
latent_dim (int) – dimensionality of the latent space.
use_gnn (bool) – If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
n_components (int) – number of components in the Gaussian mixture.
batch_size (int) – batch size for training.
kl_warmup (int) – Number of iterations during which to warm up the KL divergence.
kl_annealing_mode (str) – mode to use for annealing the KL divergence. Must be one of “linear” and “sigmoid”.
mc_kl (int) – number of Monte Carlo samples to use for computing the KL divergence.
kmeans_loss (float) – weight of the Gram matrix loss as described in deepof.model_utils.compute_kmeans_loss.
reg_cluster_variance (bool) – whether to penalize uneven cluster variances in the latent space.
encoder_type (str) – type of encoder to use. Can be set to “recurrent” (default), “TCN”, or “transformer”.
interaction_regularization (float) – weight of the interaction regularization term.
- Returns:
connected encoder of the VQ-VAE model. Outputs a vector of shape (latent_dim,). decoder (tf.keras.Model): connected decoder of the VQ-VAE model. grouper (tf.keras.Model): deep clustering branch of the VQ-VAE model. Outputs a vector of shape (n_components,) for each training instance, corresponding to the soft counts for each cluster. vade (tf.keras.Model): complete VaDE model
- Return type:
encoder (tf.keras.Model)