Source code for openspeech.search.beam_search_lstm

# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch

from openspeech.search.beam_search_base import OpenspeechBeamSearchBase
from openspeech.decoders import LSTMAttentionDecoder


[docs]class BeamSearchLSTM(OpenspeechBeamSearchBase): r""" LSTM Beam Search Decoder Args: decoder, beam_size, batch_size decoder (DecoderLSTM): base decoder of lstm model. beam_size (int): size of beam. Inputs: encoder_outputs, targets, encoder_output_lengths, teacher_forcing_ratio encoder_outputs (torch.FloatTensor): A output sequence of encoders. `FloatTensor` of size ``(batch, seq_length, dimension)`` targets (torch.LongTensor): A target sequence passed to decoders. `IntTensor` of size ``(batch, seq_length)`` encoder_output_lengths (torch.LongTensor): A encoder output lengths sequence. `LongTensor` of size ``(batch)`` teacher_forcing_ratio (float): Ratio of teacher forcing. Returns: * logits (torch.FloatTensor): Log probability of model predictions. """ def __init__(self, decoder: LSTMAttentionDecoder, beam_size: int): super(BeamSearchLSTM, self).__init__(decoder, beam_size) self.hidden_state_dim = decoder.hidden_state_dim self.num_layers = decoder.num_layers self.validate_args = decoder.validate_args
[docs] def forward( self, encoder_outputs: torch.Tensor, encoder_output_lengths: torch.Tensor, ) -> torch.Tensor: r""" Beam search decoding. Inputs: encoder_outputs encoder_outputs (torch.FloatTensor): A output sequence of encoders. `FloatTensor` of size ``(batch, seq_length, dimension)`` Returns: * logits (torch.FloatTensor): Log probability of model predictions. """ batch_size, hidden_states = encoder_outputs.size(0), None self.finished = [[] for _ in range(batch_size)] self.finished_ps = [[] for _ in range(batch_size)] inputs, batch_size, max_length = self.validate_args(None, encoder_outputs, teacher_forcing_ratio=0.0) step_outputs, hidden_states, attn = self.forward_step(inputs, hidden_states, encoder_outputs) self.cumulative_ps, self.ongoing_beams = step_outputs.topk(self.beam_size) self.ongoing_beams = self.ongoing_beams.view(batch_size * self.beam_size, 1) self.cumulative_ps = self.cumulative_ps.view(batch_size * self.beam_size, 1) input_var = self.ongoing_beams encoder_dim = encoder_outputs.size(2) encoder_outputs = self._inflate(encoder_outputs, self.beam_size, dim=0) encoder_outputs = encoder_outputs.view(self.beam_size, batch_size, -1, encoder_dim) encoder_outputs = encoder_outputs.transpose(0, 1) encoder_outputs = encoder_outputs.reshape(batch_size * self.beam_size, -1, encoder_dim) if attn is not None: attn = self._inflate(attn, self.beam_size, dim=0) if isinstance(hidden_states, tuple): hidden_states = tuple([self._inflate(h, self.beam_size, 1) for h in hidden_states]) else: hidden_states = self._inflate(hidden_states, self.beam_size, 1) for di in range(max_length - 1): if self._is_all_finished(self.beam_size): break if isinstance(hidden_states, tuple): tuple(h.view(self.num_layers, batch_size * self.beam_size, self.hidden_state_dim) for h in hidden_states) else: hidden_states = hidden_states.view(self.num_layers, batch_size * self.beam_size, self.hidden_state_dim) step_outputs, hidden_states, attn = self.forward_step(input_var, hidden_states, encoder_outputs, attn) step_outputs = step_outputs.view(batch_size, self.beam_size, -1) current_ps, current_vs = step_outputs.topk(self.beam_size) self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size) self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1) current_ps = (current_ps.permute(0, 2, 1) + self.cumulative_ps.unsqueeze(1)).permute(0, 2, 1) current_ps = current_ps.view(batch_size, self.beam_size ** 2) current_vs = current_vs.view(batch_size, self.beam_size ** 2) self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size) self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1) topk_current_ps, topk_status_ids = current_ps.topk(self.beam_size) prev_status_ids = (topk_status_ids // self.beam_size) topk_current_vs = torch.zeros((batch_size, self.beam_size), dtype=torch.long) prev_status = torch.zeros(self.ongoing_beams.size(), dtype=torch.long) for batch_idx, batch in enumerate(topk_status_ids): for idx, topk_status_idx in enumerate(batch): topk_current_vs[batch_idx, idx] = current_vs[batch_idx, topk_status_idx] prev_status[batch_idx, idx] = self.ongoing_beams[batch_idx, prev_status_ids[batch_idx, idx]] self.ongoing_beams = torch.cat([prev_status, topk_current_vs.unsqueeze(2)], dim=2) self.cumulative_ps = topk_current_ps if torch.any(topk_current_vs == self.eos_id): finished_ids = torch.where(topk_current_vs == self.eos_id) num_successors = [1] * batch_size for (batch_idx, idx) in zip(*finished_ids): self.finished[batch_idx].append(self.ongoing_beams[batch_idx, idx]) self.finished_ps[batch_idx].append(self.cumulative_ps[batch_idx, idx]) if self.beam_size != 1: eos_count = self._get_successor( current_ps=current_ps, current_vs=current_vs, finished_ids=(batch_idx, idx), num_successor=num_successors[batch_idx], eos_count=1, k=self.beam_size, ) num_successors[batch_idx] += eos_count input_var = self.ongoing_beams[:, :, -1] input_var = input_var.view(batch_size * self.beam_size, -1) return self._get_hypothesis()