diff --git a/data.py b/data.py index ac69730..2f05581 100644 --- a/data.py +++ b/data.py @@ -12,8 +12,9 @@ class AudioDataset(Dataset): #audio_sample_rates = [8000, 11025, 16000, 22050] audio_sample_rates = [11025] - def __init__(self, input_dir): + def __init__(self, input_dir, device): self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')] + self.device = device def __len__(self): @@ -32,4 +33,7 @@ class AudioDataset(Dataset): resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) low_quality_audio = resample_transform_high(low_quality_audio) - return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate) + high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio).to(self.device) + low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio).to(self.device) + + return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) diff --git a/requirements.txt b/requirements.txt index 5cb5df1..eacfc3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,11 +4,11 @@ Jinja2==3.1.4 MarkupSafe==2.1.5 mpmath==1.3.0 networkx==3.4.2 -numpy==2.2.1 -pytorch-triton-rocm==3.2.0+git0d4682f0 +numpy==2.2.3 +pytorch-triton-rocm==3.2.0+git4b3bb1f8 setuptools==70.2.0 -sympy==1.13.1 -torch==2.6.0.dev20241222+rocm6.2.4 -torchaudio==2.6.0.dev20241222+rocm6.2.4 +sympy==1.13.3 +torch==2.7.0.dev20250226+rocm6.3 +torchaudio==2.6.0.dev20250226+rocm6.3 tqdm==4.67.1 typing_extensions==4.12.2 diff --git a/training.py b/training.py index 710dd65..bf60c5c 100644 --- a/training.py +++ b/training.py @@ -20,49 +20,35 @@ from data import AudioDataset from generator import SISUGenerator from discriminator import SISUDiscriminator -import librosa -import numpy as np +import torchaudio.transforms as T -def mfcc_loss(y_true, y_pred, sr): - """Calculates MFCC loss between two audio signals. +# Init script argument parser +parser = argparse.ArgumentParser(description="Training script") +parser.add_argument("--generator", type=str, default=None, + help="Path to the generator model file") +parser.add_argument("--discriminator", type=str, default=None, + help="Path to the discriminator model file") +parser.add_argument("--device", type=str, default="cpu", help="Select device") +parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") +parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") - Args: - y_true: Target audio signal (PyTorch tensor). - y_pred: Predicted audio signal (PyTorch tensor). - sr: Sample rate (NumPy scalar). +args = parser.parse_args() - Returns: - MFCC loss (PyTorch tensor). - """ +device = torch.device(args.device if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") - # 1. Ensure sr is a NumPy scalar (not a Tensor) - if isinstance(sr, torch.Tensor): - sr = sr.item() - - # 2. Convert y_true and y_pred to NumPy arrays (and detach from graph) - y_true_np = y_true.cpu().detach().numpy()[0] # .cpu() and .detach() are crucial! - y_pred_np = y_pred.cpu().detach().numpy()[0] - - # 3. Dynamically calculate n_fft based on signal length - signal_length = min(y_true_np.shape[0], y_pred_np.shape[0]) # Use shortest signal length - n_fft = min(2048, 2**int(np.log2(signal_length))) # Power of 2, up to 2048 - - # 4. Calculate MFCCs using adjusted n_fft - mfccs_true = librosa.feature.mfcc(y=y_true_np, sr=sr, n_fft=n_fft, n_mfcc=20) - mfccs_pred = librosa.feature.mfcc(y=y_pred_np, sr=sr, n_fft=n_fft, n_mfcc=20) - - # 5. Truncate MFCCs to the same length (important!) - len_true = mfccs_true.shape[1] - len_pred = mfccs_pred.shape[1] - min_len = min(len_true, len_pred) - - mfccs_true = mfccs_true[:, :min_len] - mfccs_pred = mfccs_pred[:, :min_len] - - # 6. Convert MFCCs back to PyTorch tensors and ensure correct device - mfccs_true = torch.tensor(mfccs_true, device=y_true.device, dtype=torch.float32) - mfccs_pred = torch.tensor(mfccs_pred, device=y_pred.device, dtype=torch.float32) +mfcc_transform = T.MFCC( + sample_rate=16000, # Adjust to your sample rate + n_mfcc=20, + melkwargs={'n_fft': 2048, 'hop_length': 512} # adjust n_fft and hop_length to your needs. +).to(device) +def gpu_mfcc_loss(y_true, y_pred): + mfccs_true = mfcc_transform(y_true) + mfccs_pred = mfcc_transform(y_pred) + min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2]) + mfccs_true = mfccs_true[:, :, :min_len] + mfccs_pred = mfccs_pred[:, :, :min_len] return torch.mean((mfccs_true - mfccs_pred)**2) def discriminator_train(high_quality, low_quality, real_labels, fake_labels): @@ -93,7 +79,7 @@ def generator_train(low_quality, high_quality, real_labels): # Forward pass for fake samples (from generator output) generator_output = generator(low_quality[0]) - mfcc_l = mfcc_loss(high_quality[0], generator_output, high_quality[1]) + mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output) discriminator_decision = discriminator(generator_output) adversarial_loss = criterion_g(discriminator_decision, real_labels) @@ -105,26 +91,11 @@ def generator_train(low_quality, high_quality, real_labels): return (generator_output, combined_loss, adversarial_loss, mfcc_l) -# Init script argument parser -parser = argparse.ArgumentParser(description="Training script") -parser.add_argument("--generator", type=str, default=None, - help="Path to the generator model file") -parser.add_argument("--discriminator", type=str, default=None, - help="Path to the discriminator model file") -parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") -parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") - -args = parser.parse_args() - -# Check for CUDA availability -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print(f"Using device: {device}") - debug = args.verbose # Initialize dataset and dataloader dataset_dir = './dataset/good' -dataset = AudioDataset(dataset_dir) +dataset = AudioDataset(dataset_dir, device) # ========= MULTIPLE ========= @@ -147,14 +118,14 @@ discriminator = SISUDiscriminator() epoch: int = args.epoch -if args.generator is not None: - generator.load_state_dict(torch.load(args.generator, weights_only=True)) -if args.discriminator is not None: - discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True)) - generator = generator.to(device) discriminator = discriminator.to(device) +if args.generator is not None: + generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True)) +if args.discriminator is not None: + discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) + # Loss criterion_g = nn.MSELoss() criterion_d = nn.BCELoss() @@ -182,8 +153,8 @@ def start_training(): # ========= TRAINING ========= for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): # for high_quality_clip, low_quality_clip in train_data_loader: - high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1]) - low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1]) + high_quality_sample = (high_quality_clip[0], high_quality_clip[1]) + low_quality_sample = (low_quality_clip[0], low_quality_clip[1]) # ========= LABELS ========= batch_size = high_quality_clip[0].size(0)