Source code for openspeech.search.beam_search_transformer
# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
from openspeech.search.beam_search_base import OpenspeechBeamSearchBase
from openspeech.decoders import TransformerDecoder
[docs]class BeamSearchTransformer(OpenspeechBeamSearchBase):
r"""
Transformer Beam Search Decoder
Args: decoder, beam_size, batch_size
decoder (DecoderLSTM): base decoder of lstm model.
beam_size (int): size of beam.
Inputs: encoder_outputs, targets, encoder_output_lengths, teacher_forcing_ratio
encoder_outputs (torch.FloatTensor): A output sequence of encoders. `FloatTensor` of size ``(batch, seq_length, dimension)``
targets (torch.LongTensor): A target sequence passed to decoders. `IntTensor` of size ``(batch, seq_length)``
encoder_output_lengths (torch.LongTensor): A encoder output lengths sequence. `LongTensor` of size ``(batch)``
teacher_forcing_ratio (float): Ratio of teacher forcing.
Returns:
* logits (torch.FloatTensor): Log probability of model predictions.
"""
def __init__(self, decoder: TransformerDecoder, beam_size: int = 3) -> None:
super(BeamSearchTransformer, self).__init__(decoder, beam_size)
self.use_cuda = True if torch.cuda.is_available() else False
def forward(
self,
encoder_outputs: torch.FloatTensor,
encoder_output_lengths: torch.FloatTensor,
):
batch_size = encoder_outputs.size(0)
self.finished = [[] for _ in range(batch_size)]
self.finished_ps = [[] for _ in range(batch_size)]
decoder_inputs = torch.IntTensor(batch_size, self.decoder.max_length).fill_(self.sos_id).long()
decoder_input_lengths = torch.IntTensor(batch_size).fill_(1)
outputs = self.forward_step(
decoder_inputs=decoder_inputs[:, :1],
decoder_input_lengths=decoder_input_lengths,
encoder_outputs=encoder_outputs,
encoder_output_lengths=encoder_output_lengths,
positional_encoding_length=1,
)
step_outputs = self.decoder.fc(outputs).log_softmax(dim=-1)
self.cumulative_ps, self.ongoing_beams = step_outputs.topk(self.beam_size)
self.ongoing_beams = self.ongoing_beams.view(batch_size * self.beam_size, 1)
self.cumulative_ps = self.cumulative_ps.view(batch_size * self.beam_size, 1)
decoder_inputs = torch.IntTensor(batch_size * self.beam_size, 1).fill_(self.sos_id)
decoder_inputs = torch.cat((decoder_inputs, self.ongoing_beams), dim=-1) # bsz * beam x 2
encoder_dim = encoder_outputs.size(2)
encoder_outputs = self._inflate(encoder_outputs, self.beam_size, dim=0)
encoder_outputs = encoder_outputs.view(self.beam_size, batch_size, -1, encoder_dim)
encoder_outputs = encoder_outputs.transpose(0, 1)
encoder_outputs = encoder_outputs.reshape(batch_size * self.beam_size, -1, encoder_dim)
encoder_output_lengths = encoder_output_lengths.unsqueeze(1).repeat(1, self.beam_size).view(-1)
for di in range(2, self.decoder.max_length):
if self._is_all_finished(self.beam_size):
break
decoder_input_lengths = torch.LongTensor(batch_size * self.beam_size).fill_(di)
step_outputs = self.forward_step(
decoder_inputs=decoder_inputs[:, :di],
decoder_input_lengths=decoder_input_lengths,
encoder_outputs=encoder_outputs,
encoder_output_lengths=encoder_output_lengths,
positional_encoding_length=di,
)
step_outputs = self.decoder.fc(step_outputs).log_softmax(dim=-1)
step_outputs = step_outputs.view(batch_size, self.beam_size, -1, 10)
current_ps, current_vs = step_outputs.topk(self.beam_size)
# TODO: Check transformer's beam search
current_ps = current_ps[:, :, -1, :]
current_vs = current_vs[:, :, -1, :]
self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size)
self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1)
current_ps = (current_ps.permute(0, 2, 1) + self.cumulative_ps.unsqueeze(1)).permute(0, 2, 1)
current_ps = current_ps.view(batch_size, self.beam_size ** 2)
current_vs = current_vs.contiguous().view(batch_size, self.beam_size ** 2)
self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size)
self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1)
topk_current_ps, topk_status_ids = current_ps.topk(self.beam_size)
prev_status_ids = (topk_status_ids // self.beam_size)
topk_current_vs = torch.zeros((batch_size, self.beam_size), dtype=torch.long)
prev_status = torch.zeros(self.ongoing_beams.size(), dtype=torch.long)
for batch_idx, batch in enumerate(topk_status_ids):
for idx, topk_status_idx in enumerate(batch):
topk_current_vs[batch_idx, idx] = current_vs[batch_idx, topk_status_idx]
prev_status[batch_idx, idx] = self.ongoing_beams[batch_idx, prev_status_ids[batch_idx, idx]]
self.ongoing_beams = torch.cat([prev_status, topk_current_vs.unsqueeze(2)], dim=2)
self.cumulative_ps = topk_current_ps
if torch.any(topk_current_vs == self.eos_id):
finished_ids = torch.where(topk_current_vs == self.eos_id)
num_successors = [1] * batch_size
for (batch_idx, idx) in zip(*finished_ids):
self.finished[batch_idx].append(self.ongoing_beams[batch_idx, idx])
self.finished_ps[batch_idx].append(self.cumulative_ps[batch_idx, idx])
if self.beam_size != 1:
eos_count = self._get_successor(
current_ps=current_ps,
current_vs=current_vs,
finished_ids=(batch_idx, idx),
num_successor=num_successors[batch_idx],
eos_count=1,
k=self.beam_size,
)
num_successors[batch_idx] += eos_count
ongoing_beams = self.ongoing_beams.clone().view(batch_size * self.beam_size, -1)
decoder_inputs = torch.cat((decoder_inputs, ongoing_beams[:, :-1]), dim=-1)
return self._get_hypothesis()