Skip to content

[pull] master from altosaar:master #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 25, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions train_variational_autoencoder_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@


def add_args(parser):
parser.add_argument("--latent_size", type=int, default=128)
parser.add_argument("--hidden_size", type=int, default=256)
parser.add_argument("--latent_size", type=int, default=10)
parser.add_argument("--hidden_size", type=int, default=512)
parser.add_argument("--variational", choices=["flow", "mean-field"])
parser.add_argument("--flow_depth", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=0.001)
Expand All @@ -39,7 +39,7 @@ def add_args(parser):
parser.add_argument(
"--use_gpu", default=False, action=argparse.BooleanOptionalAction
)
parser.add_argument("--random_seed", type=int, default=582838)
parser.add_argument("--random_seed", type=int, default=42)
parser.add_argument("--train_dir", type=pathlib.Path, default="/tmp")
parser.add_argument("--data_dir", type=pathlib.Path, default="/tmp")

Expand Down Expand Up @@ -78,8 +78,8 @@ def __init__(
[
hk.Linear(self._hidden_size),
jax.nn.relu,
hk.Linear(self._hidden_size),
jax.nn.relu,
# hk.Linear(self._hidden_size),
# jax.nn.relu,
hk.Linear(np.prod(self._output_shape)),
hk.Reshape(self._output_shape, preserve_dims=2),
]
Expand All @@ -106,8 +106,8 @@ def __init__(self, latent_size: int, hidden_size: int):
hk.Flatten(),
hk.Linear(self._hidden_size),
jax.nn.relu,
hk.Linear(self._hidden_size),
jax.nn.relu,
# hk.Linear(self._hidden_size),
# jax.nn.relu,
hk.Linear(self._latent_size * 2),
]
)
Expand All @@ -121,6 +121,7 @@ def condition(self, inputs):

def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
loc, scale = self.condition(x)
# IMPORTANT: need to check in source code that reparameterization_type=tfd.FULLY_REPARAMETERIZED for this class
q_z = tfd.Normal(loc=loc, scale=scale)
return q_z

Expand Down Expand Up @@ -164,8 +165,8 @@ def main():
)(x)
)

@jax.jit
def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch,) -> jnp.ndarray:
# @jax.jit
def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray:
x = batch["image"]
out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x)
log_q_z = out.q_z.log_prob(out.z).sum(axis=-1)
Expand All @@ -186,12 +187,12 @@ def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch,) -> jnp.ndar
params = model_and_variational.init(
next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE))
)
optimizer = optax.rmsprop(args.learning_rate, centered=True)
optimizer = optax.adam(args.learning_rate)
opt_state = optimizer.init(params)

@jax.jit
# @jax.jit
def train_step(
params: hk.Params, rng_key: PRNGKey, state: optax.OptState, batch: Batch
params: hk.Params, rng_key: PRNGKey, opt_state: optax.OptState, batch: Batch
) -> Tuple[hk.Params, optax.OptState]:
"""Single update step to maximize the ELBO."""
grads = jax.grad(objective_fn)(params, rng_key, batch)
Expand Down Expand Up @@ -256,7 +257,7 @@ def evaluate(
)
print(
f"Step {step:<10d}\t"
f"Train ELBO estimate: {train_elbo:<5.3f}"
f"Train ELBO estimate: {train_elbo:<5.3f}\t"
f"Validation ELBO estimate: {elbo:<5.3f}\t"
f"Validation log p(x) estimate: {log_p_x:<5.3f}"
)
Expand Down