diff --git a/train_gan3d.py b/train_gan3d.py index 792d31f..e759d1f 100644 --- a/train_gan3d.py +++ b/train_gan3d.py @@ -102,7 +102,7 @@ def train(self): self.d_optim.zero_grad() predict_real, predict_id, predict_ex= self.discriminator(real_data) - error_real = self.criterion_gan(predict_real, make_ones(batch_size).to(device)) + self.criterion_class(predict_ex, label_ex) + self.criterion_class(predict_id, label_ex) + error_real = self.criterion_gan(predict_real, make_ones(batch_size).to(device)) + self.criterion_class(predict_ex, label_ex) + self.criterion_class(predict_id, label_id) error_real.backward() predict_fake, fake_id, fake_ex = self.discriminator(fake_data)