:albemic: | More tests.
This commit is contained in:
		| @@ -28,7 +28,7 @@ class AttentionBlock(nn.Module): | |||||||
|         return x * attention_weights |         return x * attention_weights | ||||||
|  |  | ||||||
| class SISUDiscriminator(nn.Module): | class SISUDiscriminator(nn.Module): | ||||||
|     def __init__(self, layers=64): #Increased base layer count |     def __init__(self, layers=4): #Increased base layer count | ||||||
|         super(SISUDiscriminator, self).__init__() |         super(SISUDiscriminator, self).__init__() | ||||||
|         self.model = nn.Sequential( |         self.model = nn.Sequential( | ||||||
|             discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling |             discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling | ||||||
|   | |||||||
| @@ -34,7 +34,7 @@ class ResidualInResidualBlock(nn.Module): | |||||||
|         return x + residual |         return x + residual | ||||||
|  |  | ||||||
| class SISUGenerator(nn.Module): | class SISUGenerator(nn.Module): | ||||||
|     def __init__(self, layer=64, num_rirb=4): #increased base layer and rirb amounts |     def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts | ||||||
|         super(SISUGenerator, self).__init__() |         super(SISUGenerator, self).__init__() | ||||||
|         self.conv1 = nn.Sequential( |         self.conv1 = nn.Sequential( | ||||||
|             nn.Conv1d(1, layer, kernel_size=7, padding=3), |             nn.Conv1d(1, layer, kernel_size=7, padding=3), | ||||||
|   | |||||||
							
								
								
									
										22
									
								
								training.py
									
									
									
									
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										22
									
								
								training.py
									
									
									
									
									
								
							| @@ -30,7 +30,7 @@ parser.add_argument("--discriminator", type=str, default=None, | |||||||
|                     help="Path to the discriminator model file") |                     help="Path to the discriminator model file") | ||||||
| parser.add_argument("--device", type=str, default="cpu",  help="Select device") | parser.add_argument("--device", type=str, default="cpu",  help="Select device") | ||||||
| parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") | parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") | ||||||
| parser.add_argument("--verbose", action="store_true",  help="Increase output verbosity") | parser.add_argument("--debug", action="store_true",  help="Print debug logs") | ||||||
|  |  | ||||||
| args = parser.parse_args() | args = parser.parse_args() | ||||||
|  |  | ||||||
| @@ -80,19 +80,20 @@ def generator_train(low_quality, high_quality, real_labels): | |||||||
|     # Forward pass for fake samples (from generator output) |     # Forward pass for fake samples (from generator output) | ||||||
|     generator_output = generator(low_quality[0]) |     generator_output = generator(low_quality[0]) | ||||||
|  |  | ||||||
|     mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output) |     #mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output) | ||||||
|  |  | ||||||
|     discriminator_decision = discriminator(generator_output) |     discriminator_decision = discriminator(generator_output) | ||||||
|     adversarial_loss = criterion_g(discriminator_decision, real_labels) |     adversarial_loss = criterion_g(discriminator_decision, real_labels) | ||||||
|  |  | ||||||
|     combined_loss = adversarial_loss + 0.5 * mfcc_l |     #combined_loss = adversarial_loss + 0.5 * mfcc_l | ||||||
|  |  | ||||||
|     combined_loss.backward() |     adversarial_loss.backward() | ||||||
|     optimizer_g.step() |     optimizer_g.step() | ||||||
|  |  | ||||||
|     return (generator_output, combined_loss, adversarial_loss, mfcc_l) |     #return (generator_output, combined_loss, adversarial_loss, mfcc_l) | ||||||
|  |     return (generator_output, adversarial_loss) | ||||||
|  |  | ||||||
| debug = args.verbose | debug = args.debug | ||||||
|  |  | ||||||
| # Initialize dataset and dataloader | # Initialize dataset and dataloader | ||||||
| dataset_dir = './dataset/good' | dataset_dir = './dataset/good' | ||||||
| @@ -100,7 +101,7 @@ dataset = AudioDataset(dataset_dir, device) | |||||||
|  |  | ||||||
| # ========= SINGLE ========= | # ========= SINGLE ========= | ||||||
|  |  | ||||||
| train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True) | train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True) | ||||||
|  |  | ||||||
| # Initialize models and move them to device | # Initialize models and move them to device | ||||||
| generator = SISUGenerator() | generator = SISUGenerator() | ||||||
| @@ -157,12 +158,13 @@ def start_training(): | |||||||
|  |  | ||||||
|             # ========= GENERATOR ========= |             # ========= GENERATOR ========= | ||||||
|             generator.train() |             generator.train() | ||||||
|             generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels) |             #generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels) | ||||||
|  |             generator_output, adversarial_loss = generator_train(low_quality_sample, high_quality_sample, real_labels) | ||||||
|  |  | ||||||
|             if debug: |             if debug: | ||||||
|                 print(d_loss, combined_loss, adversarial_loss, mfcc_l) |                 print(d_loss, adversarial_loss) | ||||||
|             scheduler_d.step(d_loss) |             scheduler_d.step(d_loss) | ||||||
|             #scheduler_g.step(combined_loss) |             scheduler_g.step(adversarial_loss) | ||||||
|  |  | ||||||
|             # ========= SAVE LATEST AUDIO ========= |             # ========= SAVE LATEST AUDIO ========= | ||||||
|             high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) |             high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user