Merge new-arch, because it has proven to give the best results #1
26
data.py
26
data.py
@ -4,23 +4,20 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torchaudio.transforms as T
|
import torchaudio.transforms as T
|
||||||
import AudioUtils
|
import AudioUtils
|
||||||
|
|
||||||
class AudioDataset(Dataset):
|
class AudioDataset(Dataset):
|
||||||
#audio_sample_rates = [8000, 11025, 16000, 22050]
|
|
||||||
audio_sample_rates = [11025]
|
audio_sample_rates = [11025]
|
||||||
|
MAX_LENGTH = 88200 # Define your desired maximum length here
|
||||||
|
|
||||||
def __init__(self, input_dir, device):
|
def __init__(self, input_dir, device):
|
||||||
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
|
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.input_files)
|
return len(self.input_files)
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
# Load high-quality audio
|
# Load high-quality audio
|
||||||
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
|
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
|
||||||
@ -33,7 +30,24 @@ class AudioDataset(Dataset):
|
|||||||
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
|
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
|
||||||
low_quality_audio = resample_transform_high(low_quality_audio)
|
low_quality_audio = resample_transform_high(low_quality_audio)
|
||||||
|
|
||||||
high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio).to(self.device)
|
high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio)
|
||||||
low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio).to(self.device)
|
low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio)
|
||||||
|
|
||||||
|
# Pad or truncate high-quality audio
|
||||||
|
if high_quality_audio.shape[1] < self.MAX_LENGTH:
|
||||||
|
padding = self.MAX_LENGTH - high_quality_audio.shape[1]
|
||||||
|
high_quality_audio = F.pad(high_quality_audio, (0, padding))
|
||||||
|
elif high_quality_audio.shape[1] > self.MAX_LENGTH:
|
||||||
|
high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH]
|
||||||
|
|
||||||
|
# Pad or truncate low-quality audio
|
||||||
|
if low_quality_audio.shape[1] < self.MAX_LENGTH:
|
||||||
|
padding = self.MAX_LENGTH - low_quality_audio.shape[1]
|
||||||
|
low_quality_audio = F.pad(low_quality_audio, (0, padding))
|
||||||
|
elif low_quality_audio.shape[1] > self.MAX_LENGTH:
|
||||||
|
low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH]
|
||||||
|
|
||||||
|
high_quality_audio = high_quality_audio.to(self.device)
|
||||||
|
low_quality_audio = low_quality_audio.to(self.device)
|
||||||
|
|
||||||
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
|
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
|
||||||
|
@ -2,35 +2,54 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.utils as utils
|
import torch.nn.utils as utils
|
||||||
|
|
||||||
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
|
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True):
|
||||||
padding = (kernel_size // 2) * dilation
|
padding = (kernel_size // 2) * dilation
|
||||||
|
conv_layer = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding)
|
||||||
|
if spectral_norm:
|
||||||
|
conv_layer = utils.spectral_norm(conv_layer)
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
utils.spectral_norm(nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding)),
|
conv_layer,
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
nn.BatchNorm1d(out_channels)
|
nn.BatchNorm1d(out_channels)
|
||||||
)
|
)
|
||||||
|
|
||||||
class SISUDiscriminator(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, channels):
|
||||||
super(SISUDiscriminator, self).__init__()
|
super(AttentionBlock, self).__init__()
|
||||||
layers = 4 # Increased base layer count
|
self.attention = nn.Sequential(
|
||||||
self.model = nn.Sequential(
|
nn.Conv1d(channels, channels // 4, kernel_size=1),
|
||||||
discriminator_block(1, layers, kernel_size=7, stride=2), # Initial downsampling
|
nn.ReLU(),
|
||||||
discriminator_block(layers, layers * 2, kernel_size=5, stride=2), # Downsampling
|
nn.Conv1d(channels // 4, channels, kernel_size=1),
|
||||||
discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2), # Increased dilation
|
nn.Sigmoid()
|
||||||
discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=4), # Increased dilation
|
|
||||||
discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=8), # Deeper layer!
|
|
||||||
discriminator_block(layers * 8, layers * 8, kernel_size=5, dilation=1), # Deeper layer!
|
|
||||||
discriminator_block(layers * 8, layers * 4, kernel_size=3, dilation=2), # Reduced dilation
|
|
||||||
discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=1),
|
|
||||||
discriminator_block(layers * 2, layers, kernel_size=3, stride=1), # Final convolution
|
|
||||||
discriminator_block(layers, 1, kernel_size=3, stride=1)
|
|
||||||
)
|
)
|
||||||
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Gaussian noise is not necessary here for discriminator as it is already implicit in the training process
|
attention_weights = self.attention(x)
|
||||||
|
return x * attention_weights
|
||||||
|
|
||||||
|
class SISUDiscriminator(nn.Module):
|
||||||
|
def __init__(self, layers=4): #Increased base layer count
|
||||||
|
super(SISUDiscriminator, self).__init__()
|
||||||
|
self.model = nn.Sequential(
|
||||||
|
discriminator_block(1, layers, kernel_size=7, stride=4), #Aggressive downsampling
|
||||||
|
discriminator_block(layers, layers * 2, kernel_size=5, stride=2),
|
||||||
|
discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2),
|
||||||
|
discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4),
|
||||||
|
AttentionBlock(layers * 8), #Added attention
|
||||||
|
discriminator_block(layers * 8, layers * 16, kernel_size=5, dilation=8),
|
||||||
|
discriminator_block(layers * 16, layers * 16, kernel_size=3, dilation=1),
|
||||||
|
discriminator_block(layers * 16, layers * 8, kernel_size=3, dilation=2),
|
||||||
|
discriminator_block(layers * 8, layers * 4, kernel_size=3, dilation=1),
|
||||||
|
discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1),
|
||||||
|
discriminator_block(layers * 2, layers, kernel_size=3, stride=1),
|
||||||
|
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False) #last layer no spectral norm.
|
||||||
|
)
|
||||||
|
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
x = self.model(x)
|
x = self.model(x)
|
||||||
x = self.global_avg_pool(x)
|
x = self.global_avg_pool(x)
|
||||||
x = x.view(-1, 1)
|
x = x.view(-1, 1)
|
||||||
|
x = self.sigmoid(x)
|
||||||
return x
|
return x
|
||||||
|
44
generator.py
44
generator.py
@ -7,30 +7,46 @@ def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
|
|||||||
nn.PReLU()
|
nn.PReLU()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super(AttentionBlock, self).__init__()
|
||||||
|
self.attention = nn.Sequential(
|
||||||
|
nn.Conv1d(channels, channels // 4, kernel_size=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1d(channels // 4, channels, kernel_size=1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
attention_weights = self.attention(x)
|
||||||
|
return x * attention_weights
|
||||||
|
|
||||||
|
class ResidualInResidualBlock(nn.Module):
|
||||||
|
def __init__(self, channels, num_convs=3):
|
||||||
|
super(ResidualInResidualBlock, self).__init__()
|
||||||
|
self.conv_layers = nn.Sequential(*[conv_block(channels, channels) for _ in range(num_convs)])
|
||||||
|
self.attention = AttentionBlock(channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
x = self.conv_layers(x)
|
||||||
|
x = self.attention(x)
|
||||||
|
return x + residual
|
||||||
|
|
||||||
class SISUGenerator(nn.Module):
|
class SISUGenerator(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts
|
||||||
super(SISUGenerator, self).__init__()
|
super(SISUGenerator, self).__init__()
|
||||||
layer = 4 # Increased base layer count
|
|
||||||
self.conv1 = nn.Sequential(
|
self.conv1 = nn.Sequential(
|
||||||
nn.Conv1d(1, layer, kernel_size=7, padding=3),
|
nn.Conv1d(1, layer, kernel_size=7, padding=3),
|
||||||
nn.BatchNorm1d(layer),
|
nn.BatchNorm1d(layer),
|
||||||
nn.PReLU(),
|
nn.PReLU(),
|
||||||
)
|
)
|
||||||
self.conv_blocks = nn.Sequential(
|
self.rir_blocks = nn.Sequential(*[ResidualInResidualBlock(layer) for _ in range(num_rirb)])
|
||||||
conv_block(layer, layer, kernel_size=3, dilation=1), # Local details
|
self.final_layer = nn.Conv1d(layer, 1, kernel_size=3, padding=1)
|
||||||
conv_block(layer, layer*2, kernel_size=5, dilation=2), # Local Context
|
|
||||||
conv_block(layer*2, layer*2, kernel_size=3, dilation=16), # Longer range dependencies
|
|
||||||
conv_block(layer*2, layer*2, kernel_size=5, dilation=8), # Wider context
|
|
||||||
conv_block(layer*2, layer, kernel_size=5, dilation=2), # Local Context
|
|
||||||
conv_block(layer, layer, kernel_size=3, dilation=1), # Local details
|
|
||||||
)
|
|
||||||
self.final_layer = nn.Sequential(
|
|
||||||
nn.Conv1d(layer, 1, kernel_size=3, padding=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv_blocks(x)
|
x = self.rir_blocks(x)
|
||||||
x = self.final_layer(x)
|
x = self.final_layer(x)
|
||||||
return x + residual
|
return x + residual
|
||||||
|
27
training.py
27
training.py
@ -38,7 +38,7 @@ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
|||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
mfcc_transform = T.MFCC(
|
mfcc_transform = T.MFCC(
|
||||||
sample_rate=16000, # Adjust to your sample rate
|
sample_rate=44100, # Adjust to your sample rate
|
||||||
n_mfcc=20,
|
n_mfcc=20,
|
||||||
melkwargs={'n_fft': 2048, 'hop_length': 512} # adjust n_fft and hop_length to your needs.
|
melkwargs={'n_fft': 2048, 'hop_length': 512} # adjust n_fft and hop_length to your needs.
|
||||||
).to(device)
|
).to(device)
|
||||||
@ -97,20 +97,9 @@ debug = args.verbose
|
|||||||
dataset_dir = './dataset/good'
|
dataset_dir = './dataset/good'
|
||||||
dataset = AudioDataset(dataset_dir, device)
|
dataset = AudioDataset(dataset_dir, device)
|
||||||
|
|
||||||
# ========= MULTIPLE =========
|
|
||||||
|
|
||||||
# dataset_size = len(dataset)
|
|
||||||
# train_size = int(dataset_size * .9)
|
|
||||||
# val_size = int(dataset_size-train_size)
|
|
||||||
|
|
||||||
#train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
|
||||||
|
|
||||||
# train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
|
|
||||||
# val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
|
|
||||||
|
|
||||||
# ========= SINGLE =========
|
# ========= SINGLE =========
|
||||||
|
|
||||||
train_data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
|
train_data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
|
||||||
|
|
||||||
# Initialize models and move them to device
|
# Initialize models and move them to device
|
||||||
generator = SISUGenerator()
|
generator = SISUGenerator()
|
||||||
@ -175,17 +164,17 @@ def start_training():
|
|||||||
scheduler_g.step(combined_loss)
|
scheduler_g.step(combined_loss)
|
||||||
|
|
||||||
# ========= SAVE LATEST AUDIO =========
|
# ========= SAVE LATEST AUDIO =========
|
||||||
high_quality_audio = high_quality_clip
|
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
|
||||||
low_quality_audio = low_quality_clip
|
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
|
||||||
ai_enhanced_audio = (generator_output, high_quality_clip[1])
|
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
|
||||||
|
|
||||||
new_epoch = generator_epoch+epoch
|
new_epoch = generator_epoch+epoch
|
||||||
|
|
||||||
if generator_epoch % 10 == 0:
|
if generator_epoch % 10 == 0:
|
||||||
print(f"Saved epoch {new_epoch}!")
|
print(f"Saved epoch {new_epoch}!")
|
||||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
|
torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
|
||||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
|
torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), ai_enhanced_audio[1])
|
||||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
|
torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), high_quality_audio[1])
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(generator.state_dict().keys())
|
print(generator.state_dict().keys())
|
||||||
|
Loading…
Reference in New Issue
Block a user