🐛 | Fixed model and training

This commit is contained in:
2024-12-18 18:08:44 +02:00
parent 3bcc356eef
commit e43b2ab7ef
3 changed files with 71 additions and 52 deletions

31
data.py
View File

@ -1,9 +1,13 @@
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchaudio
import os
import random
class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 16000, 22050]
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')]
self.target_duration = target_duration # Duration in seconds or None if not set
@ -17,29 +21,30 @@ class AudioDataset(Dataset):
def __getitem__(self, idx):
high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True)
resample_transform = torchaudio.transforms.Resample(sr_original, 16000)
sample_rate = random.choice(self.audio_sample_rates)
resample_transform = torchaudio.transforms.Resample(sr_original, sample_rate)
low_quality_wav = resample_transform(high_quality_wav)
low_quality_wav = -low_quality_wav
# Calculate target length based on desired duration and 16000 Hz
if self.target_duration is not None:
target_length = int(self.target_duration * 44100)
else:
# Calculate duration of original high quality audio
duration_original = high_quality_wav.shape[1] / sr_original
target_length = int(duration_original * 16000)
target_length = high_quality_wav.size(1)
# Pad both to the calculated target length
high_quality_wav = self.pad_tensor(high_quality_wav, target_length)
low_quality_wav = self.pad_tensor(low_quality_wav, target_length)
high_quality_wav = self.stretch_tensor(high_quality_wav, target_length)
low_quality_wav = self.stretch_tensor(low_quality_wav, target_length)
return low_quality_wav, high_quality_wav
def pad_tensor(self, tensor, target_length):
def stretch_tensor(self, tensor, target_length):
current_length = tensor.size(1)
if current_length < target_length:
padding_amount = target_length - current_length
padding = (0, padding_amount)
tensor = torch.nn.functional.pad(tensor, padding, mode=self.padding_mode, value=self.padding_value)
elif current_length > target_length:
tensor = tensor[:, :target_length]
scale_factor = target_length / current_length
# Resample the tensor using linear interpolation
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0)
return tensor