Source code for openspeech.modules.multi_head_attention

# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch.nn as nn
from torch import Tensor
from typing import Tuple, Optional

from openspeech.modules.wrapper import Linear
from openspeech.modules.dot_product_attention import DotProductAttention


[docs]class MultiHeadAttention(nn.Module): r""" Multi-Head Attention proposed in "Attention Is All You Need" Instead of performing a single attention function with d_model-dimensional keys, values, and queries, project the queries, keys and values h times with different, learned linear projections to d_head dimensions. These are concatenated and once again projected, resulting in the final values. Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o where head_i = Attention(Q · W_q, K · W_k, V · W_v) Args: dim (int): The dimension of model (default: 512) num_heads (int): The number of attention heads. (default: 8) Inputs: query, key, value, mask - **query** (batch, q_len, d_model): tensor containing projection vector for decoders. - **key** (batch, k_len, d_model): tensor containing projection vector for encoders. - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence. - **mask** (-): tensor containing indices to be masked Returns: output, attn - **output** (batch, output_len, dimensions): tensor containing the attended output features. - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoders outputs. """ def __init__(self, dim: int = 512, num_heads: int = 8) -> None: super(MultiHeadAttention, self).__init__() assert dim % num_heads == 0, "hidden_dim % num_heads should be zero." self.d_head = int(dim / num_heads) self.num_heads = num_heads self.query_proj = Linear(dim, self.d_head * num_heads) self.key_proj = Linear(dim, self.d_head * num_heads) self.value_proj = Linear(dim, self.d_head * num_heads) self.scaled_dot_attn = DotProductAttention(dim, scale=True) def forward( self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: batch_size = value.size(0) query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2) key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2) value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2) if mask is not None: mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) context, attn = self.scaled_dot_attn(query, key, value, mask) context = context.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.d_head) return context, attn