# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch.nn as nn
from torch import Tensor
from typing import Tuple, Optional
from openspeech.modules.wrapper import Linear
from openspeech.modules.dot_product_attention import DotProductAttention
[docs]class MultiHeadAttention(nn.Module):
r"""
Multi-Head Attention proposed in "Attention Is All You Need"
Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
These are concatenated and once again projected, resulting in the final values.
Multi-head attention allows the model to jointly attend to information from different representation
subspaces at different positions.
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
where head_i = Attention(Q · W_q, K · W_k, V · W_v)
Args:
dim (int): The dimension of model (default: 512)
num_heads (int): The number of attention heads. (default: 8)
Inputs: query, key, value, mask
- **query** (batch, q_len, d_model): tensor containing projection vector for decoders.
- **key** (batch, k_len, d_model): tensor containing projection vector for encoders.
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
- **mask** (-): tensor containing indices to be masked
Returns: output, attn
- **output** (batch, output_len, dimensions): tensor containing the attended output features.
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoders outputs.
"""
def __init__(self, dim: int = 512, num_heads: int = 8) -> None:
super(MultiHeadAttention, self).__init__()
assert dim % num_heads == 0, "hidden_dim % num_heads should be zero."
self.d_head = int(dim / num_heads)
self.num_heads = num_heads
self.query_proj = Linear(dim, self.d_head * num_heads)
self.key_proj = Linear(dim, self.d_head * num_heads)
self.value_proj = Linear(dim, self.d_head * num_heads)
self.scaled_dot_attn = DotProductAttention(dim, scale=True)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
batch_size = value.size(0)
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
context, attn = self.scaled_dot_attn(query, key, value, mask)
context = context.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.d_head)
return context, attn