deepof.model_utils.compute_kmeans_loss

deepof.model_utils.compute_kmeans_loss(latent_means: Tensor, weight: float = 1.0, batch_size: int = 64)

Add a penalty to the singular values of the Gram matrix of the latent means. It helps disentangle the latent space.

Based on https://arxiv.org/pdf/1610.04794.pdf, and https://www.biorxiv.org/content/10.1101/2020.05.14.095430v3.

Parameters:
  • latent_means (tf.Tensor) – tensor containing the means of the latent distribution

  • weight (float) – weight of the Gram loss in the total loss function

  • batch_size (int) – batch size of the data to compute the kmeans loss for.

Returns:

kmeans loss

Return type:

tf.Tensor