From ed283715cea16fd9eac7c67593cb603b40d6c637 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Sat, 19 Jan 2019 16:02:49 -0500 Subject: [PATCH 01/21] add pytorch training script --- train_vae_pytorch.py | 150 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 train_vae_pytorch.py diff --git a/train_vae_pytorch.py b/train_vae_pytorch.py new file mode 100644 index 0000000..9f285d2 --- /dev/null +++ b/train_vae_pytorch.py @@ -0,0 +1,150 @@ +"""Fit a VAE to MNIST. + +Conventions: + - batch size is the innermost dimension, then the sample dimension, then latent dimension +""" +import torch +import torch.utils +from torch import nn +import nomen +import yaml +import numpy as np +import logging + +import data + +config = """ +latent_size: 128 +data_size: 784 +learning_rate: 0.001 +batch_size: 128 +test_batch_size: 512 +max_iterations: 100000 +log_interval: 1000 +n_samples: 77 +""" + +class NeuralNetwork(nn.Module): + def __init__(self, input_size, output_size, hidden_size): + super().__init__() + modules = [nn.Linear(input_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, output_size)] + self.net = nn.Sequential(*modules) + + def forward(self, input): + return self.net(input) + + + +class Model(nn.Module): + """Bernoulli model parameterized by a generative network with Gaussian latents for MNIST.""" + def __init__(self, latent_size, data_size, batch_size): + super().__init__() + # prior on latents is standard normal + self.p_z = torch.distributions.Normal(torch.zeros(latent_size), torch.ones(latent_size)) + # likelihood is bernoulli, equivalent to negative binary cross entropy + self.log_p_x = BernoulliLogProb() + # generative network is a MLP + self.generative_network = NeuralNetwork(input_size=latent_size, output_size=data_size, hidden_size=latent_size * 2) + + + def forward(self, z, x): + """Return log probability of model.""" + log_p_z = self.p_z.log_prob(z).sum(-1) + logits = self.generative_network(z) + log_p_x = self.log_p_x(logits, x).sum(-1) + return log_p_z + log_p_x + + +class NormalLogProb(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, loc, scale, z): + var = torch.pow(scale, 2) + return -0.5 * torch.log(2 * np.pi * var) + torch.pow(z - loc, 2) / (2 * var) + +class BernoulliLogProb(nn.Module): + def __init__(self): + super().__init__() + self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none') + + def forward(self, logits, target): + logits, target = torch.broadcast_tensors(logits, target.unsqueeze(1)) + return -self.bce_with_logits(logits, target) + +class Variational(nn.Module): + """Approximate posterior parameterized by an inference network.""" + def __init__(self, latent_size, data_size): + super().__init__() + self.inference_network = NeuralNetwork(input_size=data_size, output_size=latent_size * 2, hidden_size=latent_size*2) + self.log_q_z = NormalLogProb() + self.softplus = nn.Softplus() + + def forward(self, x, n_samples=1): + """Return sample of latent variable and log prob.""" + loc, scale_arg = torch.chunk(self.inference_network(x).unsqueeze(1), chunks=2, dim=-1) + scale = self.softplus(scale_arg) + eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1])) + z = loc + scale * eps # reparameterization + log_q_z = self.log_q_z(loc, scale, z).sum(-1) + return z, log_q_z + + +def cycle(iterable): + while True: + for x in iterable: + yield x + + +def evaluate(n_samples, model, variational, eval_data): + model.eval() + total_log_p_x = 0.0 + total_elbo = 0.0 + for batch in eval_data: + x = batch[0] + z, log_q_z = variational(x, n_samples) + log_p_x_and_z = model(z, x) + # importance sampling of approximate marginal likelihood + # using logsumexp in the sample dimension + elbo = log_p_x_and_z - log_q_z + log_p_x = torch.logsumexp(elbo, dim=1) - np.log(n_samples) + # average over sample dimension, sum over minibatch + total_elbo += elbo.cpu().numpy().mean(1).sum() + # sum over minibatch + total_log_p_x += log_p_x.cpu().numpy().sum() + n_data = len(eval_data.dataset) + return total_elbo / n_data, total_log_p_x / n_data + + +if __name__ == '__main__': + dictionary = yaml.load(config) + cfg = nomen.Config(dictionary) + + model = Model(latent_size=cfg.latent_size, data_size=cfg.data_size, batch_size=cfg.batch_size) + variational = Variational(latent_size=cfg.latent_size, data_size=cfg.data_size) + + optimizer = torch.optim.RMSprop(list(model.parameters()) + list(variational.parameters()), + lr=cfg.learning_rate) + + train_data, valid_data, test_data = data.load_binary_mnist(cfg) + + for step, batch in enumerate(cycle(train_data)): + x = batch[0] + model.zero_grad() + variational.zero_grad() + z, log_q_z = variational(x) + log_p_x_and_z = model(z, x) + elbo = (log_p_x_and_z - log_q_z).mean(1) + loss = -elbo.mean(0) + loss.backward() + optimizer.step() + + if step % cfg.log_interval == 0: + print(f'step:\t{step}\ttrain elbo: {elbo.detach().cpu().numpy()[0]:.2f}') + with torch.no_grad(): + valid_elbo, valid_log_p_x = evaluate(cfg.n_samples, model, variational, valid_data) + print(f'step:\t{step}\tvalid elbo: {valid_elbo:.2f}\tvalid log p(x): {valid_log_p_x:.2f}') From 47dc18178d60ac0a447d087e53462fcfbd9ce785 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Sat, 19 Jan 2019 16:03:39 -0500 Subject: [PATCH 02/21] Update README.md --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index efe40e4..bf36993 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ -# Variational Autoencoder (Deep Latent Gaussian Model) in tf -Reference implementation for a variational autoencoder in TensorFlow. +# Variational Autoencoder / Deep Latent Gaussian Model in tensorflow and pytorch +Reference implementation for a variational autoencoder in TensorFlow and PyTorch. + +I recommend the PyTorch version. Mean-field variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder). From b9378bb2573488fa55ef416d960da21104cdb620 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Sat, 19 Jan 2019 16:04:14 -0500 Subject: [PATCH 03/21] Rename train_vae_pytorch.py to train_variational_autoencoder_pytorch.py --- train_vae_pytorch.py => train_variational_autoencoder_pytorch.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename train_vae_pytorch.py => train_variational_autoencoder_pytorch.py (100%) diff --git a/train_vae_pytorch.py b/train_variational_autoencoder_pytorch.py similarity index 100% rename from train_vae_pytorch.py rename to train_variational_autoencoder_pytorch.py From 7ec71325fce31424ad868443cadf008c1f3c8b8f Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Sat, 19 Jan 2019 16:05:59 -0500 Subject: [PATCH 04/21] Update README.md --- README.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/README.md b/README.md index bf36993..b21ac45 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,38 @@ I recommend the PyTorch version. Mean-field variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder). Blog post: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/ + +Example output with importance sampling for estimating the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. Finaly marginal likelihood on the test set of `-97.10` nats. + +``` +$ python train_variational_autoencoder_pytorch.py +step: 0 train elbo: -558.69 +step: 0 valid elbo: -391.84 valid log p(x): -363.25 +step: 5000 train elbo: -116.09 +step: 5000 valid elbo: -112.57 valid log p(x): -107.01 +step: 10000 train elbo: -105.82 +step: 10000 valid elbo: -108.49 valid log p(x): -102.62 +step: 15000 train elbo: -106.78 +step: 15000 valid elbo: -106.97 valid log p(x): -100.97 +step: 20000 train elbo: -108.43 +step: 20000 valid elbo: -106.23 valid log p(x): -100.04 +step: 25000 train elbo: -99.68 +step: 25000 valid elbo: -104.89 valid log p(x): -98.83 +step: 30000 train elbo: -96.71 +step: 30000 valid elbo: -104.50 valid log p(x): -98.34 +step: 35000 train elbo: -98.64 +step: 35000 valid elbo: -104.05 valid log p(x): -97.87 +step: 40000 train elbo: -93.60 +step: 40000 valid elbo: -104.10 valid log p(x): -97.68 +step: 45000 train elbo: -96.45 +step: 45000 valid elbo: -104.58 valid log p(x): -97.76 +step: 50000 train elbo: -101.63 +step: 50000 valid elbo: -104.72 valid log p(x): -97.81 +step: 55000 train elbo: -106.78 +step: 55000 valid elbo: -105.14 valid log p(x): -98.06 +step: 60000 train elbo: -100.58 +step: 60000 valid elbo: -104.13 valid log p(x): -97.30 +step: 65000 train elbo: -96.19 +step: 65000 valid elbo: -104.46 valid log p(x): -97.43 +step: 65000 test elbo: -103.31 test log p(x): -97.10 +``` From aeee4f94d9ef5efd1ada75b39363b104c4f71931 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Sat, 19 Jan 2019 16:07:13 -0500 Subject: [PATCH 05/21] fix test marginal likelihood --- train_variational_autoencoder_pytorch.py | 173 +++++++++++++++-------- 1 file changed, 114 insertions(+), 59 deletions(-) diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index 9f285d2..11db31c 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -1,17 +1,19 @@ -"""Fit a VAE to MNIST. +"""Fit a variational autoencoder to MNIST. -Conventions: +Notes: + - run https://github.com/altosaar/proximity_vi/blob/master/get_binary_mnist.py to download binary MNIST file - batch size is the innermost dimension, then the sample dimension, then latent dimension """ import torch import torch.utils +import torch.utils.data from torch import nn import nomen import yaml import numpy as np import logging - -import data +import pathlib +import h5py config = """ latent_size: 128 @@ -20,67 +22,42 @@ batch_size: 128 test_batch_size: 512 max_iterations: 100000 -log_interval: 1000 -n_samples: 77 +log_interval: 5000 +n_samples: 128 +use_gpu: true +train_dir: $TMPDIR """ -class NeuralNetwork(nn.Module): - def __init__(self, input_size, output_size, hidden_size): - super().__init__() - modules = [nn.Linear(input_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, output_size)] - self.net = nn.Sequential(*modules) - - def forward(self, input): - return self.net(input) - - class Model(nn.Module): """Bernoulli model parameterized by a generative network with Gaussian latents for MNIST.""" - def __init__(self, latent_size, data_size, batch_size): + def __init__(self, latent_size, data_size, batch_size, device): super().__init__() - # prior on latents is standard normal - self.p_z = torch.distributions.Normal(torch.zeros(latent_size), torch.ones(latent_size)) - # likelihood is bernoulli, equivalent to negative binary cross entropy + self.p_z = torch.distributions.Normal( + torch.zeros(latent_size, device=device), + torch.ones(latent_size, device=device)) self.log_p_x = BernoulliLogProb() - # generative network is a MLP - self.generative_network = NeuralNetwork(input_size=latent_size, output_size=data_size, hidden_size=latent_size * 2) - + self.generative_network = NeuralNetwork(input_size=latent_size, + output_size=data_size, + hidden_size=latent_size * 2) def forward(self, z, x): """Return log probability of model.""" log_p_z = self.p_z.log_prob(z).sum(-1) logits = self.generative_network(z) + # unsqueeze sample dimension + logits, x = torch.broadcast_tensors(logits, x.unsqueeze(1)) log_p_x = self.log_p_x(logits, x).sum(-1) return log_p_z + log_p_x - -class NormalLogProb(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, loc, scale, z): - var = torch.pow(scale, 2) - return -0.5 * torch.log(2 * np.pi * var) + torch.pow(z - loc, 2) / (2 * var) - -class BernoulliLogProb(nn.Module): - def __init__(self): - super().__init__() - self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none') - - def forward(self, logits, target): - logits, target = torch.broadcast_tensors(logits, target.unsqueeze(1)) - return -self.bce_with_logits(logits, target) class Variational(nn.Module): """Approximate posterior parameterized by an inference network.""" def __init__(self, latent_size, data_size): super().__init__() - self.inference_network = NeuralNetwork(input_size=data_size, output_size=latent_size * 2, hidden_size=latent_size*2) + self.inference_network = NeuralNetwork(input_size=data_size, + output_size=latent_size * 2, + hidden_size=latent_size*2) self.log_q_z = NormalLogProb() self.softplus = nn.Softplus() @@ -88,28 +65,75 @@ def forward(self, x, n_samples=1): """Return sample of latent variable and log prob.""" loc, scale_arg = torch.chunk(self.inference_network(x).unsqueeze(1), chunks=2, dim=-1) scale = self.softplus(scale_arg) - eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1])) + eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device) z = loc + scale * eps # reparameterization log_q_z = self.log_q_z(loc, scale, z).sum(-1) return z, log_q_z +class NeuralNetwork(nn.Module): + def __init__(self, input_size, output_size, hidden_size): + super().__init__() + modules = [nn.Linear(input_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, output_size)] + self.net = nn.Sequential(*modules) + + def forward(self, input): + return self.net(input) + + +class NormalLogProb(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, loc, scale, z): + var = torch.pow(scale, 2) + return -0.5 * torch.log(2 * np.pi * var) - torch.pow(z - loc, 2) / (2 * var) + + +class BernoulliLogProb(nn.Module): + def __init__(self): + super().__init__() + self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none') + + def forward(self, logits, target): + # bernoulli log prob is equivalent to negative binary cross entropy + return -self.bce_with_logits(logits, target) + + def cycle(iterable): while True: for x in iterable: yield x +def load_binary_mnist(cfg, **kwcfg): + f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binarized_mnist.hdf5'), 'r') + x_train = f['train'][::] + x_val = f['valid'][::] + x_test = f['test'][::] + train = torch.utils.data.TensorDataset(torch.from_numpy(x_train)) + train_loader = torch.utils.data.DataLoader(train, batch_size=cfg.batch_size, shuffle=True) + validation = torch.utils.data.TensorDataset(torch.from_numpy(x_val)) + val_loader = torch.utils.data.DataLoader(validation, batch_size=cfg.test_batch_size, shuffle=False) + test = torch.utils.data.TensorDataset(torch.from_numpy(x_test)) + test_loader = torch.utils.data.DataLoader(test, batch_size=cfg.test_batch_size, shuffle=False) + return train_loader, val_loader, test_loader + + def evaluate(n_samples, model, variational, eval_data): model.eval() total_log_p_x = 0.0 total_elbo = 0.0 for batch in eval_data: - x = batch[0] + x = batch[0].to(next(model.parameters()).device) z, log_q_z = variational(x, n_samples) log_p_x_and_z = model(z, x) - # importance sampling of approximate marginal likelihood - # using logsumexp in the sample dimension + # importance sampling of approximate marginal likelihood with q(z) + # as the proposal, and logsumexp in the sample dimension elbo = log_p_x_and_z - log_q_z log_p_x = torch.logsumexp(elbo, dim=1) - np.log(n_samples) # average over sample dimension, sum over minibatch @@ -123,28 +147,59 @@ def evaluate(n_samples, model, variational, eval_data): if __name__ == '__main__': dictionary = yaml.load(config) cfg = nomen.Config(dictionary) - - model = Model(latent_size=cfg.latent_size, data_size=cfg.data_size, batch_size=cfg.batch_size) - variational = Variational(latent_size=cfg.latent_size, data_size=cfg.data_size) + device = torch.device("cuda:0" if cfg.use_gpu else "cpu") + + model = Model(latent_size=cfg.latent_size, + data_size=cfg.data_size, + batch_size=cfg.batch_size, + device=device) + variational = Variational(latent_size=cfg.latent_size, + data_size=cfg.data_size) + model.to(device) + variational.to(device) + + optimizer = torch.optim.RMSprop(list(model.parameters()) + + list(variational.parameters()), + lr=cfg.learning_rate, + centered=True) - optimizer = torch.optim.RMSprop(list(model.parameters()) + list(variational.parameters()), - lr=cfg.learning_rate) + kwargs = {'num_workers': 0, 'pin_memory': False} if cfg.use_gpu else {} + train_data, valid_data, test_data = load_binary_mnist(cfg, **kwargs) - train_data, valid_data, test_data = data.load_binary_mnist(cfg) + best_valid_elbo = -np.inf + num_no_improvement = 0 for step, batch in enumerate(cycle(train_data)): - x = batch[0] + x = batch[0].to(device) model.zero_grad() variational.zero_grad() z, log_q_z = variational(x) log_p_x_and_z = model(z, x) + # average over sample dimension elbo = (log_p_x_and_z - log_q_z).mean(1) - loss = -elbo.mean(0) + # sum over batch dimension + loss = -elbo.sum(0) loss.backward() optimizer.step() if step % cfg.log_interval == 0: - print(f'step:\t{step}\ttrain elbo: {elbo.detach().cpu().numpy()[0]:.2f}') + print(f'step:\t{step}\ttrain elbo: {elbo.detach().cpu().numpy().mean():.2f}') with torch.no_grad(): valid_elbo, valid_log_p_x = evaluate(cfg.n_samples, model, variational, valid_data) - print(f'step:\t{step}\tvalid elbo: {valid_elbo:.2f}\tvalid log p(x): {valid_log_p_x:.2f}') + print(f'step:\t{step}\t\tvalid elbo: {valid_elbo:.2f}\tvalid log p(x): {valid_log_p_x:.2f}') + if valid_elbo > best_valid_elbo: + best_valid_elbo = valid_elbo + states = {'model': model.state_dict(), + 'variational': variational.state_dict()} + torch.save(states, cfg.train_dir / 'best_state_dict') + else: + num_no_improvement += 1 + + if num_no_improvement > 5: + checkpoint = torch.load(cfg.train_dir / 'best_state_dict') + model.load_state_dict(checkpoint['model']) + variational.load_state_dict(checkpoint['variational']) + with torch.no_grad(): + test_elbo, test_log_p_x = evaluate(cfg.n_samples, model, variational, test_data) + print(f'step:\t{step}\t\ttest elbo: {test_elbo:.2f}\ttest log p(x): {test_log_p_x:.2f}') + break From 68c85351bc45b371a6cb4ea14b67d020154db572 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Sun, 20 Jan 2019 15:55:17 -0500 Subject: [PATCH 06/21] pin memory; fix seed --- train_variational_autoencoder_pytorch.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index 11db31c..9537222 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -14,6 +14,7 @@ import logging import pathlib import h5py +import random config = """ latent_size: 128 @@ -22,10 +23,11 @@ batch_size: 128 test_batch_size: 512 max_iterations: 100000 -log_interval: 5000 +log_interval: 10000 n_samples: 128 use_gpu: true train_dir: $TMPDIR +seed: 582838 """ @@ -116,7 +118,7 @@ def load_binary_mnist(cfg, **kwcfg): x_val = f['valid'][::] x_test = f['test'][::] train = torch.utils.data.TensorDataset(torch.from_numpy(x_train)) - train_loader = torch.utils.data.DataLoader(train, batch_size=cfg.batch_size, shuffle=True) + train_loader = torch.utils.data.DataLoader(train, batch_size=cfg.batch_size, shuffle=True, **kwcfg) validation = torch.utils.data.TensorDataset(torch.from_numpy(x_val)) val_loader = torch.utils.data.DataLoader(validation, batch_size=cfg.test_batch_size, shuffle=False) test = torch.utils.data.TensorDataset(torch.from_numpy(x_test)) @@ -148,6 +150,9 @@ def evaluate(n_samples, model, variational, eval_data): dictionary = yaml.load(config) cfg = nomen.Config(dictionary) device = torch.device("cuda:0" if cfg.use_gpu else "cpu") + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + random.seed(cfg.seed) model = Model(latent_size=cfg.latent_size, data_size=cfg.data_size, @@ -163,7 +168,7 @@ def evaluate(n_samples, model, variational, eval_data): lr=cfg.learning_rate, centered=True) - kwargs = {'num_workers': 0, 'pin_memory': False} if cfg.use_gpu else {} + kwargs = {'num_workers': 4, 'pin_memory': True} if cfg.use_gpu else {} train_data, valid_data, test_data = load_binary_mnist(cfg, **kwargs) best_valid_elbo = -np.inf @@ -188,6 +193,7 @@ def evaluate(n_samples, model, variational, eval_data): valid_elbo, valid_log_p_x = evaluate(cfg.n_samples, model, variational, valid_data) print(f'step:\t{step}\t\tvalid elbo: {valid_elbo:.2f}\tvalid log p(x): {valid_log_p_x:.2f}') if valid_elbo > best_valid_elbo: + num_no_improvement = 0 best_valid_elbo = valid_elbo states = {'model': model.state_dict(), 'variational': variational.state_dict()} From c78cb3f7208dccd029112265518c38abb4abda55 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Sun, 20 Jan 2019 16:11:14 -0500 Subject: [PATCH 07/21] add inverse autoregressive flow classes --- flow.py | 152 +++++++++++++++++++++++ train_variational_autoencoder_pytorch.py | 66 ++++++++-- 2 files changed, 205 insertions(+), 13 deletions(-) create mode 100644 flow.py diff --git a/flow.py b/flow.py new file mode 100644 index 0000000..b3b33e7 --- /dev/null +++ b/flow.py @@ -0,0 +1,152 @@ +"""Credit: mostly based on Ilya's excellent implementation here: https://github.com/ikostrikov/pytorch-flows""" +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class InverseAutoregressiveFlow(nn.Module): + """Inverse Autoregressive Flows with LSTM-type update. One block. + + Eq 11-14 of https://arxiv.org/abs/1606.04934 + """ + def __init__(self, num_input, num_hidden, num_context): + super().__init__() + self.made = MADE(num_input=num_input, num_output=num_input * 2, + num_hidden=num_hidden, num_context=num_context) + # init such that sigmoid(s) is close to 1 for stability + self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2) + self.sigmoid = nn.Sigmoid() + self.log_sigmoid = nn.LogSigmoid() + + def forward(self, input, context=None): + m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1) + s = s + self.sigmoid_arg_bias + sigmoid = self.sigmoid(s) + z = sigmoid * input + (1 - sigmoid) * m + return z, -self.log_sigmoid(s) + + +class FlowSequential(nn.Sequential): + """Forward pass.""" + + def forward(self, input, context=None): + total_log_prob = torch.zeros_like(input, device=input.device) + for block in self._modules.values(): + input, log_prob = block(input, context) + total_log_prob += log_prob + return input, total_log_prob + + +class MaskedLinear(nn.Module): + """Linear layer with some input-output connections masked.""" + def __init__(self, in_features, out_features, mask, context_features=None, bias=True): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias) + self.register_buffer("mask", mask) + if context_features is not None: + self.cond_linear = nn.Linear(context_features, out_features, bias=False) + + def forward(self, input, context=None): + output = F.linear(input, self.mask * self.linear.weight, self.linear.bias) + if context is None: + return output + else: + return output + self.cond_linear(context) + + +class MADE(nn.Module): + """Implements MADE: Masked Autoencoder for Distribution Estimation. + + Follows https://arxiv.org/abs/1502.03509 + + This is used to build MAF: Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057). + """ + def __init__(self, num_input, num_output, num_hidden, num_context): + super().__init__() + # m corresponds to m(k), the maximum degree of a node in the MADE paper + self._m = [] + self._masks = [] + self._build_masks(num_input, num_output, num_hidden, num_layers=3) + self._check_masks() + modules = [] + self.input_context_net = MaskedLinear(num_input, num_hidden, self._masks[0], num_context) + modules.append(nn.ReLU()) + modules.append(MaskedLinear(num_hidden, num_hidden, self._masks[1], context_features=None)) + modules.append(nn.ReLU()) + modules.append(MaskedLinear(num_hidden, num_output, self._masks[2], context_features=None)) + self.net = nn.Sequential(*modules) + + def _build_masks(self, num_input, num_output, num_hidden, num_layers): + """Build the masks according to Eq 12 and 13 in the MADE paper.""" + rng = np.random.RandomState(0) + # assign input units a number between 1 and D + self._m.append(np.arange(1, num_input + 1)) + for i in range(1, num_layers + 1): + # randomly assign maximum number of input nodes to connect to + if i == num_layers: + # assign output layer units a number between 1 and D + m = np.arange(1, num_input + 1) + assert num_output % num_input == 0, "num_output must be multiple of num_input" + self._m.append(np.hstack([m for _ in range(num_output // num_input)])) + else: + # assign hidden layer units a number between 1 and D-1 + self._m.append(rng.randint(1, num_input, size=num_hidden)) + #self._m.append(np.arange(1, num_hidden + 1) % (num_input - 1) + 1) + if i == num_layers: + mask = self._m[i][None, :] > self._m[i - 1][:, None] + else: + # input to hidden & hidden to hidden + mask = self._m[i][None, :] >= self._m[i - 1][:, None] + # need to transpose for torch linear layer, shape (num_output, num_input) + self._masks.append(torch.from_numpy(mask.astype(np.float32).T)) + + def _check_masks(self): + """Check that the connectivity matrix between layers is lower triangular.""" + # (num_input, num_hidden) + prev = self._masks[0].t() + for i in range(1, len(self._masks)): + # num_hidden is second axis + prev = prev @ self._masks[i].t() + final = prev.numpy() + num_input = self._masks[0].shape[1] + num_output = self._masks[-1].shape[0] + assert final.shape == (num_input, num_output) + if num_output == num_input: + assert np.triu(final).all() == 0 + else: + for submat in np.split(final, + indices_or_sections=num_output // num_input, + axis=1): + assert np.triu(submat).all() == 0 + + def forward(self, input, context=None): + # first hidden layer receives input and context + hidden = self.input_context_net(input, context) + # rest of the network is conditioned on both input and context + return self.net(hidden) + + + +class Reverse(nn.Module): + """ An implementation of a reversing layer from + Density estimation using Real NVP + (https://arxiv.org/abs/1605.08803). + + From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py + """ + + def __init__(self, num_input): + super(Reverse, self).__init__() + self.perm = np.array(np.arange(0, num_input)[::-1]) + self.inv_perm = np.argsort(self.perm) + + def forward(self, inputs, context=None, mode='forward'): + if mode == "forward": + return inputs[:, :, self.perm], torch.zeros_like(inputs, device=inputs.device) + elif mode == "inverse": + return inputs[:, :, self.inv_perm], torch.zeros_like(inputs, device=inputs.device) + else: + raise ValueError("Mode must be one of {forward, inverse}.") + + diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index 9537222..3b5e5d2 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -15,9 +15,12 @@ import pathlib import h5py import random +import flow config = """ latent_size: 128 +variational: flow +flow_depth: 2 data_size: 784 learning_rate: 0.001 batch_size: 128 @@ -30,14 +33,13 @@ seed: 582838 """ - class Model(nn.Module): """Bernoulli model parameterized by a generative network with Gaussian latents for MNIST.""" - def __init__(self, latent_size, data_size, batch_size, device): + def __init__(self, latent_size, data_size): super().__init__() - self.p_z = torch.distributions.Normal( - torch.zeros(latent_size, device=device), - torch.ones(latent_size, device=device)) + self.register_buffer('p_z_loc', torch.zeros(latent_size)) + self.register_buffer('p_z_scale', torch.ones(latent_size)) + self.log_p_z = NormalLogProb() self.log_p_x = BernoulliLogProb() self.generative_network = NeuralNetwork(input_size=latent_size, output_size=data_size, @@ -45,15 +47,15 @@ def __init__(self, latent_size, data_size, batch_size, device): def forward(self, z, x): """Return log probability of model.""" - log_p_z = self.p_z.log_prob(z).sum(-1) + log_p_z = self.log_p_z(self.p_z_loc, self.p_z_scale, z).sum(-1, keepdim=True) logits = self.generative_network(z) # unsqueeze sample dimension logits, x = torch.broadcast_tensors(logits, x.unsqueeze(1)) - log_p_x = self.log_p_x(logits, x).sum(-1) + log_p_x = self.log_p_x(logits, x).sum(-1, keepdim=True) return log_p_z + log_p_x -class Variational(nn.Module): +class VariationalMeanField(nn.Module): """Approximate posterior parameterized by an inference network.""" def __init__(self, latent_size, data_size): super().__init__() @@ -73,6 +75,38 @@ def forward(self, x, n_samples=1): return z, log_q_z +class VariationalFlow(nn.Module): + """Approximate posterior parameterized by a flow (https://arxiv.org/abs/1606.04934).""" + def __init__(self, latent_size, data_size, flow_depth): + super().__init__() + hidden_size = latent_size * 2 + self.inference_network = NeuralNetwork(input_size=data_size, + # loc, scale, and context + output_size=latent_size * 3, + hidden_size=hidden_size) + modules = [] + for _ in range(flow_depth): + modules.append(flow.InverseAutoregressiveFlow(num_input=latent_size, + num_hidden=hidden_size, + num_context=latent_size)) + modules.append(flow.Reverse(latent_size)) + self.q_z_flow = flow.FlowSequential(*modules) + self.log_q_z_0 = NormalLogProb() + self.softplus = nn.Softplus() + + def forward(self, x, n_samples=1): + """Return sample of latent variable and log prob.""" + loc, scale_arg, h = torch.chunk(self.inference_network(x).unsqueeze(1), chunks=3, dim=-1) + scale = self.softplus(scale_arg) + eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device) + z_0 = loc + scale * eps # reparameterization + log_q_z_0 = self.log_q_z_0(loc, scale, z_0) + z_T, log_q_z_flow = self.q_z_flow(z_0, context=h) + log_q_z = (log_q_z_0 + log_q_z_flow).sum(-1, keepdim=True) + return z_T, log_q_z + + + class NeuralNetwork(nn.Module): def __init__(self, input_size, output_size, hidden_size): super().__init__() @@ -155,11 +189,17 @@ def evaluate(n_samples, model, variational, eval_data): random.seed(cfg.seed) model = Model(latent_size=cfg.latent_size, - data_size=cfg.data_size, - batch_size=cfg.batch_size, - device=device) - variational = Variational(latent_size=cfg.latent_size, - data_size=cfg.data_size) + data_size=cfg.data_size) + if cfg.variational == 'flow': + variational = VariationalFlow(latent_size=cfg.latent_size, + data_size=cfg.data_size, + flow_depth=cfg.flow_depth) + elif cfg.variational == 'mean-field': + variational = VariationalMeanField(latent_size=cfg.latent_size, + data_size=cfg.data_size) + else: + raise ValueError('Variational distribution not implemented: %s' % cfg.variational) + model.to(device) variational.to(device) From 54bf13e99215ca31f973ab827d57e80f1bca3283 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Sun, 20 Jan 2019 16:14:03 -0500 Subject: [PATCH 08/21] Update README.md --- README.md | 46 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b21ac45..800c71c 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,16 @@ # Variational Autoencoder / Deep Latent Gaussian Model in tensorflow and pytorch Reference implementation for a variational autoencoder in TensorFlow and PyTorch. -I recommend the PyTorch version. +I recommend the PyTorch version. It includes an example of a more expressive variational family (the [inverse autoregressive flow](https://arxiv.org/abs/1606.04934). -Mean-field variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder). +Variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder). Blog post: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/ Example output with importance sampling for estimating the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. Finaly marginal likelihood on the test set of `-97.10` nats. ``` -$ python train_variational_autoencoder_pytorch.py +$ python train_variational_autoencoder_pytorch.py --variational mean-field step: 0 train elbo: -558.69 step: 0 valid elbo: -391.84 valid log p(x): -363.25 step: 5000 train elbo: -116.09 @@ -41,3 +41,43 @@ step: 65000 train elbo: -96.19 step: 65000 valid elbo: -104.46 valid log p(x): -97.43 step: 65000 test elbo: -103.31 test log p(x): -97.10 ``` + + +Using a non mean-field, more expressive variational posterior approximation, the test marginal log-likelihood improves to `-95.33` nats: + +``` +$ python train_variational_autoencoder_pytorch.py --variational flow +step: 0 train elbo: -578.35 +step: 0 valid elbo: -407.06 valid log p(x): -367.88 +step: 10000 train elbo: -106.63 +step: 10000 valid elbo: -110.12 valid log p(x): -104.00 +step: 20000 train elbo: -101.51 +step: 20000 valid elbo: -105.02 valid log p(x): -99.11 +step: 30000 train elbo: -98.70 +step: 30000 valid elbo: -103.76 valid log p(x): -97.71 +step: 40000 train elbo: -104.31 +step: 40000 valid elbo: -103.71 valid log p(x): -97.27 +step: 50000 train elbo: -97.20 +step: 50000 valid elbo: -102.97 valid log p(x): -96.60 +step: 60000 train elbo: -97.50 +step: 60000 valid elbo: -102.82 valid log p(x): -96.49 +step: 70000 train elbo: -94.68 +step: 70000 valid elbo: -102.63 valid log p(x): -96.22 +step: 80000 train elbo: -92.86 +step: 80000 valid elbo: -102.53 valid log p(x): -96.09 +step: 90000 train elbo: -93.83 +step: 90000 valid elbo: -102.33 valid log p(x): -96.00 +step: 100000 train elbo: -93.91 +step: 100000 valid elbo: -102.48 valid log p(x): -95.92 +step: 110000 train elbo: -94.34 +step: 110000 valid elbo: -102.81 valid log p(x): -96.09 +step: 120000 train elbo: -88.63 +step: 120000 valid elbo: -102.53 valid log p(x): -95.80 +step: 130000 train elbo: -96.61 +step: 130000 valid elbo: -103.56 valid log p(x): -96.26 +step: 140000 train elbo: -94.92 +step: 140000 valid elbo: -102.81 valid log p(x): -95.86 +step: 150000 train elbo: -97.84 +step: 150000 valid elbo: -103.06 valid log p(x): -95.92 +step: 150000 test elbo: -101.64 test log p(x): -95.33 +``` From 75949e3c0501edf0c00e31d4e09c8b54d3cf435e Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Sun, 20 Jan 2019 16:14:46 -0500 Subject: [PATCH 09/21] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 800c71c..68acb24 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Variational Autoencoder / Deep Latent Gaussian Model in tensorflow and pytorch Reference implementation for a variational autoencoder in TensorFlow and PyTorch. -I recommend the PyTorch version. It includes an example of a more expressive variational family (the [inverse autoregressive flow](https://arxiv.org/abs/1606.04934). +I recommend the PyTorch version. It includes an example of a more expressive variational family, the [inverse autoregressive flow](https://arxiv.org/abs/1606.04934). Variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder). From 2e48e80b5b7a0aea521ad96a307842a180e3847e Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Sun, 20 Jan 2019 16:15:08 -0500 Subject: [PATCH 10/21] Rename vae.py to train_variational_autoencoder_tensorflow.py --- vae.py => train_variational_autoencoder_tensorflow.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename vae.py => train_variational_autoencoder_tensorflow.py (100%) diff --git a/vae.py b/train_variational_autoencoder_tensorflow.py similarity index 100% rename from vae.py rename to train_variational_autoencoder_tensorflow.py From 4e58b8c2a6be7dbfe949df2c8b5f6c55d090ffdc Mon Sep 17 00:00:00 2001 From: "Ilya V. Schurov" Date: Wed, 20 Mar 2019 02:30:09 +0300 Subject: [PATCH 11/21] fixes error with size mismatch --- train_variational_autoencoder_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index 3b5e5d2..f663dea 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -71,7 +71,7 @@ def forward(self, x, n_samples=1): scale = self.softplus(scale_arg) eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device) z = loc + scale * eps # reparameterization - log_q_z = self.log_q_z(loc, scale, z).sum(-1) + log_q_z = self.log_q_z(loc, scale, z).sum(-1, keepdim=True) return z, log_q_z From 517c52b0eb119ab68ed92807759a261a51e59fe2 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Thu, 21 Mar 2019 13:39:35 -0400 Subject: [PATCH 12/21] Update train_variational_autoencoder_pytorch.py --- train_variational_autoencoder_pytorch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index f663dea..279e05c 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -27,6 +27,7 @@ test_batch_size: 512 max_iterations: 100000 log_interval: 10000 +early_stopping_interval: 5 n_samples: 128 use_gpu: true train_dir: $TMPDIR @@ -183,6 +184,7 @@ def evaluate(n_samples, model, variational, eval_data): if __name__ == '__main__': dictionary = yaml.load(config) cfg = nomen.Config(dictionary) + cfg.parse_args() device = torch.device("cuda:0" if cfg.use_gpu else "cpu") torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) @@ -241,7 +243,7 @@ def evaluate(n_samples, model, variational, eval_data): else: num_no_improvement += 1 - if num_no_improvement > 5: + if num_no_improvement > cfg.early_stopping_interval: checkpoint = torch.load(cfg.train_dir / 'best_state_dict') model.load_state_dict(checkpoint['model']) variational.load_state_dict(checkpoint['variational']) From 1d0f52358f04b96d0450a98fe877fb7060011e4a Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Thu, 21 Mar 2019 13:48:14 -0400 Subject: [PATCH 13/21] add mnist download utility --- data.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 data.py diff --git a/data.py b/data.py new file mode 100644 index 0000000..580c978 --- /dev/null +++ b/data.py @@ -0,0 +1,42 @@ +"""Get the binarized MNIST dataset and convert to hdf5. +From https://github.com/yburda/iwae/blob/master/datasets.py +""" +import urllib.request +import os +import numpy as np +import h5py + + +def parse_binary_mnist(): + def lines_to_np_array(lines): + return np.array([[int(i) for i in line.split()] for line in lines]) + with open(os.path.join(DATASETS_DIR, 'binarized_mnist_train.amat')) as f: + lines = f.readlines() + train_data = lines_to_np_array(lines).astype('float32') + with open(os.path.join(DATASETS_DIR, 'binarized_mnist_valid.amat')) as f: + lines = f.readlines() + validation_data = lines_to_np_array(lines).astype('float32') + with open(os.path.join(DATASETS_DIR, 'binarized_mnist_test.amat')) as f: + lines = f.readlines() + test_data = lines_to_np_array(lines).astype('float32') + return train_data, validation_data, test_data + + +def download_binary_mnist(fname): + DATASETS_DIR = '/tmp/' + subdatasets = ['train', 'valid', 'test'] + for subdataset in subdatasets: + filename = 'binarized_mnist_{}.amat'.format(subdataset) + url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format( + subdataset) + local_filename = os.path.join(DATASETS_DIR, filename) + urllib.request.urlretrieve(url, local_filename) + + train, validation, test = parse_binary_mnist() + + data_dict = {'train': train, 'valid': validation, 'test': test} + f = h5py.File(fname, 'w') + f.create_dataset('train', data=data_dict['train']) + f.create_dataset('valid', data=data_dict['valid']) + f.create_dataset('test', data=data_dict['test']) + f.close() From cdf9000699e28d1c69ad08bf08a1c6d75e95a75d Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Thu, 21 Mar 2019 13:48:36 -0400 Subject: [PATCH 14/21] add mnist download utility --- train_variational_autoencoder_pytorch.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index f663dea..16bd92c 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -15,6 +15,7 @@ import pathlib import h5py import random +import data import flow config = """ @@ -30,6 +31,7 @@ n_samples: 128 use_gpu: true train_dir: $TMPDIR +data_dir: $TMPDIR seed: 582838 """ @@ -147,6 +149,9 @@ def cycle(iterable): def load_binary_mnist(cfg, **kwcfg): + fname = cfg.data_dir / 'binary_mnist.h5' + if not fname.exists(): + data.download_binary_mnist(fname) f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binarized_mnist.hdf5'), 'r') x_train = f['train'][::] x_val = f['valid'][::] From 8d1c764b283cc03517606caf69871e8db7a31cb8 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Thu, 21 Mar 2019 13:57:27 -0400 Subject: [PATCH 15/21] rename DATASETS_DIR to data_dir --- data.py | 15 ++++++++------- train_variational_autoencoder_pytorch.py | 1 + 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/data.py b/data.py index 580c978..f49249d 100644 --- a/data.py +++ b/data.py @@ -7,32 +7,32 @@ import h5py -def parse_binary_mnist(): +def parse_binary_mnist(data_dir): def lines_to_np_array(lines): return np.array([[int(i) for i in line.split()] for line in lines]) - with open(os.path.join(DATASETS_DIR, 'binarized_mnist_train.amat')) as f: + with open(os.path.join(data_dir, 'binarized_mnist_train.amat')) as f: lines = f.readlines() train_data = lines_to_np_array(lines).astype('float32') - with open(os.path.join(DATASETS_DIR, 'binarized_mnist_valid.amat')) as f: + with open(os.path.join(data_dir, 'binarized_mnist_valid.amat')) as f: lines = f.readlines() validation_data = lines_to_np_array(lines).astype('float32') - with open(os.path.join(DATASETS_DIR, 'binarized_mnist_test.amat')) as f: + with open(os.path.join(data_dir, 'binarized_mnist_test.amat')) as f: lines = f.readlines() test_data = lines_to_np_array(lines).astype('float32') return train_data, validation_data, test_data def download_binary_mnist(fname): - DATASETS_DIR = '/tmp/' + data_dir = '/tmp/' subdatasets = ['train', 'valid', 'test'] for subdataset in subdatasets: filename = 'binarized_mnist_{}.amat'.format(subdataset) url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format( subdataset) - local_filename = os.path.join(DATASETS_DIR, filename) + local_filename = os.path.join(data_dir, filename) urllib.request.urlretrieve(url, local_filename) - train, validation, test = parse_binary_mnist() + train, validation, test = parse_binary_mnist(data_dir) data_dict = {'train': train, 'valid': validation, 'test': test} f = h5py.File(fname, 'w') @@ -40,3 +40,4 @@ def download_binary_mnist(fname): f.create_dataset('valid', data=data_dict['valid']) f.create_dataset('test', data=data_dict['test']) f.close() + print(f'Saved binary MNIST data to: {fname}') diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index be48c4d..ed2b4f3 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -152,6 +152,7 @@ def cycle(iterable): def load_binary_mnist(cfg, **kwcfg): fname = cfg.data_dir / 'binary_mnist.h5' if not fname.exists(): + print('Downloading binary MNIST data...') data.download_binary_mnist(fname) f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binarized_mnist.hdf5'), 'r') x_train = f['train'][::] From 96337fa367e720bd59da6979b64b2a526e96cfc1 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Fri, 12 Apr 2019 17:36:19 -0400 Subject: [PATCH 16/21] Update train_variational_autoencoder_pytorch.py --- train_variational_autoencoder_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index ed2b4f3..5a662a9 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -226,7 +226,7 @@ def evaluate(n_samples, model, variational, eval_data): x = batch[0].to(device) model.zero_grad() variational.zero_grad() - z, log_q_z = variational(x) + z, log_q_z = variational(x, n_samples=1) log_p_x_and_z = model(z, x) # average over sample dimension elbo = (log_p_x_and_z - log_q_z).mean(1) From b43325e297498269be86b281a02e2ef79664f273 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Wed, 18 Sep 2019 18:50:37 -0400 Subject: [PATCH 17/21] upgrade to tf 1.1.4; slim <- keras; tf.distributions <- tf.probability; tf <- tf.compat.v1 --- environment.yml | 74 +++++++++ train_variational_autoencoder_tensorflow.py | 171 ++++++-------------- 2 files changed, 122 insertions(+), 123 deletions(-) create mode 100644 environment.yml diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..64be185 --- /dev/null +++ b/environment.yml @@ -0,0 +1,74 @@ +name: dev +channels: + - defaults +dependencies: + - blas=1.0=mkl + - ca-certificates=2019.5.15=1 + - certifi=2019.6.16=py37_1 + - freetype=2.9.1=hb4e5f40_0 + - imageio=2.5.0=py37_0 + - intel-openmp=2019.4=233 + - jpeg=9b=he5867d9_2 + - libcxx=4.0.1=hcfea43d_1 + - libcxxabi=4.0.1=hcfea43d_1 + - libedit=3.1.20181209=hb402a30_0 + - libffi=3.2.1=h475c297_4 + - libgfortran=3.0.1=h93005f0_2 + - libpng=1.6.37=ha441bb4_0 + - libtiff=4.0.10=hcb84e12_2 + - mkl=2019.4=233 + - mkl-service=2.3.0=py37hfbe908c_0 + - mkl_fft=1.0.14=py37h5e564d8_0 + - mkl_random=1.0.2=py37h27c97d8_0 + - ncurses=6.1=h0a44026_1 + - numpy=1.16.5=py37hacdab7b_0 + - numpy-base=1.16.5=py37h6575580_0 + - olefile=0.46=py37_0 + - openssl=1.1.1d=h1de35cc_1 + - pillow=6.1.0=py37hb68e598_0 + - python=3.7.4=h359304d_1 + - readline=7.0=h1de35cc_5 + - setuptools=41.0.1=py37_0 + - six=1.12.0=py37_0 + - sqlite=3.29.0=ha441bb4_0 + - tk=8.6.8=ha441bb4_0 + - wheel=0.33.4=py37_0 + - xz=5.2.4=h1de35cc_4 + - zlib=1.2.11=h1de35cc_3 + - zstd=1.3.7=h5bba6e5_0 + - pip: + - absl-py==0.8.0 + - astor==0.8.0 + - attrs==19.1.0 + - chardet==3.0.4 + - cloudpickle==1.2.2 + - decorator==4.4.0 + - dill==0.3.0 + - future==0.17.1 + - gast==0.3.2 + - google-pasta==0.1.7 + - googleapis-common-protos==1.6.0 + - grpcio==1.23.0 + - h5py==2.10.0 + - idna==2.8 + - keras-applications==1.0.8 + - keras-preprocessing==1.1.0 + - markdown==3.1.1 + - pip==19.2.3 + - promise==2.2.1 + - protobuf==3.9.1 + - psutil==5.6.3 + - requests==2.22.0 + - tensorboard==1.14.0 + - tensorflow==1.14.0 + - tensorflow-datasets==1.2.0 + - tensorflow-estimator==1.14.0 + - tensorflow-metadata==0.14.0 + - tensorflow-probability==0.7.0 + - termcolor==1.1.0 + - tqdm==4.36.0 + - urllib3==1.25.3 + - werkzeug==0.15.6 + - wrapt==1.11.2 +prefix: /usr/local/anaconda3/envs/dev + diff --git a/train_variational_autoencoder_tensorflow.py b/train_variational_autoencoder_tensorflow.py index a6cba28..922b83f 100644 --- a/train_variational_autoencoder_tensorflow.py +++ b/train_variational_autoencoder_tensorflow.py @@ -1,33 +1,20 @@ import itertools -import matplotlib as mpl import numpy as np import os import tensorflow as tf +import tensorflow.keras as tfk import tensorflow.contrib.slim as slim import time -import seaborn as sns - -from matplotlib import pyplot as plt +import tensorflow_datasets as tfds +import tensorflow_probability as tfp from imageio import imwrite from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets - -sns.set_style('whitegrid') - -distributions = tf.distributions +tfkl = tfk.layers +tfc = tf.compat.v1 flags = tf.app.flags flags.DEFINE_string('data_dir', '/tmp/dat/', 'Directory for data') flags.DEFINE_string('logdir', '/tmp/log/', 'Directory for logs') - -# For making plots: -# flags.DEFINE_integer('latent_dim', 2, 'Latent dimensionality of model') -# flags.DEFINE_integer('batch_size', 64, 'Minibatch size') -# flags.DEFINE_integer('n_samples', 10, 'Number of samples to save') -# flags.DEFINE_integer('print_every', 10, 'Print every n iterations') -# flags.DEFINE_integer('hidden_size', 200, 'Hidden size for neural networks') -# flags.DEFINE_integer('n_iterations', 1000, 'number of iterations') - -# For bigger model: flags.DEFINE_integer('latent_dim', 100, 'Latent dimensionality of model') flags.DEFINE_integer('batch_size', 64, 'Minibatch size') flags.DEFINE_integer('n_samples', 1, 'Number of samples to save') @@ -50,12 +37,13 @@ def inference_network(x, latent_dim, hidden_size): mu: Mean parameters for the variational family Normal sigma: Standard deviation parameters for the variational family Normal """ - with slim.arg_scope([slim.fully_connected], activation_fn=tf.nn.relu): - net = slim.flatten(x) - net = slim.fully_connected(net, hidden_size) - net = slim.fully_connected(net, hidden_size) - gaussian_params = slim.fully_connected( - net, latent_dim * 2, activation_fn=None) + inference_net = tfk.Sequential([ + tfkl.Flatten(), + tfkl.Dense(hidden_size, activation=tf.nn.relu), + tfkl.Dense(hidden_size, activation=tf.nn.relu), + tfkl.Dense(latent_dim * 2, activation=None) + ]) + gaussian_params = inference_net(x) # The mean parameter is unconstrained mu = gaussian_params[:, :latent_dim] # The standard deviation must be positive. Parametrize with a softplus @@ -73,12 +61,13 @@ def generative_network(z, hidden_size): Returns: bernoulli_logits: logits for the Bernoulli likelihood of the data """ - with slim.arg_scope([slim.fully_connected], activation_fn=tf.nn.relu): - net = slim.fully_connected(z, hidden_size) - net = slim.fully_connected(net, hidden_size) - bernoulli_logits = slim.fully_connected(net, 784, activation_fn=None) - bernoulli_logits = tf.reshape(bernoulli_logits, [-1, 28, 28, 1]) - return bernoulli_logits + generative_net = tfk.Sequential([ + tfkl.Dense(hidden_size, activation=tf.nn.relu), + tfkl.Dense(hidden_size, activation=tf.nn.relu), + tfkl.Dense(28 * 28, activation=None) + ]) + bernoulli_logits = generative_net(z) + return tf.reshape(bernoulli_logits, [-1, 28, 28, 1]) def train(): @@ -86,87 +75,76 @@ def train(): # Input placeholders with tf.name_scope('data'): - x = tf.placeholder(tf.float32, [None, 28, 28, 1]) - tf.summary.image('data', x) + x = tfc.placeholder(tf.float32, [None, 28, 28, 1]) + tfc.summary.image('data', x) - with tf.variable_scope('variational'): + with tfc.variable_scope('variational'): q_mu, q_sigma = inference_network(x=x, latent_dim=FLAGS.latent_dim, hidden_size=FLAGS.hidden_size) # The variational distribution is a Normal with mean and standard # deviation given by the inference network - q_z = distributions.Normal(loc=q_mu, scale=q_sigma) - assert q_z.reparameterization_type == distributions.FULLY_REPARAMETERIZED + q_z = tfp.distributions.Normal(loc=q_mu, scale=q_sigma) + assert q_z.reparameterization_type == tfp.distributions.FULLY_REPARAMETERIZED - with tf.variable_scope('model'): + with tfc.variable_scope('model'): # The likelihood is Bernoulli-distributed with logits given by the # generative network p_x_given_z_logits = generative_network(z=q_z.sample(), hidden_size=FLAGS.hidden_size) - p_x_given_z = distributions.Bernoulli(logits=p_x_given_z_logits) + p_x_given_z = tfp.distributions.Bernoulli(logits=p_x_given_z_logits) posterior_predictive_samples = p_x_given_z.sample() - tf.summary.image('posterior_predictive', + tfc.summary.image('posterior_predictive', tf.cast(posterior_predictive_samples, tf.float32)) # Take samples from the prior - with tf.variable_scope('model', reuse=True): - p_z = distributions.Normal(loc=np.zeros(FLAGS.latent_dim, dtype=np.float32), + with tfc.variable_scope('model', reuse=True): + p_z = tfp.distributions.Normal(loc=np.zeros(FLAGS.latent_dim, dtype=np.float32), scale=np.ones(FLAGS.latent_dim, dtype=np.float32)) p_z_sample = p_z.sample(FLAGS.n_samples) p_x_given_z_logits = generative_network(z=p_z_sample, hidden_size=FLAGS.hidden_size) - prior_predictive = distributions.Bernoulli(logits=p_x_given_z_logits) + prior_predictive = tfp.distributions.Bernoulli(logits=p_x_given_z_logits) prior_predictive_samples = prior_predictive.sample() - tf.summary.image('prior_predictive', + tfc.summary.image('prior_predictive', tf.cast(prior_predictive_samples, tf.float32)) # Take samples from the prior with a placeholder - with tf.variable_scope('model', reuse=True): + with tfc.variable_scope('model', reuse=True): z_input = tf.placeholder(tf.float32, [None, FLAGS.latent_dim]) p_x_given_z_logits = generative_network(z=z_input, hidden_size=FLAGS.hidden_size) - prior_predictive_inp = distributions.Bernoulli(logits=p_x_given_z_logits) + prior_predictive_inp = tfp.distributions.Bernoulli(logits=p_x_given_z_logits) prior_predictive_inp_sample = prior_predictive_inp.sample() # Build the evidence lower bound (ELBO) or the negative loss - kl = tf.reduce_sum(distributions.kl_divergence(q_z, p_z), 1) + kl = tf.reduce_sum(tfp.distributions.kl_divergence(q_z, p_z), 1) expected_log_likelihood = tf.reduce_sum(p_x_given_z.log_prob(x), [1, 2, 3]) elbo = tf.reduce_sum(expected_log_likelihood - kl, 0) - - optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001) - + optimizer = tfc.train.RMSPropOptimizer(learning_rate=0.001) train_op = optimizer.minimize(-elbo) # Merge all the summaries - summary_op = tf.summary.merge_all() + summary_op = tfc.summary.merge_all() - init_op = tf.global_variables_initializer() + init_op = tfc.global_variables_initializer() # Run training - sess = tf.InteractiveSession() + sess = tfc.InteractiveSession() sess.run(init_op) - mnist = read_data_sets(FLAGS.data_dir, one_hot=True) + mnist_data = tfds.load(name='binarized_mnist', split='train', shuffle_files=False) + dataset = mnist_data.repeat().shuffle(buffer_size=1024).batch(FLAGS.batch_size) print('Saving TensorBoard summaries and images to: %s' % FLAGS.logdir) - train_writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph) - - # Get fixed MNIST digits for plotting posterior means during training - np_x_fixed, np_y = mnist.test.next_batch(5000) - np_x_fixed = np_x_fixed.reshape(5000, 28, 28, 1) - np_x_fixed = (np_x_fixed > 0.5).astype(np.float32) + train_writer = tfc.summary.FileWriter(FLAGS.logdir, sess.graph) t0 = time.time() - for i in range(FLAGS.n_iterations): - # Re-binarize the data at every batch; this improves results - np_x, _ = mnist.train.next_batch(FLAGS.batch_size) - np_x = np_x.reshape(FLAGS.batch_size, 28, 28, 1) - np_x = (np_x > 0.5).astype(np.float32) + for i, batch in enumerate(tfds.as_numpy(dataset)): + np_x = batch['image'] sess.run(train_op, {x: np_x}) - - # Print progress and save samples every so often if i % FLAGS.print_every == 0: np_elbo, summary_str = sess.run([elbo, summary_op], {x: np_x}) train_writer.add_summary(summary_str, i) @@ -174,73 +152,20 @@ def train(): i, np_elbo / FLAGS.batch_size, (time.time() - t0) / FLAGS.print_every)) - t0 = time.time() - # Save samples np_posterior_samples, np_prior_samples = sess.run( [posterior_predictive_samples, prior_predictive_samples], {x: np_x}) for k in range(FLAGS.n_samples): f_name = os.path.join( FLAGS.logdir, 'iter_%d_posterior_predictive_%d_data.jpg' % (i, k)) - imwrite(f_name, np_x[k, :, :, 0]) + imwrite(f_name, np_x[k, :, :, 0].astype(np.uint8)) f_name = os.path.join( FLAGS.logdir, 'iter_%d_posterior_predictive_%d_sample.jpg' % (i, k)) - imwrite(f_name, np_posterior_samples[k, :, :, 0]) + imwrite(f_name, np_posterior_samples[k, :, :, 0].astype(np.uint8)) f_name = os.path.join( FLAGS.logdir, 'iter_%d_prior_predictive_%d.jpg' % (i, k)) - imwrite(f_name, np_prior_samples[k, :, :, 0]) - - # Plot the posterior predictive space - if FLAGS.latent_dim == 2: - np_q_mu = sess.run(q_mu, {x: np_x_fixed}) - cmap = mpl.colors.ListedColormap(sns.color_palette("husl")) - f, ax = plt.subplots(1, figsize=(6 * 1.1618, 6)) - im = ax.scatter(np_q_mu[:, 0], np_q_mu[:, 1], c=np.argmax(np_y, 1), cmap=cmap, - alpha=0.7) - ax.set_xlabel('First dimension of sampled latent variable $z_1$') - ax.set_ylabel('Second dimension of sampled latent variable mean $z_2$') - ax.set_xlim([-10., 10.]) - ax.set_ylim([-10., 10.]) - f.colorbar(im, ax=ax, label='Digit class') - plt.tight_layout() - plt.savefig(os.path.join(FLAGS.logdir, - 'posterior_predictive_map_frame_%d.png' % i)) - plt.close() - - nx = ny = 20 - x_values = np.linspace(-3, 3, nx) - y_values = np.linspace(-3, 3, ny) - canvas = np.empty((28 * ny, 28 * nx)) - for ii, yi in enumerate(x_values): - for j, xi in enumerate(y_values): - np_z = np.array([[xi, yi]]) - x_mean = sess.run(prior_predictive_inp_sample, {z_input: np_z}) - canvas[(nx - ii - 1) * 28:(nx - ii) * 28, j * - 28:(j + 1) * 28] = x_mean[0].reshape(28, 28) - imwrite(os.path.join(FLAGS.logdir, - 'prior_predictive_map_frame_%d.png' % i), canvas) - # plt.figure(figsize=(8, 10)) - # Xi, Yi = np.meshgrid(x_values, y_values) - # plt.imshow(canvas, origin="upper") - # plt.tight_layout() - # plt.savefig() - - # Make the gifs - if FLAGS.latent_dim == 2: - os.system( - 'convert -delay 15 -loop 0 {0}/posterior_predictive_map_frame*png {0}/posterior_predictive.gif' - .format(FLAGS.logdir)) - os.system( - 'convert -delay 15 -loop 0 {0}/prior_predictive_map_frame*png {0}/prior_predictive.gif' - .format(FLAGS.logdir)) - - -def main(_): - if tf.gfile.Exists(FLAGS.logdir): - tf.gfile.DeleteRecursively(FLAGS.logdir) - tf.gfile.MakeDirs(FLAGS.logdir) - train() - + imwrite(f_name, np_prior_samples[k, :, :, 0].astype(np.uint8)) + t0 = time.time() if __name__ == '__main__': - tf.app.run() + train() From 2c702b5ccbe029632a86c1f2ff904a8142619310 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Mon, 18 Nov 2019 14:24:26 -0500 Subject: [PATCH 18/21] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 68acb24..39e78e2 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ step: 65000 test elbo: -103.31 test log p(x): -97.10 ``` -Using a non mean-field, more expressive variational posterior approximation, the test marginal log-likelihood improves to `-95.33` nats: +Using a non mean-field, more expressive variational posterior approximation (inverse autoregressive flow, https://arxiv.org/abs/1606.04934), the test marginal log-likelihood improves to `-95.33` nats: ``` $ python train_variational_autoencoder_pytorch.py --variational flow From 526d716d7799c20e6bd94a7283e8ab1741d34adf Mon Sep 17 00:00:00 2001 From: "Mohamad H. Danesh" Date: Mon, 23 Mar 2020 16:18:35 -0700 Subject: [PATCH 19/21] Update train_variational_autoencoder_pytorch.py --- train_variational_autoencoder_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index 5a662a9..bda329f 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -154,7 +154,7 @@ def load_binary_mnist(cfg, **kwcfg): if not fname.exists(): print('Downloading binary MNIST data...') data.download_binary_mnist(fname) - f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binarized_mnist.hdf5'), 'r') + f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binary_mnist.h5'), 'r') x_train = f['train'][::] x_val = f['valid'][::] x_test = f['test'][::] From 898ffd0aa3bfde5fd8a4bc4c45e9630e2edb5e0e Mon Sep 17 00:00:00 2001 From: "Mohamad H. Danesh" Date: Mon, 23 Mar 2020 16:19:33 -0700 Subject: [PATCH 20/21] Update train_variational_autoencoder_tensorflow.py --- train_variational_autoencoder_tensorflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_variational_autoencoder_tensorflow.py b/train_variational_autoencoder_tensorflow.py index 922b83f..b7531f2 100644 --- a/train_variational_autoencoder_tensorflow.py +++ b/train_variational_autoencoder_tensorflow.py @@ -135,7 +135,7 @@ def train(): sess = tfc.InteractiveSession() sess.run(init_op) - mnist_data = tfds.load(name='binarized_mnist', split='train', shuffle_files=False) + mnist_data = tfds.load(name='binary_mnist', split='train', shuffle_files=False) dataset = mnist_data.repeat().shuffle(buffer_size=1024).batch(FLAGS.batch_size) print('Saving TensorBoard summaries and images to: %s' % FLAGS.logdir) From c7b298ef64773c798b3f24444af780f45ec81043 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Fri, 22 Jan 2021 13:48:15 -0500 Subject: [PATCH 21/21] Update README.md --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 39e78e2..ebe274c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ -# Variational Autoencoder / Deep Latent Gaussian Model in tensorflow and pytorch +# Variational Autoencoder in tensorflow and pytorch +[![DOI](https://zenodo.org/badge/65744394.svg)](https://zenodo.org/badge/latestdoi/65744394) + Reference implementation for a variational autoencoder in TensorFlow and PyTorch. I recommend the PyTorch version. It includes an example of a more expressive variational family, the [inverse autoregressive flow](https://arxiv.org/abs/1606.04934).