From dfb452b5421e9e5b97315c6420b8766ac86f3f4f Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Fri, 21 May 2021 19:10:22 -0400 Subject: [PATCH] add jax example --- .env | 7 + .gitignore | 4 + README.md | 33 +-- environment_jax.yml | 81 +++++++ setup.cfg | 2 + train_variational_autoencoder_jax.py | 267 +++++++++++++++++++++++ train_variational_autoencoder_pytorch.py | 2 + 7 files changed, 371 insertions(+), 25 deletions(-) create mode 100644 .env create mode 100644 .gitignore create mode 100644 environment_jax.yml create mode 100644 setup.cfg create mode 100644 train_variational_autoencoder_jax.py diff --git a/.env b/.env new file mode 100644 index 0000000..fd9db80 --- /dev/null +++ b/.env @@ -0,0 +1,7 @@ +# dev.env - development configuration + +# suppress warnings for jax +JAX_PLATFORM_NAME=cpu + +# suppress tensorflow warnings +TF_CPP_MIN_LOG_LEVEL=2 \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5f4cdef --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.pyc +launch.json +settings.json +*.code-workspace \ No newline at end of file diff --git a/README.md b/README.md index 62caf60..5f27575 100644 --- a/README.md +++ b/README.md @@ -34,29 +34,12 @@ step: 20000 train elbo: -101.51 step: 20000 valid elbo: -105.02 valid log p(x): -99.11 step: 30000 train elbo: -98.70 step: 30000 valid elbo: -103.76 valid log p(x): -97.71 -step: 40000 train elbo: -104.31 -step: 40000 valid elbo: -103.71 valid log p(x): -97.27 -step: 50000 train elbo: -97.20 -step: 50000 valid elbo: -102.97 valid log p(x): -96.60 -step: 60000 train elbo: -97.50 -step: 60000 valid elbo: -102.82 valid log p(x): -96.49 -step: 70000 train elbo: -94.68 -step: 70000 valid elbo: -102.63 valid log p(x): -96.22 -step: 80000 train elbo: -92.86 -step: 80000 valid elbo: -102.53 valid log p(x): -96.09 -step: 90000 train elbo: -93.83 -step: 90000 valid elbo: -102.33 valid log p(x): -96.00 -step: 100000 train elbo: -93.91 -step: 100000 valid elbo: -102.48 valid log p(x): -95.92 -step: 110000 train elbo: -94.34 -step: 110000 valid elbo: -102.81 valid log p(x): -96.09 -step: 120000 train elbo: -88.63 -step: 120000 valid elbo: -102.53 valid log p(x): -95.80 -step: 130000 train elbo: -96.61 -step: 130000 valid elbo: -103.56 valid log p(x): -96.26 -step: 140000 train elbo: -94.92 -step: 140000 valid elbo: -102.81 valid log p(x): -95.86 -step: 150000 train elbo: -97.84 -step: 150000 valid elbo: -103.06 valid log p(x): -95.92 -step: 150000 test elbo: -101.64 test log p(x): -95.33 ``` + +Using jax: +``` +Step 0 Validation ELBO estimate: -507.485 Validation log p(x) estimate: -507.485 +Step 10000 Validation ELBO estimate: -152.695 Validation log p(x) estimate: -152.695 +Step 20000 Validation ELBO estimate: -150.413 Validation log p(x) estimate: -150.413 +Step 30000 Validation ELBO estimate: -150.529 Validation log p(x) estimate: -150.529 +``` \ No newline at end of file diff --git a/environment_jax.yml b/environment_jax.yml new file mode 100644 index 0000000..4f12d2d --- /dev/null +++ b/environment_jax.yml @@ -0,0 +1,81 @@ +name: jax +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - ca-certificates=2021.4.13=h06a4308_1 + - certifi=2020.12.5=py39h06a4308_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - ncurses=6.2=he6710b0_1 + - openssl=1.1.1k=h27cfd23_0 + - pip=21.1.1=py39h06a4308_0 + - python=3.9.5=hdb3f193_3 + - readline=8.1=h27cfd23_0 + - setuptools=52.0.0=py39h06a4308_0 + - sqlite=3.35.4=hdfb4753_0 + - tk=8.6.10=hbc83047_0 + - tzdata=2020f=h52ac0ba_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - pip: + - absl-py==0.12.0 + - astunparse==1.6.3 + - attrs==21.2.0 + - cachetools==4.2.2 + - chardet==4.0.0 + - chex==0.0.7 + - cloudpickle==1.6.0 + - decorator==5.0.9 + - dill==0.3.3 + - dm-haiku==0.0.5.dev0 + - dm-tree==0.1.6 + - flatbuffers==1.12 + - future==0.18.2 + - gast==0.4.0 + - google-auth==1.30.0 + - google-auth-oauthlib==0.4.4 + - google-pasta==0.2.0 + - googleapis-common-protos==1.53.0 + - grpcio==1.34.1 + - h5py==3.1.0 + - idna==2.10 + - jax==0.2.13 + - jaxlib==0.1.67 + - jmp==0.0.2 + - keras-nightly==2.5.0.dev2021032900 + - keras-preprocessing==1.1.2 + - markdown==3.3.4 + - numpy==1.19.5 + - oauthlib==3.1.0 + - opt-einsum==3.3.0 + - optax==0.0.7 + - promise==2.3 + - protobuf==3.17.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - requests==2.25.1 + - requests-oauthlib==1.3.0 + - rsa==4.7.2 + - scipy==1.6.3 + - six==1.15.0 + - tabulate==0.8.9 + - tensorboard==2.5.0 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.0 + - tensorflow==2.5.0 + - tensorflow-datasets==4.3.0 + - tensorflow-estimator==2.5.0 + - tensorflow-metadata==1.0.0 + - termcolor==1.1.0 + - tfp-nightly==0.14.0.dev20210521 + - toolz==0.11.1 + - tqdm==4.60.0 + - typing-extensions==3.7.4.3 + - urllib3==1.26.4 + - werkzeug==2.0.1 + - wrapt==1.12.1 +prefix: /home/jaan/miniconda3/envs/jax diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..1d36346 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[flake8] +max-line-length = 88 \ No newline at end of file diff --git a/train_variational_autoencoder_jax.py b/train_variational_autoencoder_jax.py new file mode 100644 index 0000000..a9dac24 --- /dev/null +++ b/train_variational_autoencoder_jax.py @@ -0,0 +1,267 @@ +"""Train variational autoencoder or binary MNIST data. + +Largely follows https://github.com/deepmind/dm-haiku/blob/master/examples/vae.py""" + +import argparse +import pathlib +from calendar import c +from typing import Generator, Mapping, NamedTuple, Sequence, Tuple + +import jax +import numpy as np + +jax.config.update("jax_platform_name", "cpu") # suppress warning about no GPUs + +import haiku as hk +import jax.numpy as jnp +import optax +import tensorflow_datasets as tfds +from tensorflow_probability.substrates import jax as tfp + +tfd = tfp.distributions + +Batch = Mapping[str, np.ndarray] +MNIST_IMAGE_SHAPE: Sequence[int] = (28, 28, 1) +PRNGKey = jnp.ndarray + + +def add_args(parser): + parser.add_argument("--latent_size", type=int, default=128) + parser.add_argument("--hidden_size", type=int, default=256) + parser.add_argument("--variational", choices=["flow", "mean-field"]) + parser.add_argument("--flow_depth", type=int, default=2) + parser.add_argument("--learning_rate", type=float, default=0.001) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--training_steps", type=int, default=100000) + parser.add_argument("--log_interval", type=int, default=10000) + parser.add_argument("--early_stopping_interval", type=int, default=5) + parser.add_argument("--n_samples", type=int, default=128) + parser.add_argument( + "--use_gpu", default=False, action=argparse.BooleanOptionalAction + ) + parser.add_argument("--random_seed", type=int, default=582838) + parser.add_argument("--train_dir", type=pathlib.Path, default="/tmp") + parser.add_argument("--data_dir", type=pathlib.Path, default="/tmp") + + +def load_dataset( + split: str, batch_size: int, seed: int, repeat: bool = False +) -> Generator[Batch, None, None]: + ds = tfds.load( + "binarized_mnist", + split=split, + shuffle_files=True, + read_config=tfds.ReadConfig(shuffle_seed=seed), + ) + ds = ds.shuffle(buffer_size=10 * batch_size, seed=seed) + ds = ds.batch(batch_size) + ds = ds.prefetch(buffer_size=5) + if repeat: + ds = ds.repeat() + return iter(tfds.as_numpy(ds)) + + +class Model(hk.Module): + """Deep latent Gaussian model or variational autoencoder.""" + + def __init__( + self, + latent_size: int, + hidden_size: int, + output_shape: Sequence[int] = MNIST_IMAGE_SHAPE, + ): + super().__init__() + self._latent_size = latent_size + self._hidden_size = hidden_size + self._output_shape = output_shape + self.generative_network = hk.Sequential( + [ + hk.Linear(self._hidden_size), + jax.nn.relu, + hk.Linear(self._hidden_size), + jax.nn.relu, + hk.Linear(np.prod(self._output_shape)), + hk.Reshape(self._output_shape, preserve_dims=2), + ] + ) + + def __call__(self, x: jnp.ndarray, z: jnp.ndarray) -> Tuple[tfd.Distribution]: + p_z = tfd.Normal( + loc=jnp.zeros(self._latent_size), scale=jnp.ones(self._latent_size) + ) + logits = self.generative_network(z) + p_x_given_z = tfd.Bernoulli(logits=logits) + return p_z, p_x_given_z + + +class VariationalMeanField(hk.Module): + """Mean field variational distribution q(z | x) parameterized by inference network.""" + + def __init__(self, latent_size: int, hidden_size: int): + super().__init__() + self._latent_size = latent_size + self._hidden_size = hidden_size + self.inference_network = hk.Sequential( + [ + hk.Flatten(), + hk.Linear(self._hidden_size), + jax.nn.relu, + hk.Linear(self._hidden_size), + jax.nn.relu, + hk.Linear(self._latent_size * 2), + ] + ) + + def condition(self, inputs): + """Compute parameters of a multivariate independent Normal distribution based on the inputs.""" + out = self.inference_network(inputs) + loc, scale_arg = jnp.split(out, 2, axis=-1) + scale = jax.nn.softplus(scale_arg) + return loc, scale + + def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + loc, scale = self.condition(x) + q_z = tfd.Normal(loc=loc, scale=scale) + return q_z + + +class ModelAndVariationalOutput(NamedTuple): + p_z: tfd.Distribution + p_x_given_z: tfd.Distribution + q_z: tfd.Distribution + z: jnp.ndarray + + +class ModelAndVariational(hk.Module): + """Parent class for creating inputs to the variational inference algorithm.""" + + def __init__(self, latent_size: int, hidden_size: int, output_shape: Sequence[int]): + super().__init__() + self._latent_size = latent_size + self._hidden_size = hidden_size + self._output_shape = output_shape + + def __call__(self, x: jnp.ndarray) -> ModelAndVariationalOutput: + x = x.astype(jnp.float32) + q_z = VariationalMeanField(self._latent_size, self._hidden_size)(x) + # use a single sample from variational distribution to train + # shape [num_samples, batch_size, latent_size] + z = q_z.sample(sample_shape=[1], seed=hk.next_rng_key()) + + p_z, p_x_given_z = Model( + self._latent_size, self._hidden_size, MNIST_IMAGE_SHAPE + )(x=x, z=z) + return ModelAndVariationalOutput(p_z, p_x_given_z, q_z, z) + + +def main(): + parser = argparse.ArgumentParser() + add_args(parser) + args = parser.parse_args() + model_and_variational = hk.transform( + lambda x: ModelAndVariational( + args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE + )(x) + ) + + @jax.jit + def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch,) -> jnp.ndarray: + x = batch["image"] + out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x) + log_q_z = out.q_z.log_prob(out.z).sum(axis=-1) + # sum over last three image dimensions (width, height, channels) + log_p_x_given_z = out.p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1)) + # sum over latent dimension + log_p_z = out.p_z.log_prob(out.z).sum(axis=-1) + + elbo = log_p_x_given_z + log_p_z - log_q_z + # average elbo over number of samples + elbo = elbo.mean(axis=0) + # sum elbo over batch + elbo = elbo.sum(axis=0) + return -elbo + + rng_seq = hk.PRNGSequence(args.random_seed) + + params = model_and_variational.init( + next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE)) + ) + optimizer = optax.rmsprop(args.learning_rate, centered=True) + opt_state = optimizer.init(params) + + @jax.jit + def train_step( + params: hk.Params, rng_key: PRNGKey, state: optax.OptState, batch: Batch + ) -> Tuple[hk.Params, optax.OptState]: + """Single update step to maximize the ELBO.""" + grads = jax.grad(objective_fn)(params, rng_key, batch) + updates, new_opt_state = optimizer.update(grads, opt_state) + new_params = optax.apply_updates(params, updates) + return new_params, new_opt_state + + @jax.jit + def importance_weighted_estimate( + params: hk.Params, rng_key: PRNGKey, batch: Batch + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Estimate marginal log p(x) using importance sampling.""" + x = batch["image"] + out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x) + log_q_z = out.q_z.log_prob(out.z).sum(axis=-1) + # sum over last three image dimensions (width, height, channels) + log_p_x_given_z = out.p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1)) + # sum over latent dimension + log_p_z = out.p_z.log_prob(out.z).sum(axis=-1) + + elbo = log_p_x_given_z + log_p_z - log_q_z + # importance sampling of approximate marginal likelihood with q(z) + # as the proposal, and logsumexp in the sample dimension + log_p_x = jax.nn.logsumexp(elbo, axis=0) - jnp.log(jnp.shape(elbo)[0]) + # sum over the elements of the minibatch + log_p_x = log_p_x.sum(0) + # average elbo over number of samples + elbo = elbo.mean(axis=0) + # sum elbo over batch + elbo = elbo.sum(axis=0) + return elbo, log_p_x + + def evaluate( + dataset: Generator[Batch, None, None], + params: hk.Params, + rng_seq: hk.PRNGSequence, + ) -> Tuple[float, float]: + total_elbo = 0.0 + total_log_p_x = 0.0 + dataset_size = 0 + for batch in dataset: + elbo, log_p_x = importance_weighted_estimate(params, next(rng_seq), batch) + total_elbo += elbo + total_log_p_x += log_p_x + dataset_size += len(batch["image"]) + return total_elbo / dataset_size, total_log_p_x / dataset_size + + train_ds = load_dataset( + tfds.Split.TRAIN, args.batch_size, args.random_seed, repeat=True + ) + test_ds = load_dataset(tfds.Split.TEST, args.batch_size, args.random_seed) + + for step in range(args.training_steps): + params, opt_state = train_step(params, next(rng_seq), opt_state, next(train_ds)) + if step % args.log_interval == 0: + valid_ds = load_dataset( + tfds.Split.VALIDATION, args.batch_size, args.random_seed + ) + elbo, log_p_x = evaluate(valid_ds, params, rng_seq) + train_elbo = ( + -objective_fn(params, next(rng_seq), next(train_ds)) / args.batch_size + ) + print( + f"Step {step:<10d}\t" + f"Train ELBO estimate: {train_elbo:<5.3f}" + f"Validation ELBO estimate: {elbo:<5.3f}\t" + f"Validation log p(x) estimate: {log_p_x:<5.3f}" + ) + + +if __name__ == "__main__": + main() + diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index 8da495a..059dd71 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -1,3 +1,5 @@ +"""Train variational autoencoder on binary MNIST data.""" + import numpy as np import random