Skip to content

[pull] master from altosaar:master #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,85 @@
# Variational Autoencoder (Deep Latent Gaussian Model) in tf
Reference implementation for a variational autoencoder in TensorFlow.
# Variational Autoencoder in tensorflow and pytorch
[![DOI](https://zenodo.org/badge/65744394.svg)](https://zenodo.org/badge/latestdoi/65744394)

Mean-field variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder).
Reference implementation for a variational autoencoder in TensorFlow and PyTorch.

I recommend the PyTorch version. It includes an example of a more expressive variational family, the [inverse autoregressive flow](https://arxiv.org/abs/1606.04934).

Variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder).

Blog post: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/

Example output with importance sampling for estimating the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. Finaly marginal likelihood on the test set of `-97.10` nats.

```
$ python train_variational_autoencoder_pytorch.py --variational mean-field
step: 0 train elbo: -558.69
step: 0 valid elbo: -391.84 valid log p(x): -363.25
step: 5000 train elbo: -116.09
step: 5000 valid elbo: -112.57 valid log p(x): -107.01
step: 10000 train elbo: -105.82
step: 10000 valid elbo: -108.49 valid log p(x): -102.62
step: 15000 train elbo: -106.78
step: 15000 valid elbo: -106.97 valid log p(x): -100.97
step: 20000 train elbo: -108.43
step: 20000 valid elbo: -106.23 valid log p(x): -100.04
step: 25000 train elbo: -99.68
step: 25000 valid elbo: -104.89 valid log p(x): -98.83
step: 30000 train elbo: -96.71
step: 30000 valid elbo: -104.50 valid log p(x): -98.34
step: 35000 train elbo: -98.64
step: 35000 valid elbo: -104.05 valid log p(x): -97.87
step: 40000 train elbo: -93.60
step: 40000 valid elbo: -104.10 valid log p(x): -97.68
step: 45000 train elbo: -96.45
step: 45000 valid elbo: -104.58 valid log p(x): -97.76
step: 50000 train elbo: -101.63
step: 50000 valid elbo: -104.72 valid log p(x): -97.81
step: 55000 train elbo: -106.78
step: 55000 valid elbo: -105.14 valid log p(x): -98.06
step: 60000 train elbo: -100.58
step: 60000 valid elbo: -104.13 valid log p(x): -97.30
step: 65000 train elbo: -96.19
step: 65000 valid elbo: -104.46 valid log p(x): -97.43
step: 65000 test elbo: -103.31 test log p(x): -97.10
```


Using a non mean-field, more expressive variational posterior approximation (inverse autoregressive flow, https://arxiv.org/abs/1606.04934), the test marginal log-likelihood improves to `-95.33` nats:

```
$ python train_variational_autoencoder_pytorch.py --variational flow
step: 0 train elbo: -578.35
step: 0 valid elbo: -407.06 valid log p(x): -367.88
step: 10000 train elbo: -106.63
step: 10000 valid elbo: -110.12 valid log p(x): -104.00
step: 20000 train elbo: -101.51
step: 20000 valid elbo: -105.02 valid log p(x): -99.11
step: 30000 train elbo: -98.70
step: 30000 valid elbo: -103.76 valid log p(x): -97.71
step: 40000 train elbo: -104.31
step: 40000 valid elbo: -103.71 valid log p(x): -97.27
step: 50000 train elbo: -97.20
step: 50000 valid elbo: -102.97 valid log p(x): -96.60
step: 60000 train elbo: -97.50
step: 60000 valid elbo: -102.82 valid log p(x): -96.49
step: 70000 train elbo: -94.68
step: 70000 valid elbo: -102.63 valid log p(x): -96.22
step: 80000 train elbo: -92.86
step: 80000 valid elbo: -102.53 valid log p(x): -96.09
step: 90000 train elbo: -93.83
step: 90000 valid elbo: -102.33 valid log p(x): -96.00
step: 100000 train elbo: -93.91
step: 100000 valid elbo: -102.48 valid log p(x): -95.92
step: 110000 train elbo: -94.34
step: 110000 valid elbo: -102.81 valid log p(x): -96.09
step: 120000 train elbo: -88.63
step: 120000 valid elbo: -102.53 valid log p(x): -95.80
step: 130000 train elbo: -96.61
step: 130000 valid elbo: -103.56 valid log p(x): -96.26
step: 140000 train elbo: -94.92
step: 140000 valid elbo: -102.81 valid log p(x): -95.86
step: 150000 train elbo: -97.84
step: 150000 valid elbo: -103.06 valid log p(x): -95.92
step: 150000 test elbo: -101.64 test log p(x): -95.33
```
43 changes: 43 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Get the binarized MNIST dataset and convert to hdf5.
From https://github.com/yburda/iwae/blob/master/datasets.py
"""
import urllib.request
import os
import numpy as np
import h5py


def parse_binary_mnist(data_dir):
def lines_to_np_array(lines):
return np.array([[int(i) for i in line.split()] for line in lines])
with open(os.path.join(data_dir, 'binarized_mnist_train.amat')) as f:
lines = f.readlines()
train_data = lines_to_np_array(lines).astype('float32')
with open(os.path.join(data_dir, 'binarized_mnist_valid.amat')) as f:
lines = f.readlines()
validation_data = lines_to_np_array(lines).astype('float32')
with open(os.path.join(data_dir, 'binarized_mnist_test.amat')) as f:
lines = f.readlines()
test_data = lines_to_np_array(lines).astype('float32')
return train_data, validation_data, test_data


def download_binary_mnist(fname):
data_dir = '/tmp/'
subdatasets = ['train', 'valid', 'test']
for subdataset in subdatasets:
filename = 'binarized_mnist_{}.amat'.format(subdataset)
url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format(
subdataset)
local_filename = os.path.join(data_dir, filename)
urllib.request.urlretrieve(url, local_filename)

train, validation, test = parse_binary_mnist(data_dir)

data_dict = {'train': train, 'valid': validation, 'test': test}
f = h5py.File(fname, 'w')
f.create_dataset('train', data=data_dict['train'])
f.create_dataset('valid', data=data_dict['valid'])
f.create_dataset('test', data=data_dict['test'])
f.close()
print(f'Saved binary MNIST data to: {fname}')
74 changes: 74 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
name: dev
channels:
- defaults
dependencies:
- blas=1.0=mkl
- ca-certificates=2019.5.15=1
- certifi=2019.6.16=py37_1
- freetype=2.9.1=hb4e5f40_0
- imageio=2.5.0=py37_0
- intel-openmp=2019.4=233
- jpeg=9b=he5867d9_2
- libcxx=4.0.1=hcfea43d_1
- libcxxabi=4.0.1=hcfea43d_1
- libedit=3.1.20181209=hb402a30_0
- libffi=3.2.1=h475c297_4
- libgfortran=3.0.1=h93005f0_2
- libpng=1.6.37=ha441bb4_0
- libtiff=4.0.10=hcb84e12_2
- mkl=2019.4=233
- mkl-service=2.3.0=py37hfbe908c_0
- mkl_fft=1.0.14=py37h5e564d8_0
- mkl_random=1.0.2=py37h27c97d8_0
- ncurses=6.1=h0a44026_1
- numpy=1.16.5=py37hacdab7b_0
- numpy-base=1.16.5=py37h6575580_0
- olefile=0.46=py37_0
- openssl=1.1.1d=h1de35cc_1
- pillow=6.1.0=py37hb68e598_0
- python=3.7.4=h359304d_1
- readline=7.0=h1de35cc_5
- setuptools=41.0.1=py37_0
- six=1.12.0=py37_0
- sqlite=3.29.0=ha441bb4_0
- tk=8.6.8=ha441bb4_0
- wheel=0.33.4=py37_0
- xz=5.2.4=h1de35cc_4
- zlib=1.2.11=h1de35cc_3
- zstd=1.3.7=h5bba6e5_0
- pip:
- absl-py==0.8.0
- astor==0.8.0
- attrs==19.1.0
- chardet==3.0.4
- cloudpickle==1.2.2
- decorator==4.4.0
- dill==0.3.0
- future==0.17.1
- gast==0.3.2
- google-pasta==0.1.7
- googleapis-common-protos==1.6.0
- grpcio==1.23.0
- h5py==2.10.0
- idna==2.8
- keras-applications==1.0.8
- keras-preprocessing==1.1.0
- markdown==3.1.1
- pip==19.2.3
- promise==2.2.1
- protobuf==3.9.1
- psutil==5.6.3
- requests==2.22.0
- tensorboard==1.14.0
- tensorflow==1.14.0
- tensorflow-datasets==1.2.0
- tensorflow-estimator==1.14.0
- tensorflow-metadata==0.14.0
- tensorflow-probability==0.7.0
- termcolor==1.1.0
- tqdm==4.36.0
- urllib3==1.25.3
- werkzeug==0.15.6
- wrapt==1.11.2
prefix: /usr/local/anaconda3/envs/dev

152 changes: 152 additions & 0 deletions flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Credit: mostly based on Ilya's excellent implementation here: https://github.com/ikostrikov/pytorch-flows"""
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F


class InverseAutoregressiveFlow(nn.Module):
"""Inverse Autoregressive Flows with LSTM-type update. One block.

Eq 11-14 of https://arxiv.org/abs/1606.04934
"""
def __init__(self, num_input, num_hidden, num_context):
super().__init__()
self.made = MADE(num_input=num_input, num_output=num_input * 2,
num_hidden=num_hidden, num_context=num_context)
# init such that sigmoid(s) is close to 1 for stability
self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2)
self.sigmoid = nn.Sigmoid()
self.log_sigmoid = nn.LogSigmoid()

def forward(self, input, context=None):
m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1)
s = s + self.sigmoid_arg_bias
sigmoid = self.sigmoid(s)
z = sigmoid * input + (1 - sigmoid) * m
return z, -self.log_sigmoid(s)


class FlowSequential(nn.Sequential):
"""Forward pass."""

def forward(self, input, context=None):
total_log_prob = torch.zeros_like(input, device=input.device)
for block in self._modules.values():
input, log_prob = block(input, context)
total_log_prob += log_prob
return input, total_log_prob


class MaskedLinear(nn.Module):
"""Linear layer with some input-output connections masked."""
def __init__(self, in_features, out_features, mask, context_features=None, bias=True):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias)
self.register_buffer("mask", mask)
if context_features is not None:
self.cond_linear = nn.Linear(context_features, out_features, bias=False)

def forward(self, input, context=None):
output = F.linear(input, self.mask * self.linear.weight, self.linear.bias)
if context is None:
return output
else:
return output + self.cond_linear(context)


class MADE(nn.Module):
"""Implements MADE: Masked Autoencoder for Distribution Estimation.

Follows https://arxiv.org/abs/1502.03509

This is used to build MAF: Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057).
"""
def __init__(self, num_input, num_output, num_hidden, num_context):
super().__init__()
# m corresponds to m(k), the maximum degree of a node in the MADE paper
self._m = []
self._masks = []
self._build_masks(num_input, num_output, num_hidden, num_layers=3)
self._check_masks()
modules = []
self.input_context_net = MaskedLinear(num_input, num_hidden, self._masks[0], num_context)
modules.append(nn.ReLU())
modules.append(MaskedLinear(num_hidden, num_hidden, self._masks[1], context_features=None))
modules.append(nn.ReLU())
modules.append(MaskedLinear(num_hidden, num_output, self._masks[2], context_features=None))
self.net = nn.Sequential(*modules)

def _build_masks(self, num_input, num_output, num_hidden, num_layers):
"""Build the masks according to Eq 12 and 13 in the MADE paper."""
rng = np.random.RandomState(0)
# assign input units a number between 1 and D
self._m.append(np.arange(1, num_input + 1))
for i in range(1, num_layers + 1):
# randomly assign maximum number of input nodes to connect to
if i == num_layers:
# assign output layer units a number between 1 and D
m = np.arange(1, num_input + 1)
assert num_output % num_input == 0, "num_output must be multiple of num_input"
self._m.append(np.hstack([m for _ in range(num_output // num_input)]))
else:
# assign hidden layer units a number between 1 and D-1
self._m.append(rng.randint(1, num_input, size=num_hidden))
#self._m.append(np.arange(1, num_hidden + 1) % (num_input - 1) + 1)
if i == num_layers:
mask = self._m[i][None, :] > self._m[i - 1][:, None]
else:
# input to hidden & hidden to hidden
mask = self._m[i][None, :] >= self._m[i - 1][:, None]
# need to transpose for torch linear layer, shape (num_output, num_input)
self._masks.append(torch.from_numpy(mask.astype(np.float32).T))

def _check_masks(self):
"""Check that the connectivity matrix between layers is lower triangular."""
# (num_input, num_hidden)
prev = self._masks[0].t()
for i in range(1, len(self._masks)):
# num_hidden is second axis
prev = prev @ self._masks[i].t()
final = prev.numpy()
num_input = self._masks[0].shape[1]
num_output = self._masks[-1].shape[0]
assert final.shape == (num_input, num_output)
if num_output == num_input:
assert np.triu(final).all() == 0
else:
for submat in np.split(final,
indices_or_sections=num_output // num_input,
axis=1):
assert np.triu(submat).all() == 0

def forward(self, input, context=None):
# first hidden layer receives input and context
hidden = self.input_context_net(input, context)
# rest of the network is conditioned on both input and context
return self.net(hidden)



class Reverse(nn.Module):
""" An implementation of a reversing layer from
Density estimation using Real NVP
(https://arxiv.org/abs/1605.08803).

From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py
"""

def __init__(self, num_input):
super(Reverse, self).__init__()
self.perm = np.array(np.arange(0, num_input)[::-1])
self.inv_perm = np.argsort(self.perm)

def forward(self, inputs, context=None, mode='forward'):
if mode == "forward":
return inputs[:, :, self.perm], torch.zeros_like(inputs, device=inputs.device)
elif mode == "inverse":
return inputs[:, :, self.inv_perm], torch.zeros_like(inputs, device=inputs.device)
else:
raise ValueError("Mode must be one of {forward, inverse}.")


Loading