SISU/data.py

46 lines
2.0 KiB
Python

from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import torchaudio
import os
import random
import torchaudio.transforms as T
import AudioUtils
class AudioDataset(Dataset):
audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here
def __init__(self, input_dir, device):
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
self.device = device
def __len__(self):
return len(self.input_files)
def __getitem__(self, idx):
# Load high-quality audio
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
# Change to mono
high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio = resample_transform_low(high_quality_audio)
low_quality_audio = resample_transform_high(low_quality_audio)
splitted_high_quality_audio = AudioUtils.split_audio(high_quality_audio, 128)
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], 128)
splitted_high_quality_audio = [tensor.to(self.device) for tensor in splitted_high_quality_audio]
splitted_low_quality_audio = AudioUtils.split_audio(low_quality_audio, 128)
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], 128)
splitted_low_quality_audio = [tensor.to(self.device) for tensor in splitted_low_quality_audio]
return (splitted_high_quality_audio, original_sample_rate), (splitted_low_quality_audio, mangled_sample_rate)