Source code for openspeech.criterion.joint_ctc_cross_entropy.joint_ctc_cross_entropy

# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch.nn as nn
from typing import Tuple
from omegaconf import DictConfig
from torch import Tensor

from .. import register_criterion
from ..joint_ctc_cross_entropy.configuration import JointCTCCrossEntropyLossConfigs
from ..label_smoothed_cross_entropy.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyLoss
from ...tokenizers.tokenizer import Tokenizer


[docs]@register_criterion("joint_ctc_cross_entropy", dataclass=JointCTCCrossEntropyLossConfigs) class JointCTCCrossEntropyLoss(nn.Module): r""" Privides Joint CTC-CrossEntropy Loss function. The logit from the encoder applies CTC Loss, and the logit from the decoder applies Cross Entropy. This loss makes the encoder more robust. Args: configs (DictConfig): hydra configuration set num_classes (int): the number of classfication tokenizer (Tokenizer): tokenizer is in charge of preparing the inputs for a model. Inputs: encoder_logits, logits, output_lengths, targets, target_lengths - encoder_logits (torch.FloatTensor): probability distribution value from encoder and it has a logarithm shape. The `FloatTensor` of size ``(input_length, batch, num_classes)`` - logits (torch.FloatTensor): probability distribution value from model and it has a logarithm shape. The `FloatTensor` of size ``(batch, seq_length, num_classes)`` - output_lengths (torch.LongTensor): length of model's outputs. The `LongTensor` of size ``(batch)`` - targets (torch.LongTensor): ground-truth encoded to integers which directly point a word in label. The `LongTensor` of size ``(batch, target_length)`` - target_lengths (torch.LongTensor): length of targets. The `LongTensor` of size ``(batch)`` Returns: loss, ctc_loss, cross_entropy_loss - loss (float): loss for training - ctc_loss (float): ctc loss for training - cross_entropy_loss (float): cross entropy loss for training Reference: Suyoun Kim et al.: Joint CTC-Attention based End-to-End Speech Recognition using Multi-task Learning: https://arxiv.org/abs/1609.06773 """ def __init__( self, configs: DictConfig, num_classes: int, tokenizer: Tokenizer, ) -> None: super(JointCTCCrossEntropyLoss, self).__init__() self.num_classes = num_classes self.dim = -1 self.ignore_index = tokenizer.pad_id self.reduction = configs.criterion.reduction.lower() self.ctc_weight = configs.criterion.ctc_weight self.cross_entropy_weight = configs.criterion.cross_entropy_weight self.ctc_loss = nn.CTCLoss( blank=tokenizer.blank_id, reduction=self.reduction, zero_infinity=configs.criterion.zero_infinity, ) if configs.criterion.smoothing > 0.0: self.cross_entropy_loss = LabelSmoothedCrossEntropyLoss( configs=configs, num_classes=num_classes, tokenizer=tokenizer, ) else: self.cross_entropy_loss = nn.CrossEntropyLoss(reduction=self.reduction, ignore_index=self.ignore_index) def forward( self, encoder_logits: Tensor, logits: Tensor, output_lengths: Tensor, targets: Tensor, target_lengths: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: max_target_length = targets.size(1) max_logits_length = logits.size(1) if max_logits_length > max_target_length: logits = logits[:, :max_target_length, :] cross_entropy_targets = targets.clone() elif max_target_length > max_logits_length: cross_entropy_targets = targets[:, :max_logits_length].clone() else: cross_entropy_targets = targets.clone() logits = logits.contiguous().view(-1, logits.size(-1)) ctc_loss = self.ctc_loss(encoder_logits, targets, output_lengths, target_lengths) cross_entropy_loss = self.cross_entropy_loss(logits, cross_entropy_targets.contiguous().view(-1)) loss = cross_entropy_loss * self.cross_entropy_weight + ctc_loss * self.ctc_weight return loss, ctc_loss, cross_entropy_loss