Merge new-arch, because it has proven to give the best results #1
8
data.py
8
data.py
@ -12,8 +12,9 @@ class AudioDataset(Dataset):
|
||||
#audio_sample_rates = [8000, 11025, 16000, 22050]
|
||||
audio_sample_rates = [11025]
|
||||
|
||||
def __init__(self, input_dir):
|
||||
def __init__(self, input_dir, device):
|
||||
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
|
||||
self.device = device
|
||||
|
||||
|
||||
def __len__(self):
|
||||
@ -32,4 +33,7 @@ class AudioDataset(Dataset):
|
||||
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
|
||||
low_quality_audio = resample_transform_high(low_quality_audio)
|
||||
|
||||
return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate)
|
||||
high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio).to(self.device)
|
||||
low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio).to(self.device)
|
||||
|
||||
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
|
||||
|
@ -4,11 +4,11 @@ Jinja2==3.1.4
|
||||
MarkupSafe==2.1.5
|
||||
mpmath==1.3.0
|
||||
networkx==3.4.2
|
||||
numpy==2.2.1
|
||||
pytorch-triton-rocm==3.2.0+git0d4682f0
|
||||
numpy==2.2.3
|
||||
pytorch-triton-rocm==3.2.0+git4b3bb1f8
|
||||
setuptools==70.2.0
|
||||
sympy==1.13.1
|
||||
torch==2.6.0.dev20241222+rocm6.2.4
|
||||
torchaudio==2.6.0.dev20241222+rocm6.2.4
|
||||
sympy==1.13.3
|
||||
torch==2.7.0.dev20250226+rocm6.3
|
||||
torchaudio==2.6.0.dev20250226+rocm6.3
|
||||
tqdm==4.67.1
|
||||
typing_extensions==4.12.2
|
||||
|
95
training.py
95
training.py
@ -20,49 +20,35 @@ from data import AudioDataset
|
||||
from generator import SISUGenerator
|
||||
from discriminator import SISUDiscriminator
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torchaudio.transforms as T
|
||||
|
||||
def mfcc_loss(y_true, y_pred, sr):
|
||||
"""Calculates MFCC loss between two audio signals.
|
||||
# Init script argument parser
|
||||
parser = argparse.ArgumentParser(description="Training script")
|
||||
parser.add_argument("--generator", type=str, default=None,
|
||||
help="Path to the generator model file")
|
||||
parser.add_argument("--discriminator", type=str, default=None,
|
||||
help="Path to the discriminator model file")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
||||
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
|
||||
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
|
||||
|
||||
Args:
|
||||
y_true: Target audio signal (PyTorch tensor).
|
||||
y_pred: Predicted audio signal (PyTorch tensor).
|
||||
sr: Sample rate (NumPy scalar).
|
||||
args = parser.parse_args()
|
||||
|
||||
Returns:
|
||||
MFCC loss (PyTorch tensor).
|
||||
"""
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# 1. Ensure sr is a NumPy scalar (not a Tensor)
|
||||
if isinstance(sr, torch.Tensor):
|
||||
sr = sr.item()
|
||||
|
||||
# 2. Convert y_true and y_pred to NumPy arrays (and detach from graph)
|
||||
y_true_np = y_true.cpu().detach().numpy()[0] # .cpu() and .detach() are crucial!
|
||||
y_pred_np = y_pred.cpu().detach().numpy()[0]
|
||||
|
||||
# 3. Dynamically calculate n_fft based on signal length
|
||||
signal_length = min(y_true_np.shape[0], y_pred_np.shape[0]) # Use shortest signal length
|
||||
n_fft = min(2048, 2**int(np.log2(signal_length))) # Power of 2, up to 2048
|
||||
|
||||
# 4. Calculate MFCCs using adjusted n_fft
|
||||
mfccs_true = librosa.feature.mfcc(y=y_true_np, sr=sr, n_fft=n_fft, n_mfcc=20)
|
||||
mfccs_pred = librosa.feature.mfcc(y=y_pred_np, sr=sr, n_fft=n_fft, n_mfcc=20)
|
||||
|
||||
# 5. Truncate MFCCs to the same length (important!)
|
||||
len_true = mfccs_true.shape[1]
|
||||
len_pred = mfccs_pred.shape[1]
|
||||
min_len = min(len_true, len_pred)
|
||||
|
||||
mfccs_true = mfccs_true[:, :min_len]
|
||||
mfccs_pred = mfccs_pred[:, :min_len]
|
||||
|
||||
# 6. Convert MFCCs back to PyTorch tensors and ensure correct device
|
||||
mfccs_true = torch.tensor(mfccs_true, device=y_true.device, dtype=torch.float32)
|
||||
mfccs_pred = torch.tensor(mfccs_pred, device=y_pred.device, dtype=torch.float32)
|
||||
mfcc_transform = T.MFCC(
|
||||
sample_rate=16000, # Adjust to your sample rate
|
||||
n_mfcc=20,
|
||||
melkwargs={'n_fft': 2048, 'hop_length': 512} # adjust n_fft and hop_length to your needs.
|
||||
).to(device)
|
||||
|
||||
def gpu_mfcc_loss(y_true, y_pred):
|
||||
mfccs_true = mfcc_transform(y_true)
|
||||
mfccs_pred = mfcc_transform(y_pred)
|
||||
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
|
||||
mfccs_true = mfccs_true[:, :, :min_len]
|
||||
mfccs_pred = mfccs_pred[:, :, :min_len]
|
||||
return torch.mean((mfccs_true - mfccs_pred)**2)
|
||||
|
||||
def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
|
||||
@ -93,7 +79,7 @@ def generator_train(low_quality, high_quality, real_labels):
|
||||
# Forward pass for fake samples (from generator output)
|
||||
generator_output = generator(low_quality[0])
|
||||
|
||||
mfcc_l = mfcc_loss(high_quality[0], generator_output, high_quality[1])
|
||||
mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
|
||||
|
||||
discriminator_decision = discriminator(generator_output)
|
||||
adversarial_loss = criterion_g(discriminator_decision, real_labels)
|
||||
@ -105,26 +91,11 @@ def generator_train(low_quality, high_quality, real_labels):
|
||||
|
||||
return (generator_output, combined_loss, adversarial_loss, mfcc_l)
|
||||
|
||||
# Init script argument parser
|
||||
parser = argparse.ArgumentParser(description="Training script")
|
||||
parser.add_argument("--generator", type=str, default=None,
|
||||
help="Path to the generator model file")
|
||||
parser.add_argument("--discriminator", type=str, default=None,
|
||||
help="Path to the discriminator model file")
|
||||
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
|
||||
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check for CUDA availability
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
debug = args.verbose
|
||||
|
||||
# Initialize dataset and dataloader
|
||||
dataset_dir = './dataset/good'
|
||||
dataset = AudioDataset(dataset_dir)
|
||||
dataset = AudioDataset(dataset_dir, device)
|
||||
|
||||
# ========= MULTIPLE =========
|
||||
|
||||
@ -147,14 +118,14 @@ discriminator = SISUDiscriminator()
|
||||
|
||||
epoch: int = args.epoch
|
||||
|
||||
if args.generator is not None:
|
||||
generator.load_state_dict(torch.load(args.generator, weights_only=True))
|
||||
if args.discriminator is not None:
|
||||
discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True))
|
||||
|
||||
generator = generator.to(device)
|
||||
discriminator = discriminator.to(device)
|
||||
|
||||
if args.generator is not None:
|
||||
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
|
||||
if args.discriminator is not None:
|
||||
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
|
||||
|
||||
# Loss
|
||||
criterion_g = nn.MSELoss()
|
||||
criterion_d = nn.BCELoss()
|
||||
@ -182,8 +153,8 @@ def start_training():
|
||||
# ========= TRAINING =========
|
||||
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
|
||||
# for high_quality_clip, low_quality_clip in train_data_loader:
|
||||
high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1])
|
||||
low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1])
|
||||
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
|
||||
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
|
||||
|
||||
# ========= LABELS =========
|
||||
batch_size = high_quality_clip[0].size(0)
|
||||
|
Loading…
Reference in New Issue
Block a user