deepof.models.get_transformer_encoder
- deepof.models.get_transformer_encoder(input_shape: tuple, edge_feature_shape: tuple, adjacency_matrix: ndarray, latent_dim: int, use_gnn: bool = True, num_layers: int = 4, num_heads: int = 64, dff: int = 128, dropout_rate: float = 0.1, interaction_regularization: float = 0.0)
Build a Transformer encoder.
Based on https://www.tensorflow.org/text/tutorials/transformer. Adapted according to https://academic.oup.com/gigascience/article/8/11/giz134/5626377?login=true and https://arxiv.org/abs/1711.03905.
- Parameters:
input_shape (tuple) – shape of the input data
edge_feature_shape (tuple) – shape of the adjacency matrix to use in the graph attention layers. Should be time x edges x features.
adjacency_matrix (np.ndarray) – adjacency matrix for the mice connectivity graph. Shape should be nodes x nodes.
latent_dim (int) – dimensionality of the latent space
use_gnn (bool) – If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
num_layers (int) – number of transformer layers to include
num_heads (int) – number of heads of the multi-head-attention layers used on the transformer encoder
dff (int) – dimensionality of the token embeddings
dropout_rate (float) – dropout rate
interaction_regularization (float) – regularization parameter for the interaction features