deepof.models.get_vade

deepof.models.get_vade(input_shape: tuple, edge_feature_shape: tuple, adjacency_matrix: ndarray, latent_dim: int, use_gnn: bool, n_components: int, batch_size: int = 64, kl_warmup: int = 15, kl_annealing_mode: str = 'sigmoid', mc_kl: int = 100, kmeans_loss: float = 1.0, reg_cluster_variance: bool = False, encoder_type: str = 'recurrent', interaction_regularization: float = 0.0)

Build a Gaussian mixture variational autoencoder (VaDE) model, adapted to the DeepOF setting.

Parameters:
  • input_shape (tuple) – shape of the input data.

  • edge_feature_shape (tuple) – shape of the edge feature matrix used for graph representations.

  • adjacency_matrix (np.ndarray) – adjacency matrix of the connectivity graph to use.

  • latent_dim (int) – dimensionality of the latent space.

  • use_gnn (bool) – If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.

  • n_components (int) – number of components in the Gaussian mixture.

  • batch_size (int) – batch size for training.

  • kl_warmup (int) – Number of iterations during which to warm up the KL divergence.

  • kl_annealing_mode (str) – mode to use for annealing the KL divergence. Must be one of “linear” and “sigmoid”.

  • mc_kl (int) – number of Monte Carlo samples to use for computing the KL divergence.

  • kmeans_loss (float) – weight of the Gram matrix loss as described in deepof.model_utils.compute_kmeans_loss.

  • reg_cluster_variance (bool) – whether to penalize uneven cluster variances in the latent space.

  • encoder_type (str) – type of encoder to use. Can be set to “recurrent” (default), “TCN”, or “transformer”.

  • interaction_regularization (float) – weight of the interaction regularization term.

Returns:

connected encoder of the VQ-VAE model. Outputs a vector of shape (latent_dim,). decoder (tf.keras.Model): connected decoder of the VQ-VAE model. grouper (tf.keras.Model): deep clustering branch of the VQ-VAE model. Outputs a vector of shape (n_components,) for each training instance, corresponding to the soft counts for each cluster. vade (tf.keras.Model): complete VaDE model

Return type:

encoder (tf.keras.Model)