Compare commits

..

2 Commits

4 changed files with 13 additions and 12 deletions

View File

@@ -8,12 +8,13 @@ def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor:
current = audio_tensor.size(-1)
padding_amount = target_length - current
padding_amount = target_length - audio_tensor.size(-1)
if padding_amount <= 0:
return audio_tensor
return F.pad(audio_tensor, (0, padding_amount))
padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount))
return padded_audio_tensor
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]:

View File

@@ -67,6 +67,5 @@ class AudioDataset(Dataset):
if audio_clip[0].shape[1] < low_audio_clip.shape[1]:
low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]]
elif audio_clip[0].shape[1] > low_audio_clip.shape[1]:
target_len = audio_clip[0].shape[1]
low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, target_len)
low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length)
return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate))

View File

@@ -34,11 +34,7 @@ args = parser.parse_args()
# Init accelerator
# ---------------------------
try:
accelerator = Accelerator(mixed_precision="bf16")
except Exception:
accelerator = Accelerator(mixed_precision="fp16")
accelerator.print("⚠️ | bf16 unavailable — falling back to fp16")
# ---------------------------
# Models

View File

@@ -44,8 +44,13 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
def generator_adv_loss(disc_generated_outputs):
loss = 0.0
for dg in disc_generated_outputs:
"""
Least Squares GAN Loss for the Generator.
Objective: Fake -> 1 (Fool the discriminator)
"""
loss = 0
for dg in zip(disc_generated_outputs):
dg = dg[0] # Unpack tuple
loss += torch.mean((dg - 1) ** 2)
return loss