Files
SISU/AudioUtils.py

98 lines
3.0 KiB
Python

import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
"""
Convert stereo (C, N) to mono (1, N). Ensures a channel dimension.
"""
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0) # (N,) -> (1, N)
if waveform.shape[0] > 1:
mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N)
else:
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor:
"""
Stretch audio along time dimension to target_length.
Input assumed (1, N). Returns (1, target_length).
"""
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0) # ensure (1, N)
tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate
stretched = F.interpolate(
tensor, size=target_length, mode="linear", align_corners=False
)
return stretched.squeeze(0) # back to (1, target_length)
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor:
"""
Pad to fixed length. Input assumed (1, N). Returns (1, target_length).
"""
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
current_length = audio_tensor.shape[-1]
if current_length < target_length:
padding_needed = target_length - current_length
padding_tuple = (0, padding_needed)
padded_audio_tensor = F.pad(
audio_tensor, padding_tuple, mode="constant", value=0
)
else:
padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long
return padded_audio_tensor
def split_audio(
audio_tensor: torch.Tensor, chunk_size: int = 128
) -> list[torch.Tensor]:
"""
Split into chunks of (1, chunk_size).
"""
if not isinstance(chunk_size, int) or chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer.")
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
num_samples = audio_tensor.shape[-1]
if num_samples == 0:
return []
chunks = list(torch.split(audio_tensor, chunk_size, dim=-1))
return chunks
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
"""
Reconstruct audio from chunks. Returns (1, N).
"""
if not chunks:
return torch.empty(1, 0)
chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks]
try:
reconstructed_tensor = torch.cat(chunks, dim=-1)
except RuntimeError as e:
raise RuntimeError(
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
f"for concatenation along dim -1. Original error: {e}"
)
return reconstructed_tensor
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
max_val = torch.max(torch.abs(audio_tensor))
if max_val < eps:
return audio_tensor # silence, skip normalization
return audio_tensor / max_val