Skip to content

Commit 567ec74

Browse files
committed
update experiments
1 parent 7d796b7 commit 567ec74

File tree

2 files changed

+84
-13
lines changed

2 files changed

+84
-13
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
# vae
22
Variational Autoencoder or Deep Latent Gaussian Model demo
3+
4+
Blog post: https://jaan.io/unreasonable-confusion/

vae.py

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
import itertools
2+
import matplotlib as mpl
13
import numpy as np
24
import os
35
import tensorflow as tf
46
import tensorflow.contrib.slim as slim
57
import time
8+
import seaborn as sns
69

10+
from matplotlib import pyplot as plt
711
from scipy.misc import imsave
812
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
913

14+
sns.set_style('whitegrid')
1015

1116
sg = tf.contrib.bayesflow.stochastic_graph
1217
distributions = tf.contrib.distributions
@@ -15,10 +20,11 @@
1520
flags = tf.app.flags
1621
flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data')
1722
flags.DEFINE_string('logdir', '/tmp/logs/', 'Directory for storing data')
18-
flags.DEFINE_integer('latent_dim', 100, 'Latent dimensionality of model')
23+
flags.DEFINE_integer('latent_dim', 2, 'Latent dimensionality of model')
1924
flags.DEFINE_integer('batch_size', 64, 'Minibatch size')
2025
flags.DEFINE_integer('n_samples', 10, 'Number of samples to save')
21-
flags.DEFINE_integer('print_every', 1000, 'Print every n iterations')
26+
flags.DEFINE_integer('print_every', 10, 'Print every n iterations')
27+
flags.DEFINE_integer('hidden_size', 200, 'Hidden size for neural networks')
2228

2329
FLAGS = flags.FLAGS
2430

@@ -41,7 +47,10 @@ def inference_network(x, latent_dim, hidden_size):
4147
net = slim.fully_connected(net, hidden_size)
4248
gaussian_params = slim.fully_connected(
4349
net, latent_dim * 2, activation_fn=None)
50+
# The mean parameter is unconstrained
4451
mu = gaussian_params[:, :latent_dim]
52+
# The standard deviation must be positive. Parametrize with a softplus and
53+
# add a small epsilon for numerical stability
4554
sigma = 1e-6 + tf.nn.softplus(gaussian_params[:, latent_dim:])
4655
return mu, sigma
4756

@@ -67,7 +76,7 @@ def generative_network(z, hidden_size):
6776
def train():
6877
# Train a Variational Autoencoder on MNIST
6978

70-
# Input placehoolders
79+
# Input placeholders
7180
with tf.name_scope('data'):
7281
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
7382
tf.image_summary('data', x, max_images=10)
@@ -76,39 +85,52 @@ def train():
7685
with tf.variable_scope('variational'):
7786
q_mu, q_sigma = inference_network(x=x,
7887
latent_dim=FLAGS.latent_dim,
79-
hidden_size=200)
88+
hidden_size=FLAGS.hidden_size)
8089
with sg.value_type(sg.SampleAndReshapeValue()):
90+
# The variational distribution is a Normal with mean and standard
91+
# deviation given by the inference network
8192
q_z = sg.DistributionTensor(distributions.Normal, mu=q_mu, sigma=q_sigma)
8293

8394
with tf.variable_scope('model'):
84-
p_x_given_z_logits = generative_network(z=q_z, hidden_size=200)
85-
posterior_predictive = distributions.Bernoulli(logits=p_x_given_z_logits)
86-
posterior_predictive_samples = posterior_predictive.sample()
95+
# The likelihood is Bernoulli-distributed with logits given by the
96+
# generative network
97+
p_x_given_z_logits = generative_network(z=q_z,
98+
hidden_size=FLAGS.hidden_size)
99+
p_x_given_z = distributions.Bernoulli(logits=p_x_given_z_logits)
100+
posterior_predictive_samples = p_x_given_z.sample()
87101
tf.image_summary('posterior_predictive',
88102
tf.cast(posterior_predictive_samples, tf.float32),
89103
max_images=10)
90104

91-
105+
# Take samples from the prior
92106
with tf.variable_scope('model', reuse=True):
93107
p_z = distributions.Normal(mu=np.zeros(FLAGS.latent_dim, dtype=np.float32),
94108
sigma=np.ones(FLAGS.latent_dim, dtype=np.float32))
95109
p_z_sample = p_z.sample_n(FLAGS.n_samples)
96-
p_x_given_z_logits = generative_network(z=p_z_sample, hidden_size=200)
110+
p_x_given_z_logits = generative_network(z=p_z_sample,
111+
hidden_size=FLAGS.hidden_size)
97112
prior_predictive = distributions.Bernoulli(logits=p_x_given_z_logits)
98113
prior_predictive_samples = prior_predictive.sample()
99114
tf.image_summary('prior_predictive',
100115
tf.cast(prior_predictive_samples, tf.float32),
101116
max_images=10)
102117

118+
# Take samples from the prior with a placeholder
119+
with tf.variable_scope('model', reuse=True):
120+
z_input = tf.placeholder(tf.float32, [None, FLAGS.latent_dim])
121+
p_x_given_z_logits = generative_network(z=z_input,
122+
hidden_size=FLAGS.hidden_size)
123+
prior_predictive_inp = distributions.Bernoulli(logits=p_x_given_z_logits)
124+
prior_predictive_inp_sample = prior_predictive_inp.sample()
103125

104126
# Build the evidence lower bound (ELBO) or the negative loss
105127
kl = tf.reduce_sum(distributions.kl(q_z.distribution, p_z), 1)
106-
expected_log_likelihood = tf.reduce_sum(posterior_predictive.log_pmf(x),
128+
expected_log_likelihood = tf.reduce_sum(p_x_given_z.log_pmf(x),
107129
[1, 2, 3])
108130

109131
elbo = tf.reduce_sum(expected_log_likelihood - kl, 0)
110132

111-
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
133+
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001)
112134

113135
train_op = optimizer.minimize(-elbo)
114136

@@ -126,11 +148,17 @@ def train():
126148
print('Saving TensorBoard summaries and images to: %s' % FLAGS.logdir)
127149
train_writer = tf.train.SummaryWriter(FLAGS.logdir, sess.graph)
128150

129-
for i in range(100000):
151+
# Get fixed MNIST digits for plotting posterior means during training
152+
np_x_fixed, np_y = mnist.test.next_batch(5000)
153+
np_x_fixed = np_x_fixed.reshape(5000, 28, 28, 1)
154+
np_x_fixed = (np_x_fixed > 0.5).astype(np.float32)
155+
156+
for i in range(1000):
157+
# Re-binarize the data at every batch; this improves results
130158
np_x, _ = mnist.train.next_batch(FLAGS.batch_size)
131159
np_x = np_x.reshape(FLAGS.batch_size, 28, 28, 1)
132160
np_x = (np_x > 0.5).astype(np.float32)
133-
sess.run(train_op, feed_dict={x: np_x})
161+
sess.run(train_op, {x: np_x})
134162

135163
# Print progress and save samples every so often
136164
t0 = time.time()
@@ -157,6 +185,47 @@ def train():
157185
FLAGS.logdir, 'iter_%d_prior_predictive_%d.jpg' % (i, k))
158186
imsave(f_name, np_prior_samples[k, :, :, 0])
159187

188+
# Plot the posterior predictive space
189+
if FLAGS.latent_dim == 2:
190+
np_q_mu = sess.run(q_mu, {x: np_x_fixed})
191+
cmap = mpl.colors.ListedColormap(sns.color_palette("husl"))
192+
f, ax = plt.subplots(1, figsize=(6 * 1.1618, 6))
193+
im = ax.scatter(np_q_mu[:, 0], np_q_mu[:, 1], c=np.argmax(np_y, 1), cmap=cmap,
194+
alpha=0.7)
195+
ax.set_xlabel('First dimension of sampled latent variable $z_1$')
196+
ax.set_ylabel('Second dimension of sampled latent variable mean $z_2$')
197+
ax.set_xlim([-10., 10.])
198+
ax.set_ylim([-10., 10.])
199+
f.colorbar(im, ax=ax, label='Digit class')
200+
plt.tight_layout()
201+
plt.savefig(os.path.join(FLAGS.logdir,
202+
'posterior_predictive_map_frame_%d.png' % i))
203+
plt.close()
204+
205+
nx = ny = 20
206+
x_values = np.linspace(-3, 3, nx)
207+
y_values = np.linspace(-3, 3, ny)
208+
canvas = np.empty((28*ny, 28*nx))
209+
for ii, yi in enumerate(x_values):
210+
for j, xi in enumerate(y_values):
211+
np_z = np.array([[xi, yi]])
212+
x_mean = sess.run(prior_predictive_inp_sample, {z_input: np_z})
213+
canvas[(nx-ii-1)*28:(nx-ii)*28, j*28:(j+1)*28] = x_mean[0].reshape(28, 28)
214+
imsave(os.path.join(FLAGS.logdir,
215+
'prior_predictive_map_frame_%d.png' % i), canvas)
216+
# plt.figure(figsize=(8, 10))
217+
# Xi, Yi = np.meshgrid(x_values, y_values)
218+
# plt.imshow(canvas, origin="upper")
219+
# plt.tight_layout()
220+
# plt.savefig()
221+
222+
# Make the gifs
223+
os.system(
224+
'convert -delay 15 -loop 0 {0}/posterior_predictive_map_frame*png {0}/posterior_predictive.gif'
225+
.format(FLAGS.logdir))
226+
os.system(
227+
'convert -delay 15 -loop 0 {0}/prior_predictive_map_frame*png {0}/prior_predictive.gif'
228+
.format(FLAGS.logdir))
160229

161230
def main(_):
162231
if tf.gfile.Exists(FLAGS.logdir):

0 commit comments

Comments
 (0)