GANomaly
一,简介
异常检测旨在只使用正常样本建模从而能够区分OK样本与NG样本,所面对的场景的数据分布极不平衡, 通常OK样本非常多,而NG样本非常少。
自编码器是异常检测算法中比较经典的模型,它利用大量OK训练一个自编码网络,然后通过原图与重构图像之间的重构误差来检测NG样本,但该方法非常容易受噪声影响,其对NG样本也能够重建,导致所谓的重构误差“崩塌”。
GANomaly 采用编码器-解码器-编码器的模型结构, 同时对“原图-》重建图” 及“原图的高维特征编码->重建图的高维特征编码”进行重构误差约束。另外引入生成对抗网络的对抗训练思想, Encoder-Decoder-Encoder结构当作生成网络G-Net, 又定义了一个判别网络D-Net。
推理接断,用于推断异常的不是原图和重建图的差异,而是第一部分编码器产生的隐空间特征(原图的编码)和第二部分编码器产生的隐空间特征(重建图的编码)的差异。这种方法更关注图片实质内容的差异,对图片中的微小变化不敏感,因而能解决自编码器中易受噪声影响的问题,鲁棒性更好。
二, 网络结构
(1) G-Net, Encoder 1- Decoder - Encoder 2 , 自编码器结构参考了DCGAN, Encoder 1E将一张3通道的图片映射为一个n维的向量,Decoder则为Encoder的逆过程。Encoder2将重建出的图像在编码为一个n维的向量。
(2)D_Net, 判别网络用于区分原图和重建图,即要将原图判别为真,将重建图判别为假。它的结构和第一个子网络的解码网络是一样的。D-Net的引入,是为了引入对抗训练思想,旨在学到更好的G-Net。
三,损失函数
本文包含三个子网络,每个子网络对应一个损失函数。
第一个子网络的损失是自编码器的重建损失,这里借鉴了pix2pix文章中生成网络的损失,采用的是L1损失,而不是L2损失。因为采用L2损失生成的图像通常比采用L1生成的图像要模糊。
第二个子网络的损失是编码网络的损失,这里需要比对的是原图和重建图在高一层抽象空间中的差异,即两个bottleneck(上文中的bottleneck1和bottleneck2)间的差异,采用的是L2损失。
第三个子网络的损失是常规的GAN中判别网络的损失,这里采用的是二分类的交叉熵损失,论文中采用的L2。
四, 训练及推理
本文采用的训练策略和常规的GAN一样的,即交替地优化D-Net和G-Net。
(1) 优化D-Net时,采用的损失为上述第三个子网络的损失, 输入为 concat(input_real, input_fake)。input_fake由G-Net生成,在训练D-Net时,G-Net参数固定。
(2) 优化G-Net时,采用的损失比较复杂:
主体为重建损失Lrec,编码损失Lenc为重建损失的一个约束,对抗损失Ladv则用来和D-Net博弈。需要注意的一点是,这里的对抗损失的输入对象和优化D-Net时的输入对象是不一样的,这里的 input_d为input_fake,这和常规GAN的训练是一致的。
五,推理
前面提到,本文采用的推断方式和一般的基于自编码器的异常检测方法是不一样的。这里推断以来的不是重建损失Lrec,而是编码损失Lenc。具体而言,网络训练收敛以后,我们可以计算得到所有OK样本中的Lenc值,选取其中最大的作为判别阈值。推断时,给定一张图片,我们可以利用学好的网络,计算其 Lenc值,如果它小于判别阈值则判断为OK样本(正常样本),大于则判断为NG样本(异常样本)。