deepof.visuals.plot_cluster_detection_performance

deepof.visuals.plot_cluster_detection_performance(coordinates: deepof_coordinates, chunk_stats: DataFrame, cluster_gbm_performance: dict, hard_counts: ndarray, groups: list, save: bool = False, visualization: str = 'confusion_matrix', ax: Axes | None = None)

Plot either a confusion matrix or a bar chart with balanced accuracy for cluster detection cross validated models.

Designed to be run after deepof.post_hoc.train_supervised_cluster_detectors (see documentation for details).

Parameters:
  • coordinates (coordinates) – deepOF project where the data is stored.

  • chunk_stats (pd.DataFrame) – table with descriptive statistics for a series of sequences (‘chunks’).

  • cluster_gbm_performance (dict) – cross-validated dictionary containing trained estimators and performance metrics.

  • hard_counts (np.ndarray) – cluster assignments for the corresponding ‘chunk_stats’ table.

  • groups (list) – cross-validation indices. Data from the same animal are never shared between train and test sets.

  • save (bool) – name of the file where to save the produced figure.

  • visualization (str) – plot to render. Must be one of ‘confusion_matrix’, or ‘balanced_accuracy’.

  • ax (plt.Axes) – axis where to plot the figure. If None, a new figure is created.