资源 | NIPS 2017 Spotlight论文Bayesian GAN的

选自GitHub

作者:Andrew Gordon Wilson 

机器之心编译

参与:路雪、刘晓坤


用生成模型学习高维自然信号(比如图像、视频和音频)长期以来一直是机器学习的重要发展方向之一。来自 Uber AI Lab 的 Yunus Saatchi 等人今年五月提出了 Bayesian GAN——利用一个简单的贝叶斯公式进行端到端无监督/半监督 GAN 学习。该研究的论文已被列入 NIPS 2017 大会 Spotlight。最近,这篇论文的另一作者 Andrew Gordon Wilson 在 GitHub 上发布了 Bayesian GAN 的 TensorFlow 实现。


项目链接:https://github.com/andrewgordonwilson/bayesgan/


论文:Bayesian GAN



论文链接:https://arxiv.org/abs/1705.09558


摘要:生成对抗网络(GAN)可以隐性地学习难以用显性似然(explicit likelihood)建模的图像、音频和数据的丰富分布。我们展示了一种实际的贝叶斯公式,用 GAN 进行无监督和半监督学习。在该框架下,我们使用随机梯度哈密尔顿蒙特卡罗(Hamiltonian Monte Carlo)来边缘化生成器和判别器的权重。得到的方法很直接,且可在没有标准干预(如特征匹配或小批量判别)的情况下达到不错的性能。通过探索生成器参数具有表达性的后验,贝叶斯 GAN 避免了模式崩溃(mode-collapse),输出可解释的多种候选样本,在 SVHN、CelebA 和 CIFAR-10 等多个基准数据集上取得了顶尖的半监督学习量化结果,优于 DCGAN、Wasserstein GAN 和 DCGAN。


介绍


在贝叶斯 GAN 中,我们提出了生成器和判别器权重的条件后验,通过随机梯度哈密尔顿蒙特卡罗边缘化这些后验。贝叶斯 GAN 的主要特性有:(1)在半监督学习问题上的准确预测;(2)对优秀性能的最小干预;(3)响应对抗反馈的推断的概率公式;(4)避免模式崩溃;(5)展示多个互补的生成和判别模型,形成一个概率集成(probabilistic ensemble)。



我们介绍了一个生成器参数的多模态后验。这些参数的每个设置对应数据的不同生成假设。这里我们将展示两种权重向量设置下生成的样本,不同的权重向量设置对应不同的写作风格。贝叶斯 GAN 保留该参数分布。相反,标准 GAN 用点估计(类似最大似然解决方案)来展示整个分布,降低了数据的可解释性。


环境需求


该代码有以下依赖项(版本号很关键)


  • python 2.7

  • tensorflow==1.0.0


在 Linux 上安装 TensorfFow 1.0.0,请按照 https://www.tensorflow.org/versions/r1.0/install/说明进行操作。


  • scikit-learn==0.17.1


你可以使用下列命令安装 scikit-learn 0.17.1:


 
           
 
 
pip install scikit-learn==0.17.1


或者,使用提供的 environment.yml 文件创建 conda 环境,并进行设置:


 
           
 
 
conda env create -f environment.yml -n bgan


然后,使用下列命令加载环境:


 
           
 
 
source activate bgan


训练选项


bayesian_gan_hmc.py 具备以下训练选项。


  • --out_dir:文件夹路径,用于存储输出

  • --n_save: 每 n_save 次迭代存储的样本和权重;默认值 100

  • --z_dim: 生成器 z 向量的维度;默认值 100

  • --data_path:数据路径;具体讨论详见 https://github.com/andrewgordonwilson/bayesgan/#data-preparation;该参数是必需的

  • --dataset:可以是 mnist、cifar、svhn 或 celeb;默认 mnist

  • --gen_observed: 生成器「观察到」的数据;影响噪声变量和先验的缩放;默认值 1000

  • --batch_size:训练的批量大小;默认值 64

  • --prior_std:权重先验分布的 std;默认值 1

  • --numz:和论文中的 J 一样; z 的样本数,实现整合;默认值 1

  • --num_mcmc: 和论文中的 M 一样;每个 z 的 MCMC NN 权重样本数;默认值 1

  • --lr: Adam 优化器使用的学习率;默认值 0.0002

  • --optimizer:使用的优化方法:adam (tf.train.AdamOptimizer) 或 sgd (tf.train.MomentumOptimizer);默认 adam

  • --semi_supervised:进行半监督学习

  • --N:半监督学习所需标注样本数量

  • --train_iter:训练迭代次数;默认值 50000

  • --save_samples:保存训练过程中生成的样本

  • --save_weights:训练过程中,保存权重

  • --random_seed:随机种子;如果使用 GPU,那么注意设置该种子不会引起 100% 的可复现结果


你还可以用--wasserstein 运行 WGAN,或用--ml_ensemble <num_dcgans> 训练 <num_dcgans> DCGAN 的集成。尤其是,你可以使用--ml_ensemble 1 训练一个 DCGAN。


使用


安装


1. 安装所需依赖项

2. 复制该 repository


合成数据


你可以使用 bgan_synth 脚本运行论文中的合成实验。例如,以下命令用于训练贝叶斯 GAN(D=100,d=10),进行 5000 次迭代,并把结果保存在<results_path>。


 
           
 
 
./bgan_synth.py --x_dim 100 --z_dim 10 --numz 10 --out <results_path>


运行以下命令使用相同的数据运行 ML GAN:


 
           
 
 
./bgan_synth.py --x_dim 100 --z_dim 10 --numz 1 --out <results_path>


bgan_synth 的参数有 --save_weights、--out_dir、--z_dim、--numz、--wasserstein、--train_iter 和 --x_dim。x_dim 控制观测数据(论文中的 x)的维度。通过这个链接查看其它参数的说明:https://github.com/andrewgordonwilson/bayesgan/#training-option。


运行了上面的两个命令之后,你可以在<results_path>里查看每 100 次迭代后的输出。例如,第 900 次迭代的贝叶斯 GAN 的输出结果如下:



相对地,标准 GAN(numz=1,强制执行 ML 评估)的输出结果如下:



可以清晰地看到在这个合成数据的例子中,标准 GAN 出现了模式崩溃的趋势,而贝叶斯 GAN 完全没有这样的问题。


你可以查看 synth.iptnb,进一步探索合成实验,并生成詹森-香农差异图。


MNIST、CIFAR10、CELEBA、SVHN


bayesian_gan_hmc 脚本允许在标准和自定义数据集上训练模型。下面,我们将介绍如何使用该脚本。


数据准备


为了重现在 MNIST、CIFAR10、CelebA 和 SVHN 数据集上的实验,你需要准备这些数据,并使用一个正确的——data_path。


  • 对于 MNIST,你不需要准备数据,并可以提供任意的——data_path;

  • 对于 CIFAR10,请从该地址(https://www.cs.toronto.edu/~kriz/cifar.html)下载和获取数据的 Python 版本;然后使用包含 cifar-10-batchs-py 的目录的路径作为——data_path;

  • 对于 SVHN,请从该地址(http://ufldl.stanford.edu/housenumbers/)下载 train_32x32.mat 和 test_32x32.mat 文件,并使用包含这些文件的目录的路径作为——data_path;

  • 对于 CelebA,你需要安装 OpenCV。数据下载地址:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html。你需要创建 celebA 文件夹,该文件夹包含 Anno 和 img_align_celeba 子文件夹。其中 Anno 必须包含 list_attr_celeba.txt,img_align_celeba 必须包含.jpg 文件。你还需要通过在——data_path <path>(其中<path>是包含了 celebA 的文件夹的路径)中运行 datasets/crop_faces.py 脚本对图像进行剪裁。训练模型的时候,你需要在——data_path 中使用相同的<path>。


无监督学习


你可以在没有 -- semi 参数的情况下通过运行 bayesian_gan_hmc 脚本对模型进行无监督训练。例如,使用以下命令:


 
           
 
 
./bayesian_gan_hmc.py --data_path <data_path> --dataset svhn --numz 1 --num_mcmc 10 --out_dir <results_path> --train_iter 75000 --save_samples --n_save 100


在 SVHN 数据集上训练模型。该命令会使用这个方法执行 75000 次迭代,每 100 次迭代保存样本。这里<results_path>必须是保存结果的目录。可查看数据准备部分,了解如何设置<data_path>。可查看训练选项部分,了解其它训练选项。



半监督训练


你可以使用--semi 选项来运行 bayesian_gan_hmc,对模型进行半监督训练。使用-N 参数设置训练所用标注样本的数量。例如,使用


 
           
 
 
./bayesian_gan_hmc.py --data_path <data_path> --dataset cifar --numz 1 --num_mcmc 10--out_dir <results_path> --train_iter 75000 --N 4000 --semi --lr 0.00005


在 CIFAR10 数据集上用 4000 个标注样本训练该模型。该命令使训练经历 75000 次迭代,输出结果储存在<results_path> 文件夹中。



要想在 MNIST 数据集上使用 200 个标注样本训练该模型,你需要使用以下命令:


 
           
 
 
./bayesian_gan_hmc.py --data_path <data_path>/ --dataset mnist --numz 5 --num_mcmc 5--out_dir <results_path> --train_iter 30000 -N 200 --semi --lr 0.001



自定义数据


要想在自定义数据集上训练该模型,你需要用特定的接口定义类别。假设我们想在 digits 数据集上训练模型。该数据集包含 8x8 数字图像。假设数据的储存格式为 x_tr.npy、y_tr.npy、x_te.npy 和 y_te.npy。我们假设 x_tr.npy 和 x_te.npy 的形态为 (?, 8, 8, 1)。然后在 bgan_util.py 中定义该数据集对应的类别,如下所示:


 
           
 
 
class Digits:    def __init__(self):        self.imgs = np.load('x_tr.npy')        self.test_imgs = np.load('x_te.npy')        self.labels = np.load('y_tr.npy')        self.test_labels = np.load('y_te.npy')        self.labels = one_hot_encoded(self.labels, 10)        self.test_labels = one_hot_encoded(self.test_labels, 10)        self.x_dim = [8, 8, 1]        self.num_classes = 10    @staticmethod    def get_batch(batch_size, x, y):        """Returns a batch from the given arrays.        """        idx = np.random.choice(range(x.shape[0]), size=(batch_size,), replace=False)        return x[idx], y[idx]    def next_batch(self, batch_size, class_id=None):        return self.get_batch(batch_size, self.imgs, self.labels)    def test_batch(self, batch_size):        return self.get_batch(batch_size, self.test_imgs, self.test_labels)


该类别必须具备 next_batch 和 test_batch,以及 imgs、labels、test_imgs、test_labels、x_dim 和 num_classes。


现在,我们可以把 Digits 类输入到 bayesian_gan_hmc.py:


 
           
 
 
from bgan_util import Digits


将下列行添加至处理--dataset 参数的代码中:


 
           
 
 
if args.dataset == "digits":    dataset = Digits()


准备过程完成后,我们可以使用以下命令训练模型:


 
           
 
 
./bayesian_gan_hmc.py --data_path <any_path> --dataset digits --numz 1 --num_mcmc 10 --out_dir <results path> --train_iter 5000 --save_samples



本文为机器之心编译,转载请联系本公众号获得授权

✄------------------------------------------------

加入机器之心(全职记者/实习生):hr@jiqizhixin.com

投稿或寻求报道:content@jiqizhixin.com

广告&商务合作:bd@jiqizhixin.com

相关文章
相关标签/搜索