VAE Tutorial 2

VAE Tutorial 2: VAE 논문 & 코드 리뷰

이번 포스트에서는 VAE의 원래 논문인 “Auto-Encoding Variational Bayes”의 내용 중 일부를 다루고 Pytorch VAE example code를 리뷰해봅니다.


1. Variational Inference & Reparameterization Trick

논문의 Abstract에서는 다음과 같은 말을 던지면서 시작합니다. 이 포스트에서도 기본적으로 MNIST 데이터셋에 대한 generative model의 전제하에 이야기합니다.

How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets?

VAE가 하고 싶은 것은 명확합니다. 또한 그것을 가로막는 문제도 명확히 제시합니다.

  • 목표: efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables
  • 문제: intractable posterior, large dataset

이것을 이해하기 위해 이전 포스트에서 언급했던 식을 다시 살펴보겠습니다. directed probabilistic model 이라는 말은 explicit density estimation이라고도 볼 수 있습니다. \(p_{\theta}(z)\)에서 latent variable을 sampling 한다면 대부분의 \(z\)에 대해 \(p_{\theta}(x \vert z)\)는 거의 0의 값을 가질 것입니다. 그렇다면 다음 식처럼 \(p_{\theta}(x)\)에 대한 Monte-Carlo estimation을 하는데 sample이 너무 많이 필요하게됩니다. 데이터포인트 하나당 많은 sample을 필요로 하므로 large dataset에 대해서 이렇게 estimation을 하면 학습과정이 너무 느려집니다. \[p_{\theta}(x)=\int p_{\theta}(z)p_{\theta}(x|z)dz \approx \frac{1}{N}\sum_{i=1}^N p_{\theta}(x|z^{(i)})\]

따라서 데이터에 dependent하게 \(z\)를 sampling하기 위해 posterior \(p_{\theta}(z \vert x)\)를 정의했었는데 이 posterior는 intractable 합니다. 따라서 이 posterior를 approximate하는 새로운 posterior \(q_{\phi}(z \vert x)\)를 정의했었습니다. 이렇게 우리가 다루기 쉬운 paramterized된 posterior를 대신 사용하고 이 posterior가 원래의 posterior와 최대한 가깝게 만드는 것이 variational inference 입니다. 또는 variational bayes라고도 합니다. 이것은 다음 그림과 같이 파란색 분포에 초록색 분포를 최대한 맞추는 것과 같습니다.

그림출처 https://blog.evjang.com/2016/08/variational-bayes.html


최적화과정을 거쳐 approximate된 posterior와 posterior가 최대한 비슷해지면 \(q_{\phi}(z \vert x)\)를 통해 inference 할 수 있습니다. Inference 하는 것을 variational parameter \(\phi\)를 통해 하는 것입니다. 이 최적화과정은 이전 포스트에서 구한 ELBO를 최대화하는 것입니다. \[\mathcal{L}(x^{(i)}, \theta, \phi) = \mathbb{E}_{z}[log p_{\theta}(x^{(i)} \vert z)] - D_{KL}(q_{\phi}(z \vert x^{(i)}) \Vert p_{\theta}(z)) - (1)\] \[\theta^*, \phi^* = argmax_{\theta, \phi}\sum_{i=1}^N \mathcal{L}(x^{(i)}, \theta, \phi)\]

이 ELBO의 값을 maximize하는 parameter는 (1) analytic하게 구하거나 (2) stochastic gradient ascent를 통해 구할 수 있습니다. Analytic하게 구하는 방식 중에 Mean-Field Variational Bayes가 있습니다. 논문에서는 이 방법이 likelihood function인 \(p_{\theta}(x \vert z)\)이 뉴럴넷과 같은 복잡한 함수로 표현될 경우 intractable 하다고 말합니다. 논문에서 VAE의 방법과 Mean-Field Variational Bayes 사이의 차이에 대해서 다음과 같이 언급합니다.

Note that in contrast with the approximate posterior in mean-field variational inference, it is not necessarily factorial and its parameters φ are not computed from some closed-form expectation

따라서 (1)식의 gradient를 구해서 stochastic하게 parameter를 업데이트하는 방식을 사용할 것입니다. 이 때 (1) 식을 \(\theta\)에 대해서 미분하는 것은 문제가 없으나 \(\phi\)에 대해서 미분하는 것은 문제가 있습니다. (1) 식 중에서도 첫번째 항이 문제가 있습니다. 첫번째 항의 expectation 안에 있는 함수를 \(f(z)\)라고 가정해보겠습니다. 이 함수의 expectation에 대한 미분은 다음과 같이 쓸 수 있습니다. \[\nabla_{\phi}\mathbb{E}_{q_{\phi}(z)}[f(z)] = \int \nabla_{\phi} q_{\phi}(z) f(z) dz\] \[= \int q_{\phi}(z)\frac{\nabla_{\phi} q_{\phi}(z)}{q_{\phi}(z)} f(z) dz\] \[=\mathbb{E}_{q_{\phi}(z)}[f(z)\nabla_{\phi}log q_{\phi}(z)]\]

이 미분값은 monte-carlo estimation을 통해 estimate 할 수 있습니다. 이 때, \(z^{(i)}\)는 \(q_{\phi}(z \vert x^{(i)})\)로부터 sampling 합니다. 따라서 다음 gradient 값은 상당히 variance가 높습니다. 이 경우 impractical 합니다. \[\frac{1}{L}\sum_{l=1}^L f(z^{(l)})\nabla_{\phi}log q_{\phi}(z^{(l)})\]

이러한 문제를 해결하기 위해 VAE는 reparameterization trick이라는 technique을 사용합니다. \(z\)를 posterior \(q_{\phi}(z \vert x)\)로부터 sampling 하는 것이 아니라 differentiable 한 함수 \(g_{\phi}(\epsilon, x)\)로부터 deterministic하게 정해진다고 보는 것입니다. 이 때, \(\epsilon\)은 noise variable입니다. \[\tilde{z} = g_{\phi}(\epsilon, x) \qquad where \quad \epsilon \sim p(\epsilon)\]

이 경우 다음과 같이 \(f(z)\)의 \(q_{\phi}(z)\) 대한 expectation을 \(epsilon\)에 대한 expectation으로 바꿀 수 있습니다. 이제 바뀐 expectation에 대해 monte carlo estimation을 적용할 수 있습니다. \(f(g_{\phi}(z^{(l)}, x^{(i)}))\)는 \(\theta\)와 \(\phi\)에 대해 미분가능하기 때문에 바로 미분할 수 있습니다. \[\mathbb{E}_{q_{\phi}(z \vert x^{(i)})}[f(z)] = \mathbb{E}_{\epsilon}[f(g_{\phi}(\epsilon, x^{(i)}))] = \frac{1}{L}\sum_{l=1}^L f(g_{\phi}(z^{(l)}, x^{(i)}))\]

이 수식을 이용해서 ELBO를 고쳐쓸 수 있습니다. 이 식을 SGVB(Stochastic Gradient Variational Bayes) estimator라고 합니다. 이 때, \(z^{(i,l)}=g_{\phi}(\epsilon^{(l)}, x^{(i)})\)이고 \(\epsilon^{(l)} \sim p(\theta)\) 입니다. 보통은 \(g_{\phi}(\epsilon, x) = \mu + \sigma\epsilon\)으로 많이 사용합니다(univariate gaussian case). \[\mathcal{\tilde{L^B}}(x^{(i)}, \theta, \phi) = \frac{1}{L}\sum_{l=1}^L (log p_{\theta}(x^{(i)} \vert z^{(i, l)})) - D_{KL}(q_{\phi}(z \vert x^{(i)}) \Vert p_{\theta}(z))\]

이러한 reparameterization trick을 그림으로 보자면 다음과 같습니다. 원래는 encoder로부터 구한 data dependent한 mean과 variance를 가지고 posterior를 만듭니다. 그 posterior로부터 \(z\)를 샘플링한 다음에 그 \(z\)를 가지고 decoder는 data를 generation 했습니다. 하지만 reparametization을 하면 computation graph 내의 sampling 과정이 noise sampling이 되어 옆으로 빠져버립니다. 따라서 Back propagation을 통해 decoder output으로부터 encoder까지 gradient가 전달될 수 있습니다.

그림출처 https://arxiv.org/pdf/1606.05908.pdf

이렇게 업데이트를 하는 알고리즘이 Auto-Encoding Variational Bayes이며 다음과 같습니다.



2. VAE code example

2.1 VAE example of paper

prior와 posterior를 모두 gaussian으로 가정하고 likelihood를 Bernoulli라고 가정하면 ELBO 식은 다음과 같이 쓸 수 있습니다. \(y\)는 \(z\)와 decoder를 통해 나온 값입니다. 마지막 식의 첫번째 항은 잘 보면 cross-entropy 인 것을 알 수 있습니다. \[\mathcal{\tilde{L^B}}(x^{(i)}, \theta, \phi) = \frac{1}{L}\sum_{l=1}^L (log p_{\theta}(x^{(i)} \vert z^{(i, l)})) - D_{KL}(q_{\phi}(z \vert x^{(i)}) \Vert p_{\theta}(z))\] \[\mathcal{\tilde{L^B}}(x^{(i)}, \theta, \phi) = \frac{1}{L}\sum_{l=1}^L (log p_{\theta}(x^{(i)} \vert z^{(i, l)})) + \frac{1}{2}\sum_{j=1}^J(1 + log((\sigma_j^{(i)})^2) - (\mu_{j}^{(i)})^2 - (\sigma_j^{(i)})^2)\] \[\mathcal{\tilde{L^B}}(x^{(i)}, \theta, \phi) = \frac{1}{L}\sum_{l=1}^L (x_i log y_{(i, l)} + (1 - x_i)(1 - y_{(i,l)}) + \frac{1}{2}\sum_{j=1}^J(1 + log((\sigma_j^{(i)})^2) - (\mu_{j}^{(i)})^2 - (\sigma_j^{(i)})^2)\]


논문의 실험결과는 다음과 같습니다. 이 그림은 \(z\)를 임의로 변형시켜보면서 data를 생성해낸 결과입니다. \(z\)에 따라 data가 continuous하게 변화하는 것을 볼 수 있습니다. 또한 비슷한 숫자끼리는 서로 뭉쳐있음을 알 수 있습니다. 이를 통해 VAE가 의미있는 representation을 학습하는 것을 확인합니다.

이 마지막 식을 가지고 이제 우리는 VAE 코드를 살펴볼 수 있습니다. Pytorch는 공식적으로 VAE에 대한 simple한 example을 제공합니다. 지금까지 살펴본 VAE의 이론에 충실한 코드입니다.


2.2 VAE network class

VAE의 network 구조는 다음과 같습니다. 28x28의 이미지를 일렬로 펴서 784 크기의 vector로 만듭니다. 이 vector를 입력으로 받기 때문에 self.fc1이 784에서 400개의 hidden unit으로의 fully connected인 것을 알 수 있습니다.

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


Encoder의 구조는 다음과 같습니다. 784-d vector가 input으로 들어와서 400-d hidden layer를 통과하고 이 때 activation function은 relu를 사용합니다. 그 이후에 20개의 gaussian 분포에 대한 mean과 variance를 내보낼 것입니다. Mean은 self.fc21(h1)이며 linear 연산을 통한 output입니다. Variance의 경우 항상 0보다 크거나 같아야하는데 linear 연산을 한 self.fc22(h1)의 경우 -값이 될 수 있습니다. 따라서 이 값을 variance로 보지 않고 log variance라고 보는 것입니다.

def encode(self, x):
    h1 = F.relu(self.fc1(x))
    return self.fc21(h1), self.fc22(h1)


latent space 상에서의 mean과 log variance를 구했다면 reparameterization을 통해 latent vector \(z\)를 sampling 할 수 있습니다. Noise로부터 \(\epsilon\)인 eps를 구하고 이 eps와 standard deviation을 곱하고 mean을 더해줍니다.

def reparameterize(self, mu, logvar):
    if self.training:
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
    else:
        return mu

\(z\)를 구하고나면 이 latent 값으로부터 decoder를 통해 data에 대한 Bernoulli distribution을 출력할 수 있습니다. Bernoulli 분포는 0에서 1 사이이므로 sigmoid 함수를 output layer의 activation function으로 사용합니다.

def decode(self, z):
    h3 = F.relu(self.fc3(z))
    return F.sigmoid(self.fc4(h3))


결국 VAE의 forward path는 다음과 같습니다. x.view(-1, 784)는 이미지를 784-D vector로 만드는 부분입니다.

def forward(self, x):
    mu, logvar = self.encode(x.view(-1, 784))
    z = self.reparameterize(mu, logvar)
    return self.decode(z), mu, logvar


2.3 Train VAE

학습에 관련된 가장 중요한 부분 중의 하나인 loss function 정의 부분입니다. 다음 수식을 생각해서 코드로 어떻게 구현되었는지 보면 됩니다. \[\mathcal{\tilde{L^B}}(x^{(i)}, \theta, \phi) = \frac{1}{L}\sum_{l=1}^L (x_i log y_{(i, l)} + (1 - x_i)(1 - y_{(i,l)}) + \frac{1}{2}\sum_{j=1}^J(1 + log((\sigma_j^{(i)})^2) - (\mu_{j}^{(i)})^2 - (\sigma_j^{(i)})^2)\]

첫번째 항은 F.binary_cross_entropy로 구현되었으며 두번째 항은 KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())로 구현되었습니다.

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


학습 코드 부분은 상당히 간단합니다. model.train()을 통해 현재 학습할 것이라는 것을 선언합니다. 그 이후에 train_loader를 통해 mini_batch를 추출합니다. 그 다음에 model에 data를 입력으로 넣어서 출력을 받습니다. 그 출력을 통해 loss function 값을 구할 수 있고 loss.backward()를 통해 back-propagation으로 각 parameter의 gradient 값을 구합니다. 그 이후에 optimizer(Adam optimizer)를 통해 각 parameter를 update 합니다.

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


2.4 Evaluate VAE

학습과정을 evaluate 하는 것은 더 간단합니다. model.eval()을 통해 평가 중이라는 것을 선언합니다. 그 이후에 test dataset에 대해 reconstruction을 출력합니다. 그 이후에 loss function 값을 출력해서 학습이 어떻게 진행되고 있는지 평가합니다. 그리고 각 평가과정마다 생성된 하나의 sample을 저장합니다.

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

2.5 Main loop

epoch마다 train을 하고 test를 한 다음에 64개의 sample 이미지를 생성합니다. latent는 임의로 20개를 normal distribution에서 sampling 합니다. 이 sampling된 latent variable을 Decoder에 통과시키면 decoder는 이미지를 생성해냅니다.

for epoch in range(1, args.epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 20).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                   'results/sample_' + str(epoch) + '.png')


이 코드는 python main.py를 하면 바로 실행이 되며 실행 화면은 다음과 같다.

첫 epoch 때 test set에 대한 loss는 119.4120이며 real data reconstruction과 z sampled recontruction은 다음 사진과 같습니다.

50 epoch 때 test set에 대한 loss는 93.3473이며 real data reconstruction과 z sampled recontruction은 다음 사진과 같습니다.

학습이 잘 된 것을 확인할 수 있습니다. 이제 다음 post에서 VAE를 사용해 Music generation이라는 도메인에 적용한 MusicVAE 를 살펴보겠습니다.




© 2018. by Woongwon Lee

Powered by dnddnjs