⚗️ | Small fixes here and there
This commit is contained in:
@@ -8,13 +8,12 @@ def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor:
|
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor:
|
||||||
padding_amount = target_length - audio_tensor.size(-1)
|
current = audio_tensor.size(-1)
|
||||||
|
padding_amount = target_length - current
|
||||||
if padding_amount <= 0:
|
if padding_amount <= 0:
|
||||||
return audio_tensor
|
return audio_tensor
|
||||||
|
|
||||||
padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount))
|
return F.pad(audio_tensor, (0, padding_amount))
|
||||||
|
|
||||||
return padded_audio_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]:
|
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]:
|
||||||
|
|||||||
3
data.py
3
data.py
@@ -67,5 +67,6 @@ class AudioDataset(Dataset):
|
|||||||
if audio_clip[0].shape[1] < low_audio_clip.shape[1]:
|
if audio_clip[0].shape[1] < low_audio_clip.shape[1]:
|
||||||
low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]]
|
low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]]
|
||||||
elif audio_clip[0].shape[1] > low_audio_clip.shape[1]:
|
elif audio_clip[0].shape[1] > low_audio_clip.shape[1]:
|
||||||
low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length)
|
target_len = audio_clip[0].shape[1]
|
||||||
|
low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, target_len)
|
||||||
return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate))
|
return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate))
|
||||||
|
|||||||
@@ -34,7 +34,11 @@ args = parser.parse_args()
|
|||||||
# Init accelerator
|
# Init accelerator
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
|
|
||||||
accelerator = Accelerator(mixed_precision="bf16")
|
try:
|
||||||
|
accelerator = Accelerator(mixed_precision="bf16")
|
||||||
|
except Exception:
|
||||||
|
accelerator = Accelerator(mixed_precision="fp16")
|
||||||
|
accelerator.print("⚠️ | bf16 unavailable — falling back to fp16")
|
||||||
|
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
# Models
|
# Models
|
||||||
|
|||||||
@@ -44,13 +44,8 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
|||||||
|
|
||||||
|
|
||||||
def generator_adv_loss(disc_generated_outputs):
|
def generator_adv_loss(disc_generated_outputs):
|
||||||
"""
|
loss = 0.0
|
||||||
Least Squares GAN Loss for the Generator.
|
for dg in disc_generated_outputs:
|
||||||
Objective: Fake -> 1 (Fool the discriminator)
|
|
||||||
"""
|
|
||||||
loss = 0
|
|
||||||
for dg in zip(disc_generated_outputs):
|
|
||||||
dg = dg[0] # Unpack tuple
|
|
||||||
loss += torch.mean((dg - 1) ** 2)
|
loss += torch.mean((dg - 1) ** 2)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user