1
+ import itertools
2
+ import matplotlib as mpl
1
3
import numpy as np
2
4
import os
3
5
import tensorflow as tf
4
6
import tensorflow .contrib .slim as slim
5
7
import time
8
+ import seaborn as sns
6
9
10
+ from matplotlib import pyplot as plt
7
11
from scipy .misc import imsave
8
12
from tensorflow .contrib .learn .python .learn .datasets .mnist import read_data_sets
9
13
14
+ sns .set_style ('whitegrid' )
10
15
11
16
sg = tf .contrib .bayesflow .stochastic_graph
12
17
distributions = tf .contrib .distributions
15
20
flags = tf .app .flags
16
21
flags .DEFINE_string ('data_dir' , '/tmp/data/' , 'Directory for storing data' )
17
22
flags .DEFINE_string ('logdir' , '/tmp/logs/' , 'Directory for storing data' )
18
- flags .DEFINE_integer ('latent_dim' , 100 , 'Latent dimensionality of model' )
23
+ flags .DEFINE_integer ('latent_dim' , 2 , 'Latent dimensionality of model' )
19
24
flags .DEFINE_integer ('batch_size' , 64 , 'Minibatch size' )
20
25
flags .DEFINE_integer ('n_samples' , 10 , 'Number of samples to save' )
21
- flags .DEFINE_integer ('print_every' , 1000 , 'Print every n iterations' )
26
+ flags .DEFINE_integer ('print_every' , 10 , 'Print every n iterations' )
27
+ flags .DEFINE_integer ('hidden_size' , 200 , 'Hidden size for neural networks' )
22
28
23
29
FLAGS = flags .FLAGS
24
30
@@ -41,7 +47,10 @@ def inference_network(x, latent_dim, hidden_size):
41
47
net = slim .fully_connected (net , hidden_size )
42
48
gaussian_params = slim .fully_connected (
43
49
net , latent_dim * 2 , activation_fn = None )
50
+ # The mean parameter is unconstrained
44
51
mu = gaussian_params [:, :latent_dim ]
52
+ # The standard deviation must be positive. Parametrize with a softplus and
53
+ # add a small epsilon for numerical stability
45
54
sigma = 1e-6 + tf .nn .softplus (gaussian_params [:, latent_dim :])
46
55
return mu , sigma
47
56
@@ -67,7 +76,7 @@ def generative_network(z, hidden_size):
67
76
def train ():
68
77
# Train a Variational Autoencoder on MNIST
69
78
70
- # Input placehoolders
79
+ # Input placeholders
71
80
with tf .name_scope ('data' ):
72
81
x = tf .placeholder (tf .float32 , [None , 28 , 28 , 1 ])
73
82
tf .image_summary ('data' , x , max_images = 10 )
@@ -76,39 +85,52 @@ def train():
76
85
with tf .variable_scope ('variational' ):
77
86
q_mu , q_sigma = inference_network (x = x ,
78
87
latent_dim = FLAGS .latent_dim ,
79
- hidden_size = 200 )
88
+ hidden_size = FLAGS . hidden_size )
80
89
with sg .value_type (sg .SampleAndReshapeValue ()):
90
+ # The variational distribution is a Normal with mean and standard
91
+ # deviation given by the inference network
81
92
q_z = sg .DistributionTensor (distributions .Normal , mu = q_mu , sigma = q_sigma )
82
93
83
94
with tf .variable_scope ('model' ):
84
- p_x_given_z_logits = generative_network (z = q_z , hidden_size = 200 )
85
- posterior_predictive = distributions .Bernoulli (logits = p_x_given_z_logits )
86
- posterior_predictive_samples = posterior_predictive .sample ()
95
+ # The likelihood is Bernoulli-distributed with logits given by the
96
+ # generative network
97
+ p_x_given_z_logits = generative_network (z = q_z ,
98
+ hidden_size = FLAGS .hidden_size )
99
+ p_x_given_z = distributions .Bernoulli (logits = p_x_given_z_logits )
100
+ posterior_predictive_samples = p_x_given_z .sample ()
87
101
tf .image_summary ('posterior_predictive' ,
88
102
tf .cast (posterior_predictive_samples , tf .float32 ),
89
103
max_images = 10 )
90
104
91
-
105
+ # Take samples from the prior
92
106
with tf .variable_scope ('model' , reuse = True ):
93
107
p_z = distributions .Normal (mu = np .zeros (FLAGS .latent_dim , dtype = np .float32 ),
94
108
sigma = np .ones (FLAGS .latent_dim , dtype = np .float32 ))
95
109
p_z_sample = p_z .sample_n (FLAGS .n_samples )
96
- p_x_given_z_logits = generative_network (z = p_z_sample , hidden_size = 200 )
110
+ p_x_given_z_logits = generative_network (z = p_z_sample ,
111
+ hidden_size = FLAGS .hidden_size )
97
112
prior_predictive = distributions .Bernoulli (logits = p_x_given_z_logits )
98
113
prior_predictive_samples = prior_predictive .sample ()
99
114
tf .image_summary ('prior_predictive' ,
100
115
tf .cast (prior_predictive_samples , tf .float32 ),
101
116
max_images = 10 )
102
117
118
+ # Take samples from the prior with a placeholder
119
+ with tf .variable_scope ('model' , reuse = True ):
120
+ z_input = tf .placeholder (tf .float32 , [None , FLAGS .latent_dim ])
121
+ p_x_given_z_logits = generative_network (z = z_input ,
122
+ hidden_size = FLAGS .hidden_size )
123
+ prior_predictive_inp = distributions .Bernoulli (logits = p_x_given_z_logits )
124
+ prior_predictive_inp_sample = prior_predictive_inp .sample ()
103
125
104
126
# Build the evidence lower bound (ELBO) or the negative loss
105
127
kl = tf .reduce_sum (distributions .kl (q_z .distribution , p_z ), 1 )
106
- expected_log_likelihood = tf .reduce_sum (posterior_predictive .log_pmf (x ),
128
+ expected_log_likelihood = tf .reduce_sum (p_x_given_z .log_pmf (x ),
107
129
[1 , 2 , 3 ])
108
130
109
131
elbo = tf .reduce_sum (expected_log_likelihood - kl , 0 )
110
132
111
- optimizer = tf .train .AdamOptimizer (learning_rate = 0.001 )
133
+ optimizer = tf .train .RMSPropOptimizer (learning_rate = 0.001 )
112
134
113
135
train_op = optimizer .minimize (- elbo )
114
136
@@ -126,11 +148,17 @@ def train():
126
148
print ('Saving TensorBoard summaries and images to: %s' % FLAGS .logdir )
127
149
train_writer = tf .train .SummaryWriter (FLAGS .logdir , sess .graph )
128
150
129
- for i in range (100000 ):
151
+ # Get fixed MNIST digits for plotting posterior means during training
152
+ np_x_fixed , np_y = mnist .test .next_batch (5000 )
153
+ np_x_fixed = np_x_fixed .reshape (5000 , 28 , 28 , 1 )
154
+ np_x_fixed = (np_x_fixed > 0.5 ).astype (np .float32 )
155
+
156
+ for i in range (1000 ):
157
+ # Re-binarize the data at every batch; this improves results
130
158
np_x , _ = mnist .train .next_batch (FLAGS .batch_size )
131
159
np_x = np_x .reshape (FLAGS .batch_size , 28 , 28 , 1 )
132
160
np_x = (np_x > 0.5 ).astype (np .float32 )
133
- sess .run (train_op , feed_dict = {x : np_x })
161
+ sess .run (train_op , {x : np_x })
134
162
135
163
# Print progress and save samples every so often
136
164
t0 = time .time ()
@@ -157,6 +185,47 @@ def train():
157
185
FLAGS .logdir , 'iter_%d_prior_predictive_%d.jpg' % (i , k ))
158
186
imsave (f_name , np_prior_samples [k , :, :, 0 ])
159
187
188
+ # Plot the posterior predictive space
189
+ if FLAGS .latent_dim == 2 :
190
+ np_q_mu = sess .run (q_mu , {x : np_x_fixed })
191
+ cmap = mpl .colors .ListedColormap (sns .color_palette ("husl" ))
192
+ f , ax = plt .subplots (1 , figsize = (6 * 1.1618 , 6 ))
193
+ im = ax .scatter (np_q_mu [:, 0 ], np_q_mu [:, 1 ], c = np .argmax (np_y , 1 ), cmap = cmap ,
194
+ alpha = 0.7 )
195
+ ax .set_xlabel ('First dimension of sampled latent variable $z_1$' )
196
+ ax .set_ylabel ('Second dimension of sampled latent variable mean $z_2$' )
197
+ ax .set_xlim ([- 10. , 10. ])
198
+ ax .set_ylim ([- 10. , 10. ])
199
+ f .colorbar (im , ax = ax , label = 'Digit class' )
200
+ plt .tight_layout ()
201
+ plt .savefig (os .path .join (FLAGS .logdir ,
202
+ 'posterior_predictive_map_frame_%d.png' % i ))
203
+ plt .close ()
204
+
205
+ nx = ny = 20
206
+ x_values = np .linspace (- 3 , 3 , nx )
207
+ y_values = np .linspace (- 3 , 3 , ny )
208
+ canvas = np .empty ((28 * ny , 28 * nx ))
209
+ for ii , yi in enumerate (x_values ):
210
+ for j , xi in enumerate (y_values ):
211
+ np_z = np .array ([[xi , yi ]])
212
+ x_mean = sess .run (prior_predictive_inp_sample , {z_input : np_z })
213
+ canvas [(nx - ii - 1 )* 28 :(nx - ii )* 28 , j * 28 :(j + 1 )* 28 ] = x_mean [0 ].reshape (28 , 28 )
214
+ imsave (os .path .join (FLAGS .logdir ,
215
+ 'prior_predictive_map_frame_%d.png' % i ), canvas )
216
+ # plt.figure(figsize=(8, 10))
217
+ # Xi, Yi = np.meshgrid(x_values, y_values)
218
+ # plt.imshow(canvas, origin="upper")
219
+ # plt.tight_layout()
220
+ # plt.savefig()
221
+
222
+ # Make the gifs
223
+ os .system (
224
+ 'convert -delay 15 -loop 0 {0}/posterior_predictive_map_frame*png {0}/posterior_predictive.gif'
225
+ .format (FLAGS .logdir ))
226
+ os .system (
227
+ 'convert -delay 15 -loop 0 {0}/prior_predictive_map_frame*png {0}/prior_predictive.gif'
228
+ .format (FLAGS .logdir ))
160
229
161
230
def main (_ ):
162
231
if tf .gfile .Exists (FLAGS .logdir ):
0 commit comments