deepof.models.get_transformer_encoder

deepof.models.get_transformer_encoder(input_shape: tuple, edge_feature_shape: tuple, adjacency_matrix: ndarray, latent_dim: int, use_gnn: bool = True, num_layers: int = 4, num_heads: int = 64, dff: int = 128, dropout_rate: float = 0.1, interaction_regularization: float = 0.0)

Build a Transformer encoder.

Based on https://www.tensorflow.org/text/tutorials/transformer. Adapted according to https://academic.oup.com/gigascience/article/8/11/giz134/5626377?login=true and https://arxiv.org/abs/1711.03905.

Parameters:
  • input_shape (tuple) – shape of the input data

  • edge_feature_shape (tuple) – shape of the adjacency matrix to use in the graph attention layers. Should be time x edges x features.

  • adjacency_matrix (np.ndarray) – adjacency matrix for the mice connectivity graph. Shape should be nodes x nodes.

  • latent_dim (int) – dimensionality of the latent space

  • use_gnn (bool) – If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.

  • num_layers (int) – number of transformer layers to include

  • num_heads (int) – number of heads of the multi-head-attention layers used on the transformer encoder

  • dff (int) – dimensionality of the token embeddings

  • dropout_rate (float) – dropout rate

  • interaction_regularization (float) – regularization parameter for the interaction features