53 lines
1.7 KiB
Python
53 lines
1.7 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
def stereo_tensor_to_mono(waveform):
|
|
if waveform.shape[0] > 1:
|
|
# Average across channels
|
|
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
else:
|
|
# Already mono
|
|
mono_waveform = waveform
|
|
return mono_waveform
|
|
|
|
def stretch_tensor(tensor, target_length):
|
|
scale_factor = target_length / tensor.size(1)
|
|
|
|
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
|
|
|
|
return tensor
|
|
|
|
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128):
|
|
current_length = audio_tensor.shape[-1]
|
|
|
|
if current_length < target_length:
|
|
padding_needed = target_length - current_length
|
|
|
|
padding_tuple = (0, padding_needed)
|
|
padded_audio_tensor = F.pad(audio_tensor, padding_tuple, mode='constant', value=0)
|
|
else:
|
|
padded_audio_tensor = audio_tensor
|
|
|
|
return padded_audio_tensor
|
|
|
|
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch.Tensor]:
|
|
if not isinstance(chunk_size, int) or chunk_size <= 0:
|
|
raise ValueError("chunk_size must be a positive integer.")
|
|
|
|
# Handle scalar tensor edge case if necessary
|
|
if audio_tensor.dim() == 0:
|
|
return [audio_tensor] if audio_tensor.numel() > 0 else []
|
|
|
|
# Identify the dimension to split (usually the last one, representing time/samples)
|
|
split_dim = -1
|
|
num_samples = audio_tensor.shape[split_dim]
|
|
|
|
if num_samples == 0:
|
|
return [] # Return empty list if the dimension to split is empty
|
|
|
|
# Use torch.split to divide the tensor into chunks
|
|
# It handles the last chunk being potentially smaller automatically.
|
|
chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim))
|
|
|
|
return chunks
|