Skip to content

Commit f4c3ece

Browse files
author
Jaan Altosaar
committed
fix stale opt_state
1 parent dfb452b commit f4c3ece

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

train_variational_autoencoder_jax.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727

2828
def add_args(parser):
29-
parser.add_argument("--latent_size", type=int, default=128)
30-
parser.add_argument("--hidden_size", type=int, default=256)
29+
parser.add_argument("--latent_size", type=int, default=10)
30+
parser.add_argument("--hidden_size", type=int, default=512)
3131
parser.add_argument("--variational", choices=["flow", "mean-field"])
3232
parser.add_argument("--flow_depth", type=int, default=2)
3333
parser.add_argument("--learning_rate", type=float, default=0.001)
@@ -39,7 +39,7 @@ def add_args(parser):
3939
parser.add_argument(
4040
"--use_gpu", default=False, action=argparse.BooleanOptionalAction
4141
)
42-
parser.add_argument("--random_seed", type=int, default=582838)
42+
parser.add_argument("--random_seed", type=int, default=42)
4343
parser.add_argument("--train_dir", type=pathlib.Path, default="/tmp")
4444
parser.add_argument("--data_dir", type=pathlib.Path, default="/tmp")
4545

@@ -78,8 +78,8 @@ def __init__(
7878
[
7979
hk.Linear(self._hidden_size),
8080
jax.nn.relu,
81-
hk.Linear(self._hidden_size),
82-
jax.nn.relu,
81+
# hk.Linear(self._hidden_size),
82+
# jax.nn.relu,
8383
hk.Linear(np.prod(self._output_shape)),
8484
hk.Reshape(self._output_shape, preserve_dims=2),
8585
]
@@ -106,8 +106,8 @@ def __init__(self, latent_size: int, hidden_size: int):
106106
hk.Flatten(),
107107
hk.Linear(self._hidden_size),
108108
jax.nn.relu,
109-
hk.Linear(self._hidden_size),
110-
jax.nn.relu,
109+
# hk.Linear(self._hidden_size),
110+
# jax.nn.relu,
111111
hk.Linear(self._latent_size * 2),
112112
]
113113
)
@@ -121,6 +121,7 @@ def condition(self, inputs):
121121

122122
def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
123123
loc, scale = self.condition(x)
124+
# IMPORTANT: need to check in source code that reparameterization_type=tfd.FULLY_REPARAMETERIZED for this class
124125
q_z = tfd.Normal(loc=loc, scale=scale)
125126
return q_z
126127

@@ -164,8 +165,8 @@ def main():
164165
)(x)
165166
)
166167

167-
@jax.jit
168-
def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch,) -> jnp.ndarray:
168+
# @jax.jit
169+
def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray:
169170
x = batch["image"]
170171
out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x)
171172
log_q_z = out.q_z.log_prob(out.z).sum(axis=-1)
@@ -186,12 +187,12 @@ def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch,) -> jnp.ndar
186187
params = model_and_variational.init(
187188
next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE))
188189
)
189-
optimizer = optax.rmsprop(args.learning_rate, centered=True)
190+
optimizer = optax.adam(args.learning_rate)
190191
opt_state = optimizer.init(params)
191192

192-
@jax.jit
193+
# @jax.jit
193194
def train_step(
194-
params: hk.Params, rng_key: PRNGKey, state: optax.OptState, batch: Batch
195+
params: hk.Params, rng_key: PRNGKey, opt_state: optax.OptState, batch: Batch
195196
) -> Tuple[hk.Params, optax.OptState]:
196197
"""Single update step to maximize the ELBO."""
197198
grads = jax.grad(objective_fn)(params, rng_key, batch)
@@ -256,7 +257,7 @@ def evaluate(
256257
)
257258
print(
258259
f"Step {step:<10d}\t"
259-
f"Train ELBO estimate: {train_elbo:<5.3f}"
260+
f"Train ELBO estimate: {train_elbo:<5.3f}\t"
260261
f"Validation ELBO estimate: {elbo:<5.3f}\t"
261262
f"Validation log p(x) estimate: {log_p_x:<5.3f}"
262263
)

0 commit comments

Comments
 (0)