⚗️ | Added some stupid ways for training + some makeup
This commit is contained in:
@@ -1,71 +1,97 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def stereo_tensor_to_mono(waveform):
|
||||
|
||||
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert stereo (C, N) to mono (1, N). Ensures a channel dimension.
|
||||
"""
|
||||
if waveform.dim() == 1:
|
||||
waveform = waveform.unsqueeze(0) # (N,) -> (1, N)
|
||||
|
||||
if waveform.shape[0] > 1:
|
||||
# Average across channels
|
||||
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N)
|
||||
else:
|
||||
# Already mono
|
||||
mono_waveform = waveform
|
||||
return mono_waveform
|
||||
|
||||
def stretch_tensor(tensor, target_length):
|
||||
scale_factor = target_length / tensor.size(1)
|
||||
|
||||
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
|
||||
def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor:
|
||||
"""
|
||||
Stretch audio along time dimension to target_length.
|
||||
Input assumed (1, N). Returns (1, target_length).
|
||||
"""
|
||||
if tensor.dim() == 1:
|
||||
tensor = tensor.unsqueeze(0) # ensure (1, N)
|
||||
|
||||
return tensor
|
||||
tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate
|
||||
stretched = F.interpolate(
|
||||
tensor, size=target_length, mode="linear", align_corners=False
|
||||
)
|
||||
return stretched.squeeze(0) # back to (1, target_length)
|
||||
|
||||
|
||||
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor:
|
||||
"""
|
||||
Pad to fixed length. Input assumed (1, N). Returns (1, target_length).
|
||||
"""
|
||||
if audio_tensor.dim() == 1:
|
||||
audio_tensor = audio_tensor.unsqueeze(0)
|
||||
|
||||
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128):
|
||||
current_length = audio_tensor.shape[-1]
|
||||
|
||||
if current_length < target_length:
|
||||
padding_needed = target_length - current_length
|
||||
|
||||
padding_tuple = (0, padding_needed)
|
||||
padded_audio_tensor = F.pad(audio_tensor, padding_tuple, mode='constant', value=0)
|
||||
padded_audio_tensor = F.pad(
|
||||
audio_tensor, padding_tuple, mode="constant", value=0
|
||||
)
|
||||
else:
|
||||
padded_audio_tensor = audio_tensor
|
||||
padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long
|
||||
|
||||
return padded_audio_tensor
|
||||
|
||||
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch.Tensor]:
|
||||
|
||||
def split_audio(
|
||||
audio_tensor: torch.Tensor, chunk_size: int = 128
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Split into chunks of (1, chunk_size).
|
||||
"""
|
||||
if not isinstance(chunk_size, int) or chunk_size <= 0:
|
||||
raise ValueError("chunk_size must be a positive integer.")
|
||||
|
||||
# Handle scalar tensor edge case if necessary
|
||||
if audio_tensor.dim() == 0:
|
||||
return [audio_tensor] if audio_tensor.numel() > 0 else []
|
||||
|
||||
# Identify the dimension to split (usually the last one, representing time/samples)
|
||||
split_dim = -1
|
||||
num_samples = audio_tensor.shape[split_dim]
|
||||
if audio_tensor.dim() == 1:
|
||||
audio_tensor = audio_tensor.unsqueeze(0)
|
||||
|
||||
num_samples = audio_tensor.shape[-1]
|
||||
if num_samples == 0:
|
||||
return [] # Return empty list if the dimension to split is empty
|
||||
|
||||
# Use torch.split to divide the tensor into chunks
|
||||
# It handles the last chunk being potentially smaller automatically.
|
||||
chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim))
|
||||
return []
|
||||
|
||||
chunks = list(torch.split(audio_tensor, chunk_size, dim=-1))
|
||||
return chunks
|
||||
|
||||
|
||||
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Reconstruct audio from chunks. Returns (1, N).
|
||||
"""
|
||||
if not chunks:
|
||||
return torch.empty(0)
|
||||
|
||||
if len(chunks) == 1 and chunks[0].dim() == 0:
|
||||
return chunks[0]
|
||||
|
||||
concat_dim = -1
|
||||
return torch.empty(1, 0)
|
||||
|
||||
chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks]
|
||||
try:
|
||||
reconstructed_tensor = torch.cat(chunks, dim=concat_dim)
|
||||
reconstructed_tensor = torch.cat(chunks, dim=-1)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
|
||||
f"for concatenation along dimension {concat_dim}. Original error: {e}"
|
||||
f"for concatenation along dim -1. Original error: {e}"
|
||||
)
|
||||
|
||||
return reconstructed_tensor
|
||||
|
||||
|
||||
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||||
max_val = torch.max(torch.abs(audio_tensor))
|
||||
if max_val < eps:
|
||||
return audio_tensor # silence, skip normalization
|
||||
return audio_tensor / max_val
|
||||
|
Reference in New Issue
Block a user