deepof.model_utils.compute_kmeans_loss
- deepof.model_utils.compute_kmeans_loss(latent_means: Tensor, weight: float = 1.0, batch_size: int = 64)
Add a penalty to the singular values of the Gram matrix of the latent means. It helps disentangle the latent space.
Based on https://arxiv.org/pdf/1610.04794.pdf, and https://www.biorxiv.org/content/10.1101/2020.05.14.095430v3.
- Parameters:
latent_means (tf.Tensor) – tensor containing the means of the latent distribution
weight (float) – weight of the Gram loss in the total loss function
batch_size (int) – batch size of the data to compute the kmeans loss for.
- Returns:
kmeans loss
- Return type:
tf.Tensor