DatasetGAN(超簡略ver)を実装した話

論文実装

昨日のブログではCVPR2021oralの論文である、DatasetGANの紹介した

紹介しているうちに、次のような疑問がうまれた

DatasetGANはセグメンテーションに特化した話なのか?

DatasetGANはStyleGANのAdaIN層の特徴が良かったからうまくいったのか?

疑問を解消するために、次のような超簡略verのDatasetGANを作成した

  • StyleGANではなくバニラなGANを使う
  • セグメンテーションのアノテーションではなく、ラベルのアノテーションを行う

GAN実装参考:https://github.com/znxlwm/pytorch-generative-model-collections

GANの実装

重要ではないので具体的な実装は付録にまわす。とりあえずGANがまともに動いているかを確認

データはMNISTの訓練データを利用した

ロスは全然安定してないものの、generatorはまともなものを生成しているのでヨシ!とする

解釈器の実装

付録を見てもらえればわかるが、Generatorの特徴マップを抽出できるようにしている

def extract_features(self, z):
   z1 = self.fc1(z)
    z2 = self.fc2(z1)
    z2 = z2.view(-1, 128, 7, 7)
    z2 = self.deconv1(z2)
    x = self.deconv2(z2)
    return z, z1, z2, x

これらの特徴マップと出力を、新たな入力とし、少数アノテーションを教師として解釈器の学習を行う

解釈器は単純な構造のCNNを利用している

class Interpreter(nn.Module):
    def __init__(self):
        super(Interpreter, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(64+64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )        
        
        self.fc = nn.Sequential(
            nn.Linear(128 * 7 * 7*2 + 32, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 10),
        )
        
        initialize_weights(self)
    
    def forward(self, x, z2, z1, z):
        h = self.conv1(x)
        h = self.conv2(torch.cat([h, z2], dim=1))
        h = h.view(-1, 128 * 7 * 7)
        h = self.fc(torch.cat([h, z1, z], dim=1))
        return h
    
def detach(*xs):
    return [x.detach() for x in xs]

生成器から100個のサンプルを取り出し、生成器に勾配が伝わらないようにdetachする

G = Generator().cuda()
G.load_state_dict(torch.load("./logs/MNIST_G_100.pth"))
G.eval()

def detach(*xs):
    return [x.detach() for x in xs]

with torch.no_grad():
    torch.manual_seed(42)
    sample_z = torch.randn((100, z_dim))
    sample_z = sample_z.cuda()
    
    z, z1, z2, x = G.extract_features(sample_z)
    z, z1, z2, x = detach(z, z1, z2, x)
    save_image(x.data.cpu(), os.path.join(log_dir, 'training_images.png'), nrow=10)

取り出した100個のサンプルにアノテーションを行う

5, 2, 9, 9, 9, …

t = [
    5, 2, 9, 9, 9, 8, 6, 3, 5, 7,
    6, 2, 7, 5, 2, 3, 5, 0, 7, 2,
    7, 5, 6, 0, 5, 4, 7, 6, 7, 7,
    0, 1, 6, 4, 1, 6, 1, 5, 4, 1,
    3, 2, 4, 6, 6, 4, 7, 3, 4, 9,
    2, 2, 9, 1, 8, 3, 6, 5, 7, 0,
    1, 7, 8, 6, 0, 8, 9, 2, 0, 5,
    8, 6, 0, 5, 4, 9, 5, 7, 6, 2,
    7, 4, 3, 0, 8, 2, 3, 7, 4, 7, 
    6, 7, 4, 9, 6, 5, 1, 7, 1, 7
    ]

10個の解釈器を上記の生成画像を使って訓練する

N = 10
Is = [Interpreter().cuda() for i in range(N)]
Is_opt = [optim.Adam(Is[i].parameters()) for i in range(N)]
criterion = nn.CrossEntropyLoss()
bs = 5
t = torch.tensor(t).cuda()

def train(optimizer, interpreter, z, z1, z2, x, t):
    running_loss = 0.0
    z, z1, z2, x, t = shuffle(z, z1, z2, x, t)
    for i in range(int(len(t)/bs)):
        z_, z1_, z2_, x_, t_ = z[bs*i:bs*(i+1)], z1[
            bs*i:bs*(i+1)], z2[bs*i:bs*(i+1)], x[bs*i:bs*(i+1)], t[bs*i:bs*(i+1)]
 
        optimizer.zero_grad()
 
        y = interpreter(x_, z2_, z1_, z_)
        loss = criterion(y, t_)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss/(len(t)/bs)

for opt, I in zip(Is_opt, Is):
    for epoch in range(30): 
        loss = train(opt, I, z, z1, z2, x, t)

解釈器による予測を行い、多数決の結果と不確かさ(エントロピー)を画像の上に表示する

def interpreter_pred(Is, x, z2, z1, z):
    preds = []
    for I in Is:
        pred = I(x, z2, z1, z).argmax(1).detach().cpu().numpy()
        preds.append(pred)
    return np.array(preds)

for I in Is:
    I.eval()

sample_size = 5
with torch.no_grad():
    sample_z = torch.randn((sample_size, z_dim))
    sample_z = sample_z.cuda()
    
    z, z1, z2, x = G.extract_features(sample_z)
    z, z1, z2, x = detach(z, z1, z2, x)
    
    preds = interpreter_pred(Is, x, z2, z1, z)

plt.figure(figsize=(7, 2))
for i, x_ in enumerate(x.cpu().numpy()):
    plt.subplot(1, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    hist, _ = np.histogram(preds[:, i], range(11), density=True)
    plt.title("{}, {:.3f}".format(hist.argmax(), entropy(hist)))
    plt.imshow(x_[0], cmap=plt.cm.gray)
plt.show()

不確かさが0だと解釈器の予測が割れていない状態を意味する

画像に alt 属性が指定されていません。ファイル名: download-1-1.png
パン
パン

これで一旦「無限」データセット生成器が完成した

エントロピーが0.4以下(モデルが9つ以上同じ予測)のとき、データセットに加えることにする

def make_datasets(G, Is, M=1000, thresh=0.1, z_dim=32):
    images = []
    labels = []
    sample_size = 64
    for j in tqdm(range(M)):
        with torch.no_grad():
            sample_z = torch.randn((sample_size, z_dim))
            sample_z = sample_z.cuda()

            z, z1, z2, x = G.extract_features(sample_z)
            z, z1, z2, x = detach(z, z1, z2, x)

            preds = interpreter_pred(Is, x, z2, z1, z)

        for i, x_ in enumerate(x.cpu().numpy()):
            hist, _ = np.histogram(preds[:, i], range(11), density=True)
            pred = hist.argmax()
            ent = entropy(hist)
            if ent < thresh:
                images.append(x_)
                labels.append(pred)
    return np.array(images), np.array(labels)

images, labels = make_datasets(G, Is)   

plt.figure(figsize=(20, 6))
for i, (x_, y_) in enumerate(zip(images[:100], labels[:100])):
    plt.subplot(5, 20, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.title(f"{y_}")
    plt.imshow(x_[0], cmap=plt.cm.gray)
plt.show()

画像の上の数字は解釈器で予測したラベル

苦笑いする人
苦笑いする人

いくつか怪しいラベルが…

分類器の実装

単純なCNNった分類器を用意し、上記で作成したデータセットを使って訓練する

C = Classifier().cuda()
optimizer = optim.Adam(C.parameters())
C.train()

bs = 128
for epoch in range(30): 
    running_loss = 0.0
    images, labels = shuffle(images, labels)
    for i in range(int(len(images)/bs)):
        x_, t_ = images[bs*i:bs*(i+1)], labels[bs*i:bs*(i+1)]
 
        optimizer.zero_grad()
        y = C(torch.from_numpy(x_).cuda())
        loss = criterion(y, torch.from_numpy(t_).cuda())
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    loss = running_loss/(len(t)/bs)

テストデータで分類器の評価を行う

dataset = datasets.MNIST("./data/mnist", transform=transform, train=False)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)  

C.eval()
pred = []
true = []
with torch.no_grad():
    for x_, t_ in data_loader:
        y = C(x_.cuda()).detach().cpu().numpy().argmax(1)
        pred.extend(y)
        true.extend(t_.numpy())

print(classification_report(pred, true))
confusion_matrix(pred, true)

# 出力
              precision    recall  f1-score   support

           0       0.85      0.92      0.88       900
           1       0.84      0.94      0.89      1013
           2       0.87      0.70      0.78      1282
           3       0.72      0.74      0.73       983
           4       0.69      0.91      0.78       745
           5       0.66      0.69      0.68       853
           6       0.95      0.74      0.83      1232
           7       0.82      0.82      0.82      1023
           8       0.81      0.86      0.83       919
           9       0.78      0.75      0.77      1050

    accuracy                           0.80     10000
   macro avg       0.80      0.81      0.80     10000
weighted avg       0.81      0.80      0.80     10000

array([[831,   0,   3,   5,   0,  16,   6,  10,  23,   6],
       [  1, 957,   0,   0,   8,   1,   0,  22,   4,  20],
       [ 22,  22, 899, 179,  35,   4,   8,  74,  25,  14],
       [  5,   1,   7, 727,   1, 142,   0,  22,  29,  49],
       [  0,   6,   1,   0, 676,   2,   1,   6,   5,  48],
       [ 21,  63,   3,  36,  21, 592,  30,  19,  35,  33],
       [ 79,  52,   8,  13,  29,  93, 911,   2,  37,   8],
       [  1,  21,  94,  10,  12,   2,   0, 838,  22,  23],
       [ 14,  11,  14,  33,   2,  35,   2,   1, 788,  19],
       [  6,   2,   3,   7, 198,   5,   0,  34,   6, 789]])
苦笑いする人
苦笑いする人

全然精度でてないね。半教師あり学習なんだから98%近くでてもいいのに。

考察

結果は巷の半教師あり学習と比べて全然精度がでなかった

GANの特徴マップで得られる空間的成分を有効に活用したセグメンテーションのほうがDatasetGANを能力をいかすことができるのだろう

付録:GANの実装

import os
from os.path import join, exists
from glob import glob

from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

from sklearn.utils import shuffle
from sklearn.metrics import confusion_matrix, classification_report
from scipy.stats import entropy
  
batch_size = 128
lr = 0.0002
z_dim = 32
log_dir = './logs'

transform = transforms.Compose([
    transforms.ToTensor()
])
dataset = datasets.MNIST("./data/mnist", transform=transform, train=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)  

def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        
        
def generate(epoch, G, log_dir='logs'):
    G.eval()
    
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    torch.manual_seed(42)
    sample_z = torch.randn((64, z_dim))
    sample_z = sample_z.cuda()

    samples = G(sample_z).data.cpu()
    save_image(samples, os.path.join(log_dir, 'MNIST_epoch_%03d.png' % (epoch)))

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.fc1 = nn.Sequential(
            nn.Linear(z_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * 7 * 7),        
        )
        
        self.fc2 = nn.Sequential(
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU(),
        )
        
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        )
        
        self.deconv2 = nn.Sequential(
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )
        
        initialize_weights(self)
        
    def extract_features(self, z):
        z1 = self.fc1(z)
        z2 = self.fc2(z1)
        z2 = z2.view(-1, 128, 7, 7)
        z2 = self.deconv1(z2)
        x = self.deconv2(z2)
        return z, z1, z2, x
        

    def forward(self, z):
        x = self.fc1(z)
        x = self.fc2(x)
        x = x.view(-1, 128, 7, 7)
        x = self.deconv1(x)
        x = self.deconv2(x)        
        return x

class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128 * 7 * 7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        )
        
        initialize_weights(self)
    
    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * 7 * 7)
        x = self.fc(x)
        return x

G = Generator().cuda()
D = Discriminator().cuda()

G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCELoss()

def train(D, G, criterion, D_optimizer, G_optimizer, data_loader):
    D.train()
    G.train()

    y_real = torch.ones(batch_size, 1).cuda()
    y_fake = torch.zeros(batch_size, 1).cuda() 


    D_running_loss = 0
    G_running_loss = 0
    for batch_idx, (real_images, _) in enumerate(data_loader):
        if real_images.size()[0] != batch_size:
            break

        z = torch.normal(mean=torch.zeros(batch_size, z_dim), std=torch.ones(batch_size, z_dim))
        real_images, z = real_images.cuda(), z.cuda()

        # Discriminator
        D_optimizer.zero_grad()

        D_real = D(real_images)
        D_real_loss = criterion(D_real, y_real)

        fake_images = G(z)
        D_fake = D(fake_images.detach())
        D_fake_loss = criterion(D_fake, y_fake)

        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        D_optimizer.step()
        D_running_loss += D_loss.item()

        # Generator
        z = torch.randn((batch_size, z_dim)).cuda()

        G_optimizer.zero_grad()

        fake_images = G(z)
        D_fake = D(fake_images)
        G_loss = criterion(D_fake, y_real)
        G_loss.backward()
        G_optimizer.step()
        G_running_loss += G_loss.item()
    
    D_running_loss /= len(data_loader)
    G_running_loss /= len(data_loader)
    
    return D_running_loss, G_running_loss

num_epochs = 100

history = {}
history['D_loss'] = []
history['G_loss'] = []
for epoch in range(num_epochs):
    D_loss, G_loss = train(D, G, criterion, D_optimizer, G_optimizer, data_loader)
    
    print('epoch %d, D_loss: %.4f G_loss: %.4f' % (epoch + 1, D_loss, G_loss))
    history['D_loss'].append(D_loss)
    history['G_loss'].append(G_loss)
    
    generate(epoch + 1, G, log_dir)
    if (epoch+1) % 50 == 0:
        torch.save(G.state_dict(), os.path.join(log_dir, 'MNIST_G_%03d.pth' % (epoch + 1)))
        torch.save(D.state_dict(), os.path.join(log_dir, 'MNIST_D_%03d.pth' % (epoch + 1)))
   

コメント

タイトルとURLをコピーしました