Source code for openspeech.search.beam_search_transformer_transducer
# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
from openspeech.search.beam_search_base import OpenspeechBeamSearchBase
from openspeech.decoders import TransformerTransducerDecoder
[docs]class BeamSearchTransformerTransducer(OpenspeechBeamSearchBase):
r"""
Transformer Transducer Beam Search
Reference: RNN-T FOR LATENCY CONTROLLED ASR WITH IMPROVED BEAM SEARCH (https://arxiv.org/pdf/1911.01629.pdf)
Args: joint, decoder, beam_size, expand_beam, state_beam, blank_id
joint: joint `encoder_outputs` and `decoder_outputs`
decoder (TransformerTransducerDecoder): base decoder of transformer transducer model.
beam_size (int): size of beam.
expand_beam (int): The threshold coefficient to limit the number
of expanded hypotheses that are added in A (process_hyp).
state_beam (int): The threshold coefficient in log space to decide if hyps in A (process_hyps)
is likely to compete with hyps in B (ongoing_beams)
blank_id (int): blank id
Inputs: encoder_outputs, max_length
encoder_outputs (torch.FloatTensor): A output sequence of encoders. `FloatTensor` of size
``(batch, seq_length, dimension)``
max_length (int): max decoding time step
Returns:
* predictions (torch.LongTensor): model predictions.
"""
def __init__(
self,
joint,
decoder: TransformerTransducerDecoder,
beam_size: int = 3,
expand_beam: float = 2.3,
state_beam: float = 4.6,
blank_id: int = 3,
) -> None:
super(BeamSearchTransformerTransducer, self).__init__(decoder, beam_size)
self.joint = joint
self.forward_step = self.decoder.forward_step
self.expand_beam = expand_beam
self.state_beam = state_beam
self.blank_id = blank_id
[docs] def forward(self, encoder_outputs: torch.Tensor, max_length: int):
r"""
Beam search decoding.
Inputs: encoder_outputs, max_length
encoder_outputs (torch.FloatTensor): A output sequence of encoders. `FloatTensor` of size
``(batch, seq_length, dimension)``
max_length (int): max decoding time step
Returns:
* predictions (torch.LongTensor): model predictions.
"""
hypothesis = list()
hypothesis_score = list()
for batch_idx in range(encoder_outputs.size(0)):
blank = (
torch.ones((1, 1), device=encoder_outputs.device, dtype=torch.long) * self.blank_id
)
step_input = (
torch.ones((1, 1), device=encoder_outputs.device, dtype=torch.long) * self.sos_id
)
hyp = {
"prediction": [self.sos_id],
"logp_score": 0.0,
}
ongoing_beams = [hyp]
for t_step in range(max_length):
process_hyps = ongoing_beams
ongoing_beams = list()
while True:
if len(ongoing_beams) >= self.beam_size:
break
a_best_hyp = max(process_hyps, key=lambda x: x["logp_score"] / len(x["prediction"]))
if len(ongoing_beams) > 0:
b_best_hyp = max(
ongoing_beams,
key=lambda x: x["logp_score"] / len(x["prediction"]),
)
a_best_prob = a_best_hyp["logp_score"]
b_best_prob = b_best_hyp["logp_score"]
if b_best_prob >= self.state_beam + a_best_prob:
break
process_hyps.remove(a_best_hyp)
step_input[0, 0] = a_best_hyp["prediction"][-1]
step_lengths = encoder_outputs.new_tensor([0], dtype=torch.long)
step_outputs = self.forward_step(step_input, step_lengths).squeeze(0).squeeze(0)
log_probs = self.joint(encoder_outputs[batch_idx, t_step, :], step_outputs)
topk_targets, topk_idx = log_probs.topk(k=self.beam_size)
if topk_idx[0] != blank:
best_logp = topk_targets[0]
else:
best_logp = topk_targets[1]
for j in range(topk_targets.size(0)):
topk_hyp = {
"prediction": a_best_hyp["prediction"][:],
"logp_score": a_best_hyp["logp_score"] + topk_targets[j],
}
if topk_idx[j] == self.blank_id:
ongoing_beams.append(topk_hyp)
continue
if topk_targets[j] >= best_logp - self.expand_beam:
topk_hyp["prediction"].append(topk_idx[j].item())
process_hyps.append(topk_hyp)
ongoing_beams = sorted(
ongoing_beams,
key=lambda x: x["logp_score"] / len(x["prediction"]),
reverse=True,
)[0]
hypothesis.append(torch.LongTensor(ongoing_beams["prediction"][1:]))
hypothesis_score.append(ongoing_beams["logp_score"] / len(ongoing_beams["prediction"]))
return self._fill_sequence(hypothesis)