deepof.models.get_vqvae

deepof.models.get_vqvae(input_shape: tuple, edge_feature_shape: tuple, adjacency_matrix: ndarray, latent_dim: int, use_gnn: bool, n_components: int, beta: float = 1.0, kmeans_loss: float = 0.0, encoder_type: str = 'recurrent', interaction_regularization: float = 0.0)

Build a Vector-Quantization variational autoencoder (VQ-VAE) model, adapted to the DeepOF setting.

Parameters:
  • input_shape (tuple) – shape of the input to the encoder.

  • edge_feature_shape (tuple) – shape of the edge feature matrix used for graph representations.

  • adjacency_matrix (np.ndarray) – adjacency matrix of the connectivity graph to use.

  • latent_dim (int) – dimension of the latent space.

  • use_gnn (bool) – If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.

  • n_components (int) – number of embeddings in the embedding layer.

  • beta (float) – beta parameter of the VQ loss.

  • kmeans_loss (float) – regularization parameter for the Gram matrix.

  • encoder_type (str) – type of encoder to use. Can be set to “recurrent” (default), “TCN”, or “transformer”.

  • interaction_regularization (float) – Regularization parameter for the interaction features.

Returns:

connected encoder of the VQ-VAE model. Outputs a vector of shape (latent_dim,). decoder (tf.keras.Model): connected decoder of the VQ-VAE model. grouper (tf.keras.Model): connected embedder layer of the VQ-VAE model. Outputs cluster indices of shape (batch_size,). vqvae (tf.keras.Model): complete VQ VAE model.

Return type:

encoder (tf.keras.Model)