26
26
27
27
28
28
def add_args (parser ):
29
- parser .add_argument ("--latent_size" , type = int , default = 128 )
30
- parser .add_argument ("--hidden_size" , type = int , default = 256 )
29
+ parser .add_argument ("--latent_size" , type = int , default = 10 )
30
+ parser .add_argument ("--hidden_size" , type = int , default = 512 )
31
31
parser .add_argument ("--variational" , choices = ["flow" , "mean-field" ])
32
32
parser .add_argument ("--flow_depth" , type = int , default = 2 )
33
33
parser .add_argument ("--learning_rate" , type = float , default = 0.001 )
@@ -39,7 +39,7 @@ def add_args(parser):
39
39
parser .add_argument (
40
40
"--use_gpu" , default = False , action = argparse .BooleanOptionalAction
41
41
)
42
- parser .add_argument ("--random_seed" , type = int , default = 582838 )
42
+ parser .add_argument ("--random_seed" , type = int , default = 42 )
43
43
parser .add_argument ("--train_dir" , type = pathlib .Path , default = "/tmp" )
44
44
parser .add_argument ("--data_dir" , type = pathlib .Path , default = "/tmp" )
45
45
@@ -78,8 +78,8 @@ def __init__(
78
78
[
79
79
hk .Linear (self ._hidden_size ),
80
80
jax .nn .relu ,
81
- hk .Linear (self ._hidden_size ),
82
- jax .nn .relu ,
81
+ # hk.Linear(self._hidden_size),
82
+ # jax.nn.relu,
83
83
hk .Linear (np .prod (self ._output_shape )),
84
84
hk .Reshape (self ._output_shape , preserve_dims = 2 ),
85
85
]
@@ -106,8 +106,8 @@ def __init__(self, latent_size: int, hidden_size: int):
106
106
hk .Flatten (),
107
107
hk .Linear (self ._hidden_size ),
108
108
jax .nn .relu ,
109
- hk .Linear (self ._hidden_size ),
110
- jax .nn .relu ,
109
+ # hk.Linear(self._hidden_size),
110
+ # jax.nn.relu,
111
111
hk .Linear (self ._latent_size * 2 ),
112
112
]
113
113
)
@@ -121,6 +121,7 @@ def condition(self, inputs):
121
121
122
122
def __call__ (self , x : jnp .ndarray ) -> Tuple [jnp .ndarray , jnp .ndarray ]:
123
123
loc , scale = self .condition (x )
124
+ # IMPORTANT: need to check in source code that reparameterization_type=tfd.FULLY_REPARAMETERIZED for this class
124
125
q_z = tfd .Normal (loc = loc , scale = scale )
125
126
return q_z
126
127
@@ -164,8 +165,8 @@ def main():
164
165
)(x )
165
166
)
166
167
167
- @jax .jit
168
- def objective_fn (params : hk .Params , rng_key : PRNGKey , batch : Batch , ) -> jnp .ndarray :
168
+ # @jax.jit
169
+ def objective_fn (params : hk .Params , rng_key : PRNGKey , batch : Batch ) -> jnp .ndarray :
169
170
x = batch ["image" ]
170
171
out : ModelAndVariationalOutput = model_and_variational .apply (params , rng_key , x )
171
172
log_q_z = out .q_z .log_prob (out .z ).sum (axis = - 1 )
@@ -186,12 +187,12 @@ def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch,) -> jnp.ndar
186
187
params = model_and_variational .init (
187
188
next (rng_seq ), np .zeros ((1 , * MNIST_IMAGE_SHAPE ))
188
189
)
189
- optimizer = optax .rmsprop (args .learning_rate , centered = True )
190
+ optimizer = optax .adam (args .learning_rate )
190
191
opt_state = optimizer .init (params )
191
192
192
- @jax .jit
193
+ # @jax.jit
193
194
def train_step (
194
- params : hk .Params , rng_key : PRNGKey , state : optax .OptState , batch : Batch
195
+ params : hk .Params , rng_key : PRNGKey , opt_state : optax .OptState , batch : Batch
195
196
) -> Tuple [hk .Params , optax .OptState ]:
196
197
"""Single update step to maximize the ELBO."""
197
198
grads = jax .grad (objective_fn )(params , rng_key , batch )
@@ -256,7 +257,7 @@ def evaluate(
256
257
)
257
258
print (
258
259
f"Step { step :<10d} \t "
259
- f"Train ELBO estimate: { train_elbo :<5.3f} "
260
+ f"Train ELBO estimate: { train_elbo :<5.3f} \t "
260
261
f"Validation ELBO estimate: { elbo :<5.3f} \t "
261
262
f"Validation log p(x) estimate: { log_p_x :<5.3f} "
262
263
)
0 commit comments