Source code for openspeech.modules.relative_multi_head_attention

# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional

from openspeech.modules.wrapper import Linear


[docs]class RelativeMultiHeadAttention(nn.Module): r""" Multi-head attention with relative positional encoding. This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Args: dim (int): The dimension of model num_heads (int): The number of attention heads. dropout_p (float): probability of dropout Inputs: query, key, value, pos_embedding, mask - **query** (batch, time, dim): Tensor containing query vector - **key** (batch, time, dim): Tensor containing key vector - **value** (batch, time, dim): Tensor containing value vector - **pos_embedding** (batch, time, dim): Positional embedding tensor - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked Returns: - **outputs**: Tensor produces by relative multi head attention module. """ def __init__( self, dim: int = 512, num_heads: int = 16, dropout_p: float = 0.1, ) -> None: super(RelativeMultiHeadAttention, self).__init__() assert dim % num_heads == 0, "d_model % num_heads should be zero." self.dim = dim self.d_head = int(dim / num_heads) self.num_heads = num_heads self.sqrt_dim = math.sqrt(dim) self.query_proj = Linear(dim, dim) self.key_proj = Linear(dim, dim) self.value_proj = Linear(dim, dim) self.pos_proj = Linear(dim, dim, bias=False) self.dropout = nn.Dropout(p=dropout_p) self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) torch.nn.init.xavier_uniform_(self.u_bias) torch.nn.init.xavier_uniform_(self.v_bias) self.out_proj = Linear(dim, dim) def forward( self, query: Tensor, key: Tensor, value: Tensor, pos_embedding: Tensor, mask: Optional[Tensor] = None, ) -> Tensor: batch_size = value.size(0) query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3)) pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1)) pos_score = self._relative_shift(pos_score) score = (content_score + pos_score) / self.sqrt_dim if mask is not None: mask = mask.unsqueeze(1) score.masked_fill_(mask, -1e4) attn = F.softmax(score, -1) attn = self.dropout(attn) context = torch.matmul(attn, value).transpose(1, 2) context = context.contiguous().view(batch_size, -1, self.dim) return self.out_proj(context) def _relative_shift(self, pos_score: Tensor) -> Tensor: batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) padded_pos_score = torch.cat([zeros, pos_score], dim=-1) padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) return pos_score