Source code for openspeech.search.beam_search_base

# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch
import torch.nn as nn


[docs]class OpenspeechBeamSearchBase(nn.Module): """ Openspeech's beam-search base class. Implement the methods required for beamsearch. You have to implement `forward` method. Note: Do not use this class directly, use one of the sub classes. """ def __init__(self, decoder, beam_size: int): super(OpenspeechBeamSearchBase, self).__init__() self.decoder = decoder self.beam_size = beam_size self.sos_id = decoder.sos_id self.pad_id = decoder.pad_id self.eos_id = decoder.eos_id self.ongoing_beams = None self.cumulative_ps = None self.forward_step = decoder.forward_step def _inflate(self, tensor: torch.Tensor, n_repeat: int, dim: int) -> torch.Tensor: repeat_dims = [1] * len(tensor.size()) repeat_dims[dim] *= n_repeat return tensor.repeat(*repeat_dims) def _get_successor( self, current_ps: torch.Tensor, current_vs: torch.Tensor, finished_ids: tuple, num_successor: int, eos_count: int, k: int ) -> int: finished_batch_idx, finished_idx = finished_ids successor_ids = current_ps.topk(k + num_successor)[1] successor_idx = successor_ids[finished_batch_idx, -1] successor_p = current_ps[finished_batch_idx, successor_idx] successor_v = current_vs[finished_batch_idx, successor_idx] prev_status_idx = (successor_idx // k) prev_status = self.ongoing_beams[finished_batch_idx, prev_status_idx] prev_status = prev_status.view(-1)[:-1] successor = torch.cat([prev_status, successor_v.view(1)]) if int(successor_v) == self.eos_id: self.finished[finished_batch_idx].append(successor) self.finished_ps[finished_batch_idx].append(successor_p) eos_count = self._get_successor( current_ps=current_ps, current_vs=current_vs, finished_ids=finished_ids, num_successor=num_successor + eos_count, eos_count=eos_count + 1, k=k, ) else: self.ongoing_beams[finished_batch_idx, finished_idx] = successor self.cumulative_ps[finished_batch_idx, finished_idx] = successor_p return eos_count def _get_hypothesis(self): predictions = list() for batch_idx, batch in enumerate(self.finished): # if there is no terminated sentences, bring ongoing sentence which has the highest probability instead if len(batch) == 0: prob_batch = self.cumulative_ps[batch_idx] top_beam_idx = int(prob_batch.topk(1)[1]) predictions.append(self.ongoing_beams[batch_idx, top_beam_idx]) # bring highest probability sentence else: top_beam_idx = int(torch.FloatTensor(self.finished_ps[batch_idx]).topk(1)[1]) predictions.append(self.finished[batch_idx][top_beam_idx]) predictions = self._fill_sequence(predictions) return predictions def _is_all_finished(self, k: int) -> bool: for done in self.finished: if len(done) < k: return False return True def _fill_sequence(self, y_hats: list) -> torch.Tensor: batch_size = len(y_hats) max_length = -1 for y_hat in y_hats: if len(y_hat) > max_length: max_length = len(y_hat) matched = torch.zeros((batch_size, max_length), dtype=torch.long) for batch_idx, y_hat in enumerate(y_hats): matched[batch_idx, :len(y_hat)] = y_hat matched[batch_idx, len(y_hat):] = int(self.pad_id) return matched def forward(self, *args, **kwargs): raise NotImplementedError