Merge new-arch, because it has proven to give the best results #1
8
data.py
8
data.py
@ -12,8 +12,9 @@ class AudioDataset(Dataset):
|
|||||||
#audio_sample_rates = [8000, 11025, 16000, 22050]
|
#audio_sample_rates = [8000, 11025, 16000, 22050]
|
||||||
audio_sample_rates = [11025]
|
audio_sample_rates = [11025]
|
||||||
|
|
||||||
def __init__(self, input_dir):
|
def __init__(self, input_dir, device):
|
||||||
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
|
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -32,4 +33,7 @@ class AudioDataset(Dataset):
|
|||||||
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
|
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
|
||||||
low_quality_audio = resample_transform_high(low_quality_audio)
|
low_quality_audio = resample_transform_high(low_quality_audio)
|
||||||
|
|
||||||
return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate)
|
high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio).to(self.device)
|
||||||
|
low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio).to(self.device)
|
||||||
|
|
||||||
|
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
|
||||||
|
@ -4,11 +4,11 @@ Jinja2==3.1.4
|
|||||||
MarkupSafe==2.1.5
|
MarkupSafe==2.1.5
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
networkx==3.4.2
|
networkx==3.4.2
|
||||||
numpy==2.2.1
|
numpy==2.2.3
|
||||||
pytorch-triton-rocm==3.2.0+git0d4682f0
|
pytorch-triton-rocm==3.2.0+git4b3bb1f8
|
||||||
setuptools==70.2.0
|
setuptools==70.2.0
|
||||||
sympy==1.13.1
|
sympy==1.13.3
|
||||||
torch==2.6.0.dev20241222+rocm6.2.4
|
torch==2.7.0.dev20250226+rocm6.3
|
||||||
torchaudio==2.6.0.dev20241222+rocm6.2.4
|
torchaudio==2.6.0.dev20250226+rocm6.3
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
typing_extensions==4.12.2
|
typing_extensions==4.12.2
|
||||||
|
95
training.py
95
training.py
@ -20,49 +20,35 @@ from data import AudioDataset
|
|||||||
from generator import SISUGenerator
|
from generator import SISUGenerator
|
||||||
from discriminator import SISUDiscriminator
|
from discriminator import SISUDiscriminator
|
||||||
|
|
||||||
import librosa
|
import torchaudio.transforms as T
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def mfcc_loss(y_true, y_pred, sr):
|
# Init script argument parser
|
||||||
"""Calculates MFCC loss between two audio signals.
|
parser = argparse.ArgumentParser(description="Training script")
|
||||||
|
parser.add_argument("--generator", type=str, default=None,
|
||||||
|
help="Path to the generator model file")
|
||||||
|
parser.add_argument("--discriminator", type=str, default=None,
|
||||||
|
help="Path to the discriminator model file")
|
||||||
|
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
||||||
|
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
|
||||||
|
|
||||||
Args:
|
args = parser.parse_args()
|
||||||
y_true: Target audio signal (PyTorch tensor).
|
|
||||||
y_pred: Predicted audio signal (PyTorch tensor).
|
|
||||||
sr: Sample rate (NumPy scalar).
|
|
||||||
|
|
||||||
Returns:
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||||
MFCC loss (PyTorch tensor).
|
print(f"Using device: {device}")
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. Ensure sr is a NumPy scalar (not a Tensor)
|
mfcc_transform = T.MFCC(
|
||||||
if isinstance(sr, torch.Tensor):
|
sample_rate=16000, # Adjust to your sample rate
|
||||||
sr = sr.item()
|
n_mfcc=20,
|
||||||
|
melkwargs={'n_fft': 2048, 'hop_length': 512} # adjust n_fft and hop_length to your needs.
|
||||||
# 2. Convert y_true and y_pred to NumPy arrays (and detach from graph)
|
).to(device)
|
||||||
y_true_np = y_true.cpu().detach().numpy()[0] # .cpu() and .detach() are crucial!
|
|
||||||
y_pred_np = y_pred.cpu().detach().numpy()[0]
|
|
||||||
|
|
||||||
# 3. Dynamically calculate n_fft based on signal length
|
|
||||||
signal_length = min(y_true_np.shape[0], y_pred_np.shape[0]) # Use shortest signal length
|
|
||||||
n_fft = min(2048, 2**int(np.log2(signal_length))) # Power of 2, up to 2048
|
|
||||||
|
|
||||||
# 4. Calculate MFCCs using adjusted n_fft
|
|
||||||
mfccs_true = librosa.feature.mfcc(y=y_true_np, sr=sr, n_fft=n_fft, n_mfcc=20)
|
|
||||||
mfccs_pred = librosa.feature.mfcc(y=y_pred_np, sr=sr, n_fft=n_fft, n_mfcc=20)
|
|
||||||
|
|
||||||
# 5. Truncate MFCCs to the same length (important!)
|
|
||||||
len_true = mfccs_true.shape[1]
|
|
||||||
len_pred = mfccs_pred.shape[1]
|
|
||||||
min_len = min(len_true, len_pred)
|
|
||||||
|
|
||||||
mfccs_true = mfccs_true[:, :min_len]
|
|
||||||
mfccs_pred = mfccs_pred[:, :min_len]
|
|
||||||
|
|
||||||
# 6. Convert MFCCs back to PyTorch tensors and ensure correct device
|
|
||||||
mfccs_true = torch.tensor(mfccs_true, device=y_true.device, dtype=torch.float32)
|
|
||||||
mfccs_pred = torch.tensor(mfccs_pred, device=y_pred.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
def gpu_mfcc_loss(y_true, y_pred):
|
||||||
|
mfccs_true = mfcc_transform(y_true)
|
||||||
|
mfccs_pred = mfcc_transform(y_pred)
|
||||||
|
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
|
||||||
|
mfccs_true = mfccs_true[:, :, :min_len]
|
||||||
|
mfccs_pred = mfccs_pred[:, :, :min_len]
|
||||||
return torch.mean((mfccs_true - mfccs_pred)**2)
|
return torch.mean((mfccs_true - mfccs_pred)**2)
|
||||||
|
|
||||||
def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
|
def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
|
||||||
@ -93,7 +79,7 @@ def generator_train(low_quality, high_quality, real_labels):
|
|||||||
# Forward pass for fake samples (from generator output)
|
# Forward pass for fake samples (from generator output)
|
||||||
generator_output = generator(low_quality[0])
|
generator_output = generator(low_quality[0])
|
||||||
|
|
||||||
mfcc_l = mfcc_loss(high_quality[0], generator_output, high_quality[1])
|
mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
|
||||||
|
|
||||||
discriminator_decision = discriminator(generator_output)
|
discriminator_decision = discriminator(generator_output)
|
||||||
adversarial_loss = criterion_g(discriminator_decision, real_labels)
|
adversarial_loss = criterion_g(discriminator_decision, real_labels)
|
||||||
@ -105,26 +91,11 @@ def generator_train(low_quality, high_quality, real_labels):
|
|||||||
|
|
||||||
return (generator_output, combined_loss, adversarial_loss, mfcc_l)
|
return (generator_output, combined_loss, adversarial_loss, mfcc_l)
|
||||||
|
|
||||||
# Init script argument parser
|
|
||||||
parser = argparse.ArgumentParser(description="Training script")
|
|
||||||
parser.add_argument("--generator", type=str, default=None,
|
|
||||||
help="Path to the generator model file")
|
|
||||||
parser.add_argument("--discriminator", type=str, default=None,
|
|
||||||
help="Path to the discriminator model file")
|
|
||||||
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
|
|
||||||
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Check for CUDA availability
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Using device: {device}")
|
|
||||||
|
|
||||||
debug = args.verbose
|
debug = args.verbose
|
||||||
|
|
||||||
# Initialize dataset and dataloader
|
# Initialize dataset and dataloader
|
||||||
dataset_dir = './dataset/good'
|
dataset_dir = './dataset/good'
|
||||||
dataset = AudioDataset(dataset_dir)
|
dataset = AudioDataset(dataset_dir, device)
|
||||||
|
|
||||||
# ========= MULTIPLE =========
|
# ========= MULTIPLE =========
|
||||||
|
|
||||||
@ -147,14 +118,14 @@ discriminator = SISUDiscriminator()
|
|||||||
|
|
||||||
epoch: int = args.epoch
|
epoch: int = args.epoch
|
||||||
|
|
||||||
if args.generator is not None:
|
|
||||||
generator.load_state_dict(torch.load(args.generator, weights_only=True))
|
|
||||||
if args.discriminator is not None:
|
|
||||||
discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True))
|
|
||||||
|
|
||||||
generator = generator.to(device)
|
generator = generator.to(device)
|
||||||
discriminator = discriminator.to(device)
|
discriminator = discriminator.to(device)
|
||||||
|
|
||||||
|
if args.generator is not None:
|
||||||
|
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
|
||||||
|
if args.discriminator is not None:
|
||||||
|
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
criterion_g = nn.MSELoss()
|
criterion_g = nn.MSELoss()
|
||||||
criterion_d = nn.BCELoss()
|
criterion_d = nn.BCELoss()
|
||||||
@ -182,8 +153,8 @@ def start_training():
|
|||||||
# ========= TRAINING =========
|
# ========= TRAINING =========
|
||||||
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
|
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
|
||||||
# for high_quality_clip, low_quality_clip in train_data_loader:
|
# for high_quality_clip, low_quality_clip in train_data_loader:
|
||||||
high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1])
|
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
|
||||||
low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1])
|
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
|
||||||
|
|
||||||
# ========= LABELS =========
|
# ========= LABELS =========
|
||||||
batch_size = high_quality_clip[0].size(0)
|
batch_size = high_quality_clip[0].size(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user