deepof.model_utils.embedding_model_fitting

deepof.model_utils.embedding_model_fitting(preprocessed_object: Tuple[ndarray, ndarray, ndarray, ndarray], adjacency_matrix: ndarray, embedding_model: str, encoder_type: str, batch_size: int, latent_dim: int, epochs: int, log_history: bool, log_hparams: bool, n_components: int, output_path: str, kmeans_loss: float, pretrained: str, save_checkpoints: bool, save_weights: bool, input_type: str, kl_annealing_mode: str, kl_warmup: int, reg_cat_clusters: float, recluster: bool, temperature: float, contrastive_similarity_function: str, contrastive_loss_function: str, beta: float, tau: float, interaction_regularization: float, run: int = 0, **kwargs)

Trains the specified embedding model on the preprocessed data.

Parameters:
  • coordinates (np.ndarray) – Coordinates of the data.

  • preprocessed_object (tuple) – Tuple containing the preprocessed data.

  • adjacency_matrix (np.ndarray) – adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.

  • embedding_model (str) – Model to use to embed and cluster the data. Must be one of VQVAE (default), VaDE, and contrastive.

  • encoder_type (str) – Encoder architecture to use. Must be one of “recurrent”, “TCN”, and “transformer”.

  • batch_size (int) – Batch size to use for training.

  • latent_dim (int) – Encoding size to use for training.

  • epochs (int) – Number of epochs to train the autoencoder for.

  • log_history (bool) – Whether to log the history of the autoencoder.

  • log_hparams (bool) – Whether to log the hyperparameters used for training.

  • n_components (int) – Number of components to fit to the data.

  • output_path (str) – Path to the output directory.

  • kmeans_loss (float) – Weight of the gram loss, which adds a regularization term to VQVAE models which penalizes the correlation between the dimensions in the latent space.

  • pretrained (str) – Path to the pretrained weights to use for the autoencoder.

  • save_checkpoints (bool) – Whether to save checkpoints during training.

  • save_weights (bool) – Whether to save the weights of the autoencoder after training.

  • input_type (str) – Input type of the TableDict objects used for preprocessing. For logging purposes only.

  • interaction_regularization (float) – Weight of the interaction regularization term (L1 penalization to all features not related to interactions).

  • run (int) – Run number to use for logging.

  • parameters (# Contrastive Model specific)

  • kl_annealing_mode (str) – Mode to use for KL annealing. Must be one of “linear” (default), or “sigmoid”.

  • kl_warmup (int) – Number of epochs during which KL is annealed.

  • reg_cat_clusters (bool) – whether to penalize uneven cluster membership in the latent space, by minimizing the KL divergence between cluster membership and a uniform categorical distribution.

  • recluster (bool) – Whether to recluster the data after each training using a Gaussian Mixture Model.

  • parameters

  • temperature (float) – temperature parameter for the contrastive loss functions. Higher values put harsher penalties on negative pair similarity.

  • contrastive_similarity_function (str) – similarity function between positive and negative pairs. Must be one of ‘cosine’ (default), ‘euclidean’, ‘dot’, and ‘edit’.

  • contrastive_loss_function (str) – contrastive loss function. Must be one of ‘nce’ (default), ‘dcl’, ‘fc’, and ‘hard_dcl’. See specific documentation for details.

  • beta (float) – Beta (concentration) parameter for the hard_dcl contrastive loss. Higher values lead to ‘harder’ negative samples.

  • tau (float) – Tau parameter for the dcl and hard_dcl contrastive losses, indicating positive class probability.

Returns:

List of trained models corresponding to the selected model class. The full trained model is last.