deepof.model_utils.get_callbacks

deepof.model_utils.get_callbacks(embedding_model: str, encoder_type: str, kmeans_loss: float = 1.0, input_type: str = False, cp: bool = False, logparam: dict | None = None, outpath: str = '.', run: int = False) List[Any]

Generate callbacks used for model training.

Parameters:
  • embedding_model (str) – name of the embedding model

  • encoder_type (str) – Architecture used for the encoder. Must be one of “recurrent”, “TCN”, and “transformer”

  • kmeans_loss (float) – Weight of the gram loss

  • input_type (str) – Input type to use for training

  • cp (bool) – Whether to use checkpointing or not

  • logparam (dict) – Dictionary containing the hyperparameters to log in tensorboard

  • outpath (str) – Path to the output directory

  • run (int) – Run number to use for checkpointing

Returns:

List of callbacks to be used for training

Return type:

List[Union[Any]]