80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
import os
|
|
import random
|
|
|
|
import torchaudio
|
|
import torchcodec.decoders as decoders
|
|
import tqdm
|
|
from torch.utils.data import Dataset
|
|
|
|
import AudioUtils
|
|
|
|
|
|
class AudioDataset(Dataset):
|
|
audio_sample_rates = [11025]
|
|
|
|
def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True):
|
|
self.clip_length = clip_length
|
|
self.normalize = normalize
|
|
|
|
input_files = [
|
|
os.path.join(input_dir, f)
|
|
for f in os.listdir(input_dir)
|
|
if os.path.isfile(os.path.join(input_dir, f))
|
|
and f.lower().endswith((".wav", ".mp3", ".flac"))
|
|
]
|
|
|
|
data = []
|
|
for audio_clip in tqdm.tqdm(
|
|
input_files, desc=f"Processing {len(input_files)} audio file(s)"
|
|
):
|
|
decoder = decoders.AudioDecoder(audio_clip)
|
|
decoded_samples = decoder.get_all_samples()
|
|
|
|
audio = decoded_samples.data.float() # ensure float32
|
|
original_sample_rate = decoded_samples.sample_rate
|
|
|
|
audio = AudioUtils.stereo_tensor_to_mono(audio)
|
|
if normalize:
|
|
audio = AudioUtils.normalize(audio)
|
|
|
|
mangled_sample_rate = random.choice(self.audio_sample_rates)
|
|
resample_transform_low = torchaudio.transforms.Resample(
|
|
original_sample_rate, mangled_sample_rate
|
|
)
|
|
resample_transform_high = torchaudio.transforms.Resample(
|
|
mangled_sample_rate, original_sample_rate
|
|
)
|
|
|
|
low_audio = resample_transform_high(resample_transform_low(audio))
|
|
|
|
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length)
|
|
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
|
|
|
|
if not splitted_high_quality_audio or not splitted_low_quality_audio:
|
|
continue # skip empty or invalid clips
|
|
|
|
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(
|
|
splitted_high_quality_audio[-1], clip_length
|
|
)
|
|
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(
|
|
splitted_low_quality_audio[-1], clip_length
|
|
)
|
|
|
|
for high_quality_data, low_quality_data in zip(
|
|
splitted_high_quality_audio, splitted_low_quality_audio
|
|
):
|
|
data.append(
|
|
(
|
|
(high_quality_data, low_quality_data),
|
|
(original_sample_rate, mangled_sample_rate),
|
|
)
|
|
)
|
|
|
|
self.audio_data = data
|
|
|
|
def __len__(self):
|
|
return len(self.audio_data)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.audio_data[idx]
|