# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import tarfile
import wget
import pytorch_lightning as pl
import logging
from omegaconf import DictConfig
from typing import Optional, Tuple
from torch.utils.data import DataLoader
from openspeech.data.audio.dataset import SpeechToTextDataset
from openspeech.datasets import register_data_module
from openspeech.data.sampler import BucketingSampler
from openspeech.data.audio.data_loader import AudioDataLoader
from openspeech.tokenizers.tokenizer import Tokenizer
from openspeech.tokenizers import TOKENIZER_REGISTRY
from openspeech.datasets.aishell.preprocess import (
generate_character_labels,
generate_character_script,
)
[docs]@register_data_module('aishell')
class LightningAIShellDataModule(pl.LightningDataModule):
r"""
Lightning data module for AIShell-1. The corpus includes training set, development set and test sets.
Training set contains 120,098 utterances from 340 speakers; development set contains 14,326 utterance from
the 40 speakers; Test set contains 7,176 utterances from 20 speakers. For each speaker, around 360 utterances
(about 26 minutes of speech) are released.
Args:
configs (DictConfig): configuration set.
"""
AISHELL_TRAIN_NUM = 120098
AISHELL_VALID_NUM = 14326
AISHELL_TEST_NUM = 7176
def __init__(self, configs: DictConfig):
super(LightningAIShellDataModule, self).__init__()
self.configs = configs
self.dataset = dict()
self.logger = logging.getLogger(__name__)
def _download_dataset(self) -> None:
r""" Download aishell dataset. """
url = "https://www.openslr.org/resources/33/data_aishell.tgz"
if not os.path.exists(self.configs.dataset.dataset_path):
os.mkdir(self.configs.dataset.dataset_path)
wget.download(url, f"{self.configs.dataset.dataset_path}/data_aishell.tgz")
self.logger.info(f"Un-tarring archive {self.configs.dataset.dataset_path}/data_aishell.tgz")
tar = tarfile.open(f"{self.configs.dataset.dataset_path}/data_aishell.tgz", mode="r:gz")
tar.extractall(self.configs.dataset.dataset_path)
tar.close()
os.remove(f"{self.configs.dataset.dataset_path}/data_aishell.tgz")
self.configs.dataset.dataset_path = os.path.join(self.configs.dataset.dataset_path, "data_aishell")
def _generate_manifest_files(self, manifest_file_path: str) -> None:
generate_character_labels(
dataset_path=self.configs.dataset.dataset_path,
vocab_path=self.configs.tokenizer.vocab_path,
)
generate_character_script(
dataset_path=self.configs.dataset.dataset_path,
manifest_file_path=manifest_file_path,
vocab_path=self.configs.tokenizer.vocab_path,
)
def _parse_manifest_file(self, manifest_file_path: str) -> Tuple[list, list]:
""" Parsing manifest file """
audio_paths = list()
transcripts = list()
with open(manifest_file_path) as f:
for idx, line in enumerate(f.readlines()):
audio_path, _, transcript = line.split('\t')
transcript = transcript.replace('\n', '')
audio_paths.append(audio_path)
transcripts.append(transcript)
return audio_paths, transcripts
[docs] def prepare_data(self):
r"""
Prepare AI-Shell manifest file. If there is not exist manifest file, generate manifest file.
Returns:
tokenizer (Tokenizer): tokenizer is in charge of preparing the inputs for a model.
"""
if self.configs.dataset.dataset_download:
self._download_dataset()
if not os.path.exists(self.configs.dataset.manifest_file_path):
self.logger.info("Manifest file is not exists !!\n"
"Generate manifest files..")
if not os.path.exists(self.configs.dataset.dataset_path):
raise ValueError("Dataset path is not valid.")
self._generate_manifest_files(self.configs.dataset.manifest_file_path)
return TOKENIZER_REGISTRY[self.configs.tokenizer.unit](self.configs)
[docs] def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None):
r"""
Split `train` and `valid` dataset for training.
Args:
stage (str): stage of training. `train` or `valid`
tokenizer (Tokenizer): tokenizer is in charge of preparing the inputs for a model.
Returns:
None
"""
valid_end_idx = self.AISHELL_TRAIN_NUM + self.AISHELL_VALID_NUM
audio_paths, transcripts = self._parse_manifest_file(self.configs.dataset.manifest_file_path)
audio_paths = {
"train": audio_paths[:self.AISHELL_TRAIN_NUM],
"valid": audio_paths[self.AISHELL_TRAIN_NUM:valid_end_idx],
"test": audio_paths[valid_end_idx:],
}
transcripts = {
"train": transcripts[:self.AISHELL_TRAIN_NUM],
"valid": transcripts[self.AISHELL_TRAIN_NUM:valid_end_idx],
"test": transcripts[valid_end_idx:],
}
for stage in audio_paths.keys():
self.dataset[stage] = SpeechToTextDataset(
configs=self.configs,
dataset_path=self.configs.dataset.dataset_path,
audio_paths=audio_paths[stage],
transcripts=transcripts[stage],
sos_id=tokenizer.sos_id,
eos_id=tokenizer.eos_id,
apply_spec_augment=self.configs.audio.apply_spec_augment if stage == 'train' else False,
del_silence=self.configs.audio.del_silence if stage == 'train' else False,
)
[docs] def train_dataloader(self) -> DataLoader:
train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['train'],
num_workers=self.configs.trainer.num_workers,
batch_sampler=train_sampler,
)
[docs] def val_dataloader(self) -> DataLoader:
valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['valid'],
num_workers=self.configs.trainer.num_workers,
batch_sampler=valid_sampler,
)
[docs] def test_dataloader(self) -> DataLoader:
test_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['test'],
num_workers=self.configs.trainer.num_workers,
batch_sampler=test_sampler,
)