⚗️ | More architectural changes
This commit is contained in:
@@ -3,95 +3,39 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert stereo (C, N) to mono (1, N). Ensures a channel dimension.
|
||||
"""
|
||||
if waveform.dim() == 1:
|
||||
waveform = waveform.unsqueeze(0) # (N,) -> (1, N)
|
||||
|
||||
if waveform.shape[0] > 1:
|
||||
mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N)
|
||||
else:
|
||||
mono_waveform = waveform
|
||||
return mono_waveform
|
||||
mono_tensor = torch.mean(waveform, dim=0, keepdim=True)
|
||||
return mono_tensor
|
||||
|
||||
|
||||
def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor:
|
||||
"""
|
||||
Stretch audio along time dimension to target_length.
|
||||
Input assumed (1, N). Returns (1, target_length).
|
||||
"""
|
||||
if tensor.dim() == 1:
|
||||
tensor = tensor.unsqueeze(0) # ensure (1, N)
|
||||
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor:
|
||||
padding_amount = target_length - audio_tensor.size(-1)
|
||||
if padding_amount <= 0:
|
||||
return audio_tensor
|
||||
|
||||
tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate
|
||||
stretched = F.interpolate(
|
||||
tensor, size=target_length, mode="linear", align_corners=False
|
||||
)
|
||||
return stretched.squeeze(0) # back to (1, target_length)
|
||||
|
||||
|
||||
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor:
|
||||
"""
|
||||
Pad to fixed length. Input assumed (1, N). Returns (1, target_length).
|
||||
"""
|
||||
if audio_tensor.dim() == 1:
|
||||
audio_tensor = audio_tensor.unsqueeze(0)
|
||||
|
||||
current_length = audio_tensor.shape[-1]
|
||||
if current_length < target_length:
|
||||
padding_needed = target_length - current_length
|
||||
padding_tuple = (0, padding_needed)
|
||||
padded_audio_tensor = F.pad(
|
||||
audio_tensor, padding_tuple, mode="constant", value=0
|
||||
)
|
||||
else:
|
||||
padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long
|
||||
padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount))
|
||||
|
||||
return padded_audio_tensor
|
||||
|
||||
|
||||
def split_audio(
|
||||
audio_tensor: torch.Tensor, chunk_size: int = 128
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Split into chunks of (1, chunk_size).
|
||||
"""
|
||||
if not isinstance(chunk_size, int) or chunk_size <= 0:
|
||||
raise ValueError("chunk_size must be a positive integer.")
|
||||
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]:
|
||||
chunks = list(torch.split(audio_tensor, chunk_size, dim=1))
|
||||
|
||||
if audio_tensor.dim() == 1:
|
||||
audio_tensor = audio_tensor.unsqueeze(0)
|
||||
if pad_last_tensor:
|
||||
last_chunk = chunks[-1]
|
||||
|
||||
num_samples = audio_tensor.shape[-1]
|
||||
if num_samples == 0:
|
||||
return []
|
||||
if last_chunk.size(-1) < chunk_size:
|
||||
chunks[-1] = pad_tensor(last_chunk, chunk_size)
|
||||
|
||||
chunks = list(torch.split(audio_tensor, chunk_size, dim=-1))
|
||||
return chunks
|
||||
|
||||
|
||||
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Reconstruct audio from chunks. Returns (1, N).
|
||||
"""
|
||||
if not chunks:
|
||||
return torch.empty(1, 0)
|
||||
|
||||
chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks]
|
||||
try:
|
||||
reconstructed_tensor = torch.cat(chunks, dim=-1)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
|
||||
f"for concatenation along dim -1. Original error: {e}"
|
||||
)
|
||||
|
||||
reconstructed_tensor = torch.cat(chunks, dim=-1)
|
||||
return reconstructed_tensor
|
||||
|
||||
|
||||
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||||
max_val = torch.max(torch.abs(audio_tensor))
|
||||
if max_val < eps:
|
||||
return audio_tensor # silence, skip normalization
|
||||
return audio_tensor
|
||||
return audio_tensor / max_val
|
||||
|
||||
Reference in New Issue
Block a user