Я изучаю GAN. Я закончил один курс, который дал мне пример программы, которая генерирует изображения на основе введенных примеров.
Пример можно найти здесь:
https://github.com/davidsonmizael/gan
Поэтому я решил использовать это для создания новых изображений на основе набора данных фронтальных фотографий лиц, но я не добился успеха. В отличие от приведенного выше примера, код генерирует только шум, а на входе - фактические изображения.
На самом деле, я не имею ни малейшего представления о том, что я должен изменить, чтобы код указывал в правильном направлении и учился на изображениях. Я не изменяю ни одного значения кода, представленного в примере, но он не работает.
Если кто-нибудь может помочь мне понять это и указать мне в правильном направлении, это будет очень полезно. Спасибо заранее.
Мой Дискриминатор:
class D(nn.Module):
def __init__(self):
super(D, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias = False),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(64, 128, 4, 2, 1, bias = False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(128, 256, 4, 2, 1, bias = False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(256, 512, 4, 2, 1, bias = False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(512, 1, 4, 1, 0, bias = False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input).view(-1)
Мой генератор:
class G(nn.Module):
def __init__(self):
super(G, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
Моя функция для запуска весов:
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
Полный код можно увидеть здесь:
https://github.com/davidsonmizael/criminal-gan