deepof.models.get_transformer_decoder

deepof.models.get_transformer_decoder(input_shape, latent_dim, num_layers=2, num_heads=8, dff=128, dropout_rate=0.1)

Build a Transformer decoder.

Based on https://www.tensorflow.org/text/tutorials/transformer. Adapted according to https://academic.oup.com/gigascience/article/8/11/giz134/5626377?login=true and https://arxiv.org/abs/1711.03905.

Parameters:
  • input_shape (tuple) – shape of the input data

  • latent_dim (int) – dimensionality of the latent space

  • num_layers (int) – number of transformer layers to include

  • num_heads (int) – number of heads of the multi-head-attention layers used on the transformer encoder

  • dff (int) – dimensionality of the token embeddings

  • dropout_rate (float) – dropout rate