From 571c403b93e693fb923b3b419ed6ad4b9067abb0 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Thu, 11 Dec 2025 23:06:38 +0200 Subject: [PATCH] :alembic: | Small fixes here and there --- AudioUtils.py | 7 +++---- data.py | 3 ++- training.py | 6 +++++- utils/TrainingTools.py | 9 ++------- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/AudioUtils.py b/AudioUtils.py index ff6a24f..bff5de8 100644 --- a/AudioUtils.py +++ b/AudioUtils.py @@ -8,13 +8,12 @@ def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor: def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor: - padding_amount = target_length - audio_tensor.size(-1) + current = audio_tensor.size(-1) + padding_amount = target_length - current if padding_amount <= 0: return audio_tensor - padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount)) - - return padded_audio_tensor + return F.pad(audio_tensor, (0, padding_amount)) def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]: diff --git a/data.py b/data.py index 13b0ca8..f44a3a3 100644 --- a/data.py +++ b/data.py @@ -67,5 +67,6 @@ class AudioDataset(Dataset): if audio_clip[0].shape[1] < low_audio_clip.shape[1]: low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]] elif audio_clip[0].shape[1] > low_audio_clip.shape[1]: - low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length) + target_len = audio_clip[0].shape[1] + low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, target_len) return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate)) diff --git a/training.py b/training.py index 8107b03..f4c57c6 100644 --- a/training.py +++ b/training.py @@ -34,7 +34,11 @@ args = parser.parse_args() # Init accelerator # --------------------------- -accelerator = Accelerator(mixed_precision="bf16") +try: + accelerator = Accelerator(mixed_precision="bf16") +except Exception: + accelerator = Accelerator(mixed_precision="fp16") + accelerator.print("⚠️ | bf16 unavailable — falling back to fp16") # --------------------------- # Models diff --git a/utils/TrainingTools.py b/utils/TrainingTools.py index cd0350a..42d11dc 100644 --- a/utils/TrainingTools.py +++ b/utils/TrainingTools.py @@ -44,13 +44,8 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs): def generator_adv_loss(disc_generated_outputs): - """ - Least Squares GAN Loss for the Generator. - Objective: Fake -> 1 (Fool the discriminator) - """ - loss = 0 - for dg in zip(disc_generated_outputs): - dg = dg[0] # Unpack tuple + loss = 0.0 + for dg in disc_generated_outputs: loss += torch.mean((dg - 1) ** 2) return loss