import torch import torch.nn.functional as F def stereo_tensor_to_mono(waveform): if waveform.shape[0] > 1: # Average across channels mono_waveform = torch.mean(waveform, dim=0, keepdim=True) else: # Already mono mono_waveform = waveform return mono_waveform def stretch_tensor(tensor, target_length): scale_factor = target_length / tensor.size(1) tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False) return tensor