SISU/data.py
2024-12-25 00:09:57 +02:00

32 lines
1.1 KiB
Python

from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import torchaudio
import os
import random
import torchaudio.transforms as T
import AudioUtils
class AudioDataset(Dataset):
#audio_sample_rates = [8000, 11025, 16000, 22050]
audio_sample_rates = [11025]
def __init__(self, input_dir):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')]
def __len__(self):
return len(self.input_files)
def __getitem__(self, idx):
# Load high-quality audio
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform(high_quality_audio)
return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate)