Files
SISU/data.py
2025-11-18 21:34:59 +02:00

72 lines
2.4 KiB
Python

import os
import random
import torch
import torchaudio
import torchcodec.decoders as decoders
import tqdm
from torch.utils.data import Dataset
import AudioUtils
class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100]
def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True):
self.clip_length = clip_length
self.normalize = normalize
input_files = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if os.path.isfile(os.path.join(input_dir, f))
and f.lower().endswith((".wav", ".mp3", ".flac"))
]
data = []
for audio_clip in tqdm.tqdm(
input_files, desc=f"Processing {len(input_files)} audio file(s)"
):
decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data.float()
original_sample_rate = decoded_samples.sample_rate
if normalize:
audio = AudioUtils.normalize(audio)
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True)
if not splitted_high_quality_audio:
continue
for splitted_audio_clip in splitted_high_quality_audio:
for audio_clip in torch.split(splitted_audio_clip, 1):
data.append((audio_clip, original_sample_rate))
self.audio_data = data
def __len__(self):
return len(self.audio_data)
def __getitem__(self, idx):
audio_clip = self.audio_data[idx]
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(
audio_clip[1], mangled_sample_rate
)
resample_transform_high = torchaudio.transforms.Resample(
mangled_sample_rate, audio_clip[1]
)
low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0]))
if audio_clip[0].shape[1] < low_audio_clip.shape[1]:
low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]]
elif audio_clip[0].shape[1] > low_audio_clip.shape[1]:
low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length)
return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate))