您好,欢迎来到12图资源库!分享精神,快乐你我!我们只是素材的搬运工!!
  • 首 页
  • 当前位置:首页 > 开发 > WEB开发 >
    复杂运用PyTorch搭建GAN模型
    时间:2021-08-25 21:17 来源:网络整理 作者:网络 浏览:收藏 挑错 推荐 打印

    复杂运用PyTorch搭建GAN模型

    以往人们普遍以为生成图像是不能够完成的义务,由于按照传统的机器学习思绪,我们基本没有真值(ground truth)可以拿来检验生成的图像能否合格。

    2014年,Goodfellow等人则提出生成 对立网络(Generative Adversarial Network, GAN) ,可以让我们完全依托机器学习来生成极为逼真的图片。GAN的横空出生使得整团体工智能行业都为之震动,计算机视觉和图像生成范围发作了剧变。

    本文将带大家了解 GAN的任务原理 ,并引见如何 经过PyTorch复杂上手GAN 。

    GAN的原理

    按照传统的办法,模型的预测结果可以直接与已有的真值停止比较。但是,我们却很难定义和权衡究竟怎样才算作是“正确的”生成图像。

    Goodfellow等人则提出了一个幽默的处置办法:我们可以先训练好一个分类工具,来自动区分生成图像和真实图像。这样一来,我们就可以用这个分类工具来训练一个生成网络,直到它可以输入完全以假乱真的图像,连分类工具本人都没有办法评判真假。

    复杂运用PyTorch搭建GAN模型

    按照这一思绪,我们便有了GAN:也就是一个 生成器(generator) 和一个 判别器(discriminator) 。生成器担任依据给定的数据集生成图像,判别器则担任区分图像是真是假。GAN的运作流程如上图所示。

    损失函数

    在GAN的运作流程中,我们可以发现一个清楚的矛盾:同时优化生成器和判别器是很困难的。可以想象,这两个模型有着完全相反的目的:生成器想要尽能够伪造出真实的东西,而判别器则必需要识破生成器生成的图像。

    为了阐明这一点,我们设D(x)为判别器的输入,即x是真实图像的概率,并设G(z)为生成器的输入。判别器相似于一种二进制的分类器,所以其目的是使该函数的结果最大化:

    复杂运用PyTorch搭建GAN模型

    这一函数本质上是非负的二元交叉熵损失函数。另一方面,生成器的目的是最小化判别器做出正确判别的机率,因此它的目的是使上述函数的结果最小化。

    因此,最终的损失函数将会是两个分类器之间的极小极大博弈,表示如下:

    复杂运用PyTorch搭建GAN模型

    实际下去说,博弈的最终结果将是让判别器判别成功的概率收敛到0.5。但是在实际中,极大极小博弈通常会招致网络不收敛,因此细心调整模型训练的参数十分重要。

    在训练GAN时,我们尤其要留意学习率等超参数,学习率比较小时能让GAN在输入噪音较多的状况下也能有较为一致的输入。

    计算环境 库

    本文将指点大家经过PyTorch搭建整个顺序(包括torchvision)。同时,我们将会运用Matplotlib来让GAN的生成结果可视化。以下代码可以导入上述一切库:

    ""

    Import necessary libraries to create a generative adversarial network 

    The code is mainly developed using the PyTorch library 

    ""

    import time 

    import torch 

    import torch.nn as nn 

    import torch.optim as optim 

    from torch.utils.data import DataLoader 

    from torchvision import datasets 

    from torchvision.transforms import transforms 

    from model import discriminator, generator 

    import numpy as np 

    import matplotlib.pyplot as plt 

    数据集

    数据集关于训练GAN来说十分重要,尤其思索到我们在GAN中处置的通常是非结构化数据(普通是图片、视频等),恣意一class都可以有数据的散布。这种数据散布恰恰是GAN生成输入的基础。

    为了更好地演示GAN的搭建流程,本文将带大家运用最复杂的MNIST数据集,其中含有6万张手写阿拉伯数字的图片。

    像 MNIST 这样高质量的非结构化数据集都可以在 格物钛 的 地下数据集 网站上找到。理想上,格物钛Open Datasets平台涵盖了很多优质的地下数据集,同时也可以完成 数据集托管及一站式搜索的功用 ,这对AI开发者来说,是相当适用的社区平台。

    复杂运用PyTorch搭建GAN模型

    硬件需求

    普通来说,虽然可以运用CPU来训练神经网络,但最佳选择其实是GPU,由于这样可以大幅提升训练速度。我们可以用下面的代码来测试本人的机器能否用GPU来训练:

    ""

    Determine if any GPUs are available 

    ""

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'

    完成 网络结构

    由于数字是十分复杂的信息,我们可以将判别器和生成器这两层结构都组建成全衔接层(fully connected layers)。

    我们可以用以下代码在PyTorch中搭建判别器和生成器:

    ""

    Network Architectures 

    The following are the discriminator and generator architectures 

    ""

     

    class discriminator(nn.Module): 

        def __init__(self): 

            super(discriminator, self).__init__() 

            self.fc1 = nn.Linear(784512

            self.fc2 = nn.Linear(5121

            self.activation = nn.LeakyReLU(0.1

     

        def forward(self, x): 

            x = x.view(-1784

            x = self.activation(self.fc1(x)) 

            x = self.fc2(x) 

            return nn.Sigmoid()(x) 

     

     

    class generator(nn.Module): 

        def __init__(self): 

            super(generator, self).__init__() 

            self.fc1 = nn.Linear(1281024

            self.fc2 = nn.Linear(10242048

            self.fc3 = nn.Linear(2048784

            self.activation = nn.ReLU() 

     

    def forward(self, x): 

        x = self.activation(self.fc1(x)) 

        x = self.activation(self.fc2(x)) 

        x = self.fc3(x) 

        x = x.view(-112828

        return nn.Tanh()(x) 

    训练

    在训练GAN的时分,我们需求一边优化判别器,一边改良生成器,因此每次迭代我们都需求同时优化两个相互矛盾的损失函数。

    关于生成器,我们将输入一些随机噪音,让生成器来依据噪音的庞大改动输入的图像:

    ""

    Network training procedure 

    Every step both the loss for disciminator and generator is updated 

    Discriminator aims to classify reals and fakes 

    Generator aims to generate images as realistic as possible 

    ""

    for epoch in range(epochs): 

        for idx, (imgs, _) in enumerate(train_loader): 

            idx += 1 

     

            # Training the discriminator 

            # Real inputs are actual images of the MNIST dataset 

            # Fake inputs are from the generator 

            # Real inputs should be classified as 1 and fake as 0 

            real_inputs = imgs.to(device) 

            real_outputs = D(real_inputs) 

            real_label = torch.ones(real_inputs.shape[0], 1).to(device) 

     

    (责任编辑:admin)