diff --git a/train_variational_autoencoder_jax.py b/train_variational_autoencoder_jax.py index a9dac24..2c765dc 100644 --- a/train_variational_autoencoder_jax.py +++ b/train_variational_autoencoder_jax.py @@ -26,8 +26,8 @@ def add_args(parser): - parser.add_argument("--latent_size", type=int, default=128) - parser.add_argument("--hidden_size", type=int, default=256) + parser.add_argument("--latent_size", type=int, default=10) + parser.add_argument("--hidden_size", type=int, default=512) parser.add_argument("--variational", choices=["flow", "mean-field"]) parser.add_argument("--flow_depth", type=int, default=2) parser.add_argument("--learning_rate", type=float, default=0.001) @@ -39,7 +39,7 @@ def add_args(parser): parser.add_argument( "--use_gpu", default=False, action=argparse.BooleanOptionalAction ) - parser.add_argument("--random_seed", type=int, default=582838) + parser.add_argument("--random_seed", type=int, default=42) parser.add_argument("--train_dir", type=pathlib.Path, default="/tmp") parser.add_argument("--data_dir", type=pathlib.Path, default="/tmp") @@ -78,8 +78,8 @@ def __init__( [ hk.Linear(self._hidden_size), jax.nn.relu, - hk.Linear(self._hidden_size), - jax.nn.relu, + # hk.Linear(self._hidden_size), + # jax.nn.relu, hk.Linear(np.prod(self._output_shape)), hk.Reshape(self._output_shape, preserve_dims=2), ] @@ -106,8 +106,8 @@ def __init__(self, latent_size: int, hidden_size: int): hk.Flatten(), hk.Linear(self._hidden_size), jax.nn.relu, - hk.Linear(self._hidden_size), - jax.nn.relu, + # hk.Linear(self._hidden_size), + # jax.nn.relu, hk.Linear(self._latent_size * 2), ] ) @@ -121,6 +121,7 @@ def condition(self, inputs): def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: loc, scale = self.condition(x) + # IMPORTANT: need to check in source code that reparameterization_type=tfd.FULLY_REPARAMETERIZED for this class q_z = tfd.Normal(loc=loc, scale=scale) return q_z @@ -164,8 +165,8 @@ def main(): )(x) ) - @jax.jit - def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch,) -> jnp.ndarray: + # @jax.jit + def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray: x = batch["image"] out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x) log_q_z = out.q_z.log_prob(out.z).sum(axis=-1) @@ -186,12 +187,12 @@ def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch,) -> jnp.ndar params = model_and_variational.init( next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE)) ) - optimizer = optax.rmsprop(args.learning_rate, centered=True) + optimizer = optax.adam(args.learning_rate) opt_state = optimizer.init(params) - @jax.jit + # @jax.jit def train_step( - params: hk.Params, rng_key: PRNGKey, state: optax.OptState, batch: Batch + params: hk.Params, rng_key: PRNGKey, opt_state: optax.OptState, batch: Batch ) -> Tuple[hk.Params, optax.OptState]: """Single update step to maximize the ELBO.""" grads = jax.grad(objective_fn)(params, rng_key, batch) @@ -256,7 +257,7 @@ def evaluate( ) print( f"Step {step:<10d}\t" - f"Train ELBO estimate: {train_elbo:<5.3f}" + f"Train ELBO estimate: {train_elbo:<5.3f}\t" f"Validation ELBO estimate: {elbo:<5.3f}\t" f"Validation log p(x) estimate: {log_p_x:<5.3f}" )