# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
import torch.nn as nn
[docs]class OpenspeechBeamSearchBase(nn.Module):
"""
Openspeech's beam-search base class. Implement the methods required for beamsearch.
You have to implement `forward` method.
Note:
Do not use this class directly, use one of the sub classes.
"""
def __init__(self, decoder, beam_size: int):
super(OpenspeechBeamSearchBase, self).__init__()
self.decoder = decoder
self.beam_size = beam_size
self.sos_id = decoder.sos_id
self.pad_id = decoder.pad_id
self.eos_id = decoder.eos_id
self.ongoing_beams = None
self.cumulative_ps = None
self.forward_step = decoder.forward_step
def _inflate(self, tensor: torch.Tensor, n_repeat: int, dim: int) -> torch.Tensor:
repeat_dims = [1] * len(tensor.size())
repeat_dims[dim] *= n_repeat
return tensor.repeat(*repeat_dims)
def _get_successor(
self,
current_ps: torch.Tensor,
current_vs: torch.Tensor,
finished_ids: tuple,
num_successor: int,
eos_count: int,
k: int
) -> int:
finished_batch_idx, finished_idx = finished_ids
successor_ids = current_ps.topk(k + num_successor)[1]
successor_idx = successor_ids[finished_batch_idx, -1]
successor_p = current_ps[finished_batch_idx, successor_idx]
successor_v = current_vs[finished_batch_idx, successor_idx]
prev_status_idx = (successor_idx // k)
prev_status = self.ongoing_beams[finished_batch_idx, prev_status_idx]
prev_status = prev_status.view(-1)[:-1]
successor = torch.cat([prev_status, successor_v.view(1)])
if int(successor_v) == self.eos_id:
self.finished[finished_batch_idx].append(successor)
self.finished_ps[finished_batch_idx].append(successor_p)
eos_count = self._get_successor(
current_ps=current_ps,
current_vs=current_vs,
finished_ids=finished_ids,
num_successor=num_successor + eos_count,
eos_count=eos_count + 1,
k=k,
)
else:
self.ongoing_beams[finished_batch_idx, finished_idx] = successor
self.cumulative_ps[finished_batch_idx, finished_idx] = successor_p
return eos_count
def _get_hypothesis(self):
predictions = list()
for batch_idx, batch in enumerate(self.finished):
# if there is no terminated sentences, bring ongoing sentence which has the highest probability instead
if len(batch) == 0:
prob_batch = self.cumulative_ps[batch_idx]
top_beam_idx = int(prob_batch.topk(1)[1])
predictions.append(self.ongoing_beams[batch_idx, top_beam_idx])
# bring highest probability sentence
else:
top_beam_idx = int(torch.FloatTensor(self.finished_ps[batch_idx]).topk(1)[1])
predictions.append(self.finished[batch_idx][top_beam_idx])
predictions = self._fill_sequence(predictions)
return predictions
def _is_all_finished(self, k: int) -> bool:
for done in self.finished:
if len(done) < k:
return False
return True
def _fill_sequence(self, y_hats: list) -> torch.Tensor:
batch_size = len(y_hats)
max_length = -1
for y_hat in y_hats:
if len(y_hat) > max_length:
max_length = len(y_hat)
matched = torch.zeros((batch_size, max_length), dtype=torch.long)
for batch_idx, y_hat in enumerate(y_hats):
matched[batch_idx, :len(y_hat)] = y_hat
matched[batch_idx, len(y_hat):] = int(self.pad_id)
return matched
def forward(self, *args, **kwargs):
raise NotImplementedError