| Made training bit... spicier.

This commit is contained in:
2025-09-10 19:52:53 +03:00
parent ff38cefdd3
commit 0bc8fc2792
8 changed files with 581 additions and 303 deletions

59
data.py
View File

@@ -1,41 +1,68 @@
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import torchaudio
import os
import random
import torchaudio.transforms as T
import torchaudio
import torchcodec.decoders as decoders
import tqdm
from torch.utils.data import Dataset
import AudioUtils
class AudioDataset(Dataset):
audio_sample_rates = [11025]
def __init__(self, input_dir, device, clip_length = 1024):
self.device = device
input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')]
def __init__(self, input_dir, clip_length=16384):
input_files = [
os.path.join(root, f)
for root, _, files in os.walk(input_dir)
for f in files
if f.endswith(".wav") or f.endswith(".mp3") or f.endswith(".flac")
]
data = []
for audio_clip in tqdm.tqdm(input_files, desc=f"Processing {len(input_files)} audio file(s)"):
audio, original_sample_rate = torchaudio.load(audio_clip, normalize=True)
for audio_clip in tqdm.tqdm(
input_files, desc=f"Processing {len(input_files)} audio file(s)"
):
decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
resample_transform_low = torchaudio.transforms.Resample(
original_sample_rate, mangled_sample_rate
)
resample_transform_high = torchaudio.transforms.Resample(
mangled_sample_rate, original_sample_rate
)
low_audio = resample_transform_low(audio)
low_audio = resample_transform_high(low_audio)
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length)
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], clip_length)
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_high_quality_audio[-1], clip_length
)
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], clip_length)
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(
splitted_low_quality_audio[-1], clip_length
)
for high_quality_sample, low_quality_sample in zip(splitted_high_quality_audio, splitted_low_quality_audio):
data.append(((high_quality_sample, low_quality_sample), (original_sample_rate, mangled_sample_rate)))
for high_quality_sample, low_quality_sample in zip(
splitted_high_quality_audio, splitted_low_quality_audio
):
data.append(
(
(high_quality_sample, low_quality_sample),
(original_sample_rate, mangled_sample_rate),
)
)
self.audio_data = data