15
15
import optax
16
16
import tensorflow_datasets as tfds
17
17
from tensorflow_probability .substrates import jax as tfp
18
+ import distrax
18
19
19
20
tfd = tfp .distributions
20
21
@@ -110,13 +111,91 @@ def condition(self, inputs):
110
111
scale = jax .nn .softplus (scale_arg )
111
112
return loc , scale
112
113
113
- def __call__ (self , x : jnp .ndarray ) -> Tuple [ jnp . ndarray , jnp . ndarray ] :
114
+ def __call__ (self , x : jnp .ndarray ) -> tfd . Distribution :
114
115
loc , scale = self .condition (x )
115
116
# IMPORTANT: need to check in source code that reparameterization_type=tfd.FULLY_REPARAMETERIZED for this class
116
117
q_z = tfd .Normal (loc = loc , scale = scale )
117
118
return q_z
118
119
119
120
121
+ def make_conditioner (
122
+ event_shape : Sequence [int ], hidden_sizes : Sequence [int ], num_bijector_params : int
123
+ ) -> hk .Sequential :
124
+ """Creates an MLP conditioner for each layer of the flow."""
125
+ return hk .Sequential (
126
+ [
127
+ hk .Flatten (preserve_dims = - len (event_shape )),
128
+ hk .nets .MLP (hidden_sizes , activate_final = True ),
129
+ # We initialize this linear layer to zero so that the flow is initialized
130
+ # to the identity function.
131
+ hk .Linear (
132
+ np .prod (event_shape ) * num_bijector_params ,
133
+ w_init = jnp .zeros ,
134
+ b_init = jnp .zeros ,
135
+ ),
136
+ hk .Reshape (tuple (event_shape ) + (num_bijector_params ,), preserve_dims = - 1 ),
137
+ ]
138
+ )
139
+
140
+
141
+ def make_flow (
142
+ event_shape : Sequence [int ],
143
+ num_layers : int ,
144
+ hidden_sizes : Sequence [int ],
145
+ num_bins : int ,
146
+ ) -> distrax .Transformed :
147
+ """Creates the flow model."""
148
+ # Alternating binary mask.
149
+ mask = jnp .arange (0 , np .prod (event_shape )) % 2
150
+ mask = jnp .reshape (mask , event_shape )
151
+ mask = mask .astype (bool )
152
+
153
+ def bijector_fn (params : jnp .array ):
154
+ return distrax .RationalQuadraticSpline (params , range_min = 0.0 , range_max = 1.0 )
155
+
156
+ # Number of parameters for the rational-quadratic spline:
157
+ # - `num_bins` bin widths
158
+ # - `num_bins` bin heights
159
+ # - `num_bins + 1` knot slopes
160
+ # for a total of `3 * num_bins + 1` parameters.
161
+ num_bijector_params = 3 * num_bins + 1
162
+
163
+ layers = []
164
+ for _ in range (num_layers ):
165
+ layer = distrax .MaskedCoupling (
166
+ mask = mask ,
167
+ bijector = bijector_fn ,
168
+ conditioner = make_conditioner (
169
+ event_shape , hidden_sizes , num_bijector_params
170
+ ),
171
+ )
172
+ layers .append (layer )
173
+ # Flip the mask after each layer.
174
+ mask = jnp .logical_not (mask )
175
+
176
+ # We invert the flow so that the `forward` method is called with `log_prob`.
177
+ flow = distrax .Inverse (distrax .Chain (layers ))
178
+ base_distribution = distrax .MultivariateNormalDiag (
179
+ loc = jnp .zeros (event_shape ), scale_diag = jnp .ones (event_shape )
180
+ )
181
+ return distrax .Transformed (base_distribution , flow )
182
+
183
+
184
+ class VariationalFlow (hk .Module ):
185
+ def __init__ (self , latent_size : int , hidden_size : int ):
186
+ super ().__init__ (name = "variational" )
187
+ self ._latent_size = latent_size
188
+ self ._hidden_size = hidden_size
189
+
190
+ def __call__ (self , x : jnp .ndarray ) -> distrax .Distribution :
191
+ return make_flow (
192
+ event_shape = (self ._latent_size ,),
193
+ num_layers = 2 ,
194
+ hidden_sizes = [self ._hidden_size ] * 2 ,
195
+ num_bins = 4 ,
196
+ )
197
+
198
+
120
199
def main ():
121
200
start_time = time .time ()
122
201
parser = argparse .ArgumentParser ()
@@ -126,8 +205,11 @@ def main():
126
205
model = hk .transform (
127
206
lambda x , z : Model (args .latent_size , args .hidden_size , MNIST_IMAGE_SHAPE )(x , z )
128
207
)
208
+ # variational = hk.transform(
209
+ # lambda x: VariationalMeanField(args.latent_size, args.hidden_size)(x)
210
+ # )
129
211
variational = hk .transform (
130
- lambda x : VariationalMeanField (args .latent_size , args .hidden_size )(x )
212
+ lambda x : VariationalFlow (args .latent_size , args .hidden_size )(x )
131
213
)
132
214
p_params = model .init (
133
215
next (rng_seq ),
@@ -139,16 +221,14 @@ def main():
139
221
optimizer = optax .rmsprop (args .learning_rate )
140
222
opt_state = optimizer .init (params )
141
223
142
- @jax .jit
224
+ # @jax.jit
143
225
def objective_fn (params : hk .Params , rng_key : PRNGKey , batch : Batch ) -> jnp .ndarray :
144
226
x = batch ["image" ]
145
227
predicate = lambda module_name , name , value : "model" in module_name
146
228
p_params , q_params = hk .data_structures .partition (predicate , params )
147
229
q_z = variational .apply (q_params , rng_key , x )
148
- z = q_z .sample ( sample_shape = [1 ], seed = rng_key )
230
+ z , log_q_z = q_z .sample_and_log_prob ( x , sample_shape = [1 ], seed = rng_key )
149
231
p_z , p_x_given_z = model .apply (p_params , rng_key , x , z )
150
- # out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x)
151
- log_q_z = q_z .log_prob (z ).sum (axis = - 1 )
152
232
# sum over last three image dimensions (width, height, channels)
153
233
log_p_x_given_z = p_x_given_z .log_prob (x ).sum (axis = (- 3 , - 2 , - 1 ))
154
234
# sum over latent dimension
@@ -160,7 +240,7 @@ def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarr
160
240
elbo = elbo .sum (axis = 0 )
161
241
return - elbo
162
242
163
- @jax .jit
243
+ # @jax.jit
164
244
def train_step (
165
245
params : hk .Params , rng_key : PRNGKey , opt_state : optax .OptState , batch : Batch
166
246
) -> Tuple [hk .Params , optax .OptState ]:
@@ -170,7 +250,7 @@ def train_step(
170
250
new_params = optax .apply_updates (params , updates )
171
251
return new_params , new_opt_state
172
252
173
- @jax .jit
253
+ # @jax.jit
174
254
def importance_weighted_estimate (
175
255
params : hk .Params , rng_key : PRNGKey , batch : Batch
176
256
) -> Tuple [jnp .ndarray , jnp .ndarray ]:
@@ -180,9 +260,9 @@ def importance_weighted_estimate(
180
260
predicate = lambda module_name , name , value : "model" in module_name
181
261
p_params , q_params = hk .data_structures .partition (predicate , params )
182
262
q_z = variational .apply (q_params , rng_key , x )
183
- z = q_z .sample ( args .num_eval_samples , seed = rng_key )
263
+ z , log_q_z = q_z .sample_and_log_prob ( sample_shape = [ args .num_eval_samples ] , seed = rng_key )
184
264
p_z , p_x_given_z = model .apply (p_params , rng_key , x , z )
185
- log_q_z = q_z .log_prob (z ).sum (axis = - 1 )
265
+ # log_q_z = q_z.log_prob(z).sum(axis=-1)
186
266
# sum over last three image dimensions (width, height, channels)
187
267
log_p_x_given_z = p_x_given_z .log_prob (x ).sum (axis = (- 3 , - 2 , - 1 ))
188
268
# sum over latent dimension
0 commit comments