deepof.model_utils.embedding_model_fitting
- deepof.model_utils.embedding_model_fitting(preprocessed_object: Tuple[ndarray, ndarray, ndarray, ndarray], adjacency_matrix: ndarray, embedding_model: str, encoder_type: str, batch_size: int, latent_dim: int, epochs: int, log_history: bool, log_hparams: bool, n_components: int, output_path: str, kmeans_loss: float, pretrained: str, save_checkpoints: bool, save_weights: bool, input_type: str, kl_annealing_mode: str, kl_warmup: int, reg_cat_clusters: float, recluster: bool, temperature: float, contrastive_similarity_function: str, contrastive_loss_function: str, beta: float, tau: float, interaction_regularization: float, run: int = 0, **kwargs)
Trains the specified embedding model on the preprocessed data.
- Parameters:
coordinates (np.ndarray) – Coordinates of the data.
preprocessed_object (tuple) – Tuple containing the preprocessed data.
adjacency_matrix (np.ndarray) – adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.
embedding_model (str) – Model to use to embed and cluster the data. Must be one of VQVAE (default), VaDE, and contrastive.
encoder_type (str) – Encoder architecture to use. Must be one of “recurrent”, “TCN”, and “transformer”.
batch_size (int) – Batch size to use for training.
latent_dim (int) – Encoding size to use for training.
epochs (int) – Number of epochs to train the autoencoder for.
log_history (bool) – Whether to log the history of the autoencoder.
log_hparams (bool) – Whether to log the hyperparameters used for training.
n_components (int) – Number of components to fit to the data.
output_path (str) – Path to the output directory.
kmeans_loss (float) – Weight of the gram loss, which adds a regularization term to VQVAE models which penalizes the correlation between the dimensions in the latent space.
pretrained (str) – Path to the pretrained weights to use for the autoencoder.
save_checkpoints (bool) – Whether to save checkpoints during training.
save_weights (bool) – Whether to save the weights of the autoencoder after training.
input_type (str) – Input type of the TableDict objects used for preprocessing. For logging purposes only.
interaction_regularization (float) – Weight of the interaction regularization term (L1 penalization to all features not related to interactions).
run (int) – Run number to use for logging.
parameters (# Contrastive Model specific)
kl_annealing_mode (str) – Mode to use for KL annealing. Must be one of “linear” (default), or “sigmoid”.
kl_warmup (int) – Number of epochs during which KL is annealed.
reg_cat_clusters (bool) – whether to penalize uneven cluster membership in the latent space, by minimizing the KL divergence between cluster membership and a uniform categorical distribution.
recluster (bool) – Whether to recluster the data after each training using a Gaussian Mixture Model.
parameters
temperature (float) – temperature parameter for the contrastive loss functions. Higher values put harsher penalties on negative pair similarity.
contrastive_similarity_function (str) – similarity function between positive and negative pairs. Must be one of ‘cosine’ (default), ‘euclidean’, ‘dot’, and ‘edit’.
contrastive_loss_function (str) – contrastive loss function. Must be one of ‘nce’ (default), ‘dcl’, ‘fc’, and ‘hard_dcl’. See specific documentation for details.
beta (float) – Beta (concentration) parameter for the hard_dcl contrastive loss. Higher values lead to ‘harder’ negative samples.
tau (float) – Tau parameter for the dcl and hard_dcl contrastive losses, indicating positive class probability.
- Returns:
List of trained models corresponding to the selected model class. The full trained model is last.