torchaudio.prototype.models¶
conformer_rnnt_model¶
- torchaudio.prototype.models.conformer_rnnt_model(*, input_dim: int, encoding_dim: int, time_reduction_stride: int, conformer_input_dim: int, conformer_ffn_dim: int, conformer_num_layers: int, conformer_num_heads: int, conformer_depthwise_conv_kernel_size: int, conformer_dropout: float, num_symbols: int, symbol_embedding_dim: int, num_lstm_layers: int, lstm_hidden_dim: int, lstm_layer_norm: int, lstm_layer_norm_epsilon: int, lstm_dropout: int, joiner_activation: str) RNNT [source]¶
Builds Conformer-based recurrent neural network transducer (RNN-T) model.
- Parameters:
input_dim (int) – dimension of input sequence frames passed to transcription network.
encoding_dim (int) – dimension of transcription- and prediction-network-generated encodings passed to joint network.
time_reduction_stride (int) – factor by which to reduce length of input sequence.
conformer_input_dim (int) – dimension of Conformer input.
conformer_ffn_dim (int) – hidden layer dimension of each Conformer layer’s feedforward network.
conformer_num_layers (int) – number of Conformer layers to instantiate.
conformer_num_heads (int) – number of attention heads in each Conformer layer.
conformer_depthwise_conv_kernel_size (int) – kernel size of each Conformer layer’s depthwise convolution layer.
conformer_dropout (float) – Conformer dropout probability.
num_symbols (int) – cardinality of set of target tokens.
symbol_embedding_dim (int) – dimension of each target token embedding.
num_lstm_layers (int) – number of LSTM layers to instantiate.
lstm_hidden_dim (int) – output dimension of each LSTM layer.
lstm_layer_norm (bool) – if
True
, enables layer normalization for LSTM layers.lstm_layer_norm_epsilon (float) – value of epsilon to use in LSTM layer normalization layers.
lstm_dropout (float) – LSTM dropout probability.
joiner_activation (str) – activation function to use in the joiner. Must be one of (“relu”, “tanh”). (Default: “relu”)
Returns –
- RNNT:
Conformer RNN-T model.
conformer_rnnt_base¶
ConvEmformer¶
- class torchaudio.prototype.models.ConvEmformer(input_dim: int, num_heads: int, ffn_dim: int, num_layers: int, segment_length: int, kernel_size: int, dropout: float = 0.0, ffn_activation: str = 'relu', left_context_length: int = 0, right_context_length: int = 0, max_memory_size: int = 0, weight_init_scale_strategy: Optional[str] = 'depthwise', tanh_on_mem: bool = False, negative_inf: float = -100000000.0, conv_activation: str = 'silu')[source]¶
Implements the convolution-augmented streaming transformer architecture introduced in Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution [Shi et al., 2022].
- Parameters:
input_dim (int) – input dimension.
num_heads (int) – number of attention heads in each ConvEmformer layer.
ffn_dim (int) – hidden layer dimension of each ConvEmformer layer’s feedforward network.
num_layers (int) – number of ConvEmformer layers to instantiate.
segment_length (int) – length of each input segment.
kernel_size (int) – size of kernel to use in convolution modules.
dropout (float, optional) – dropout probability. (Default: 0.0)
ffn_activation (str, optional) – activation function to use in feedforward networks. Must be one of (“relu”, “gelu”, “silu”). (Default: “relu”)
left_context_length (int, optional) – length of left context. (Default: 0)
right_context_length (int, optional) – length of right context. (Default: 0)
max_memory_size (int, optional) – maximum number of memory elements to use. (Default: 0)
weight_init_scale_strategy (str or None, optional) – per-layer weight initialization scaling strategy. Must be one of (“depthwise”, “constant”,
None
). (Default: “depthwise”)tanh_on_mem (bool, optional) – if
True
, applies tanh to memory elements. (Default:False
)negative_inf (float, optional) – value to use for negative infinity in attention weights. (Default: -1e8)
conv_activation (str, optional) – activation function to use in convolution modules. Must be one of (“relu”, “gelu”, “silu”). (Default: “silu”)
Examples
>>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4) >>> input = torch.rand(10, 200, 80) >>> lengths = torch.randint(1, 200, (10,)) >>> output, lengths = conv_emformer(input, lengths) >>> input = torch.rand(4, 20, 80) >>> lengths = torch.ones(4) * 20 >>> output, lengths, states = conv_emformer.infer(input, lengths, None)
- forward(input: Tensor, lengths: Tensor) Tuple[Tensor, Tensor] ¶
Forward pass for training and non-streaming inference.
B: batch size; T: max number of input frames in batch; D: feature dimension of each frame.
- Parameters:
input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, T + right_context_length, D).
lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in
input
.
- Returns:
- Tensor
output frames, with shape (B, T, D).
- Tensor
output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.
- Return type:
(Tensor, Tensor)
- infer(input: Tensor, lengths: Tensor, states: Optional[List[List[Tensor]]] = None) Tuple[Tensor, Tensor, List[List[Tensor]]] ¶
Forward pass for streaming inference.
B: batch size; D: feature dimension of each frame.
- Parameters:
input (torch.Tensor) – utterance frames right-padded with right context frames, with shape (B, segment_length + right_context_length, D).
lengths (torch.Tensor) – with shape (B,) and i-th element representing number of valid frames for i-th batch element in
input
.states (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing internal state generated in preceding invocation of
infer
. (Default:None
)
- Returns:
- Tensor
output frames, with shape (B, segment_length, D).
- Tensor
output lengths, with shape (B,) and i-th element representing number of valid frames for i-th batch element in output frames.
- List[List[Tensor]]
output states; list of lists of tensors representing internal state generated in current invocation of
infer
.
- Return type:
(Tensor, Tensor, List[List[Tensor]])