Skip to content

[pull] master from altosaar:master #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
JAX_PLATFORM_NAME=cpu

# suppress tensorflow warnings
TF_CPP_MIN_LOG_LEVEL=2
TF_CPP_MIN_LOG_LEVEL=2

TFDS_DATA_DIR=/scratch/gpfs/altosaar/tensorflow_datasets
34 changes: 21 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@ Variational inference is used to fit the model to binarized MNIST handwritten di

Blog post: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/

Example output with importance sampling for estimating the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. Final marginal likelihood on the test set was `-97.10` nats after 65k iterations.

## PyTorch implementation

(anaconda environment is in `environment-jax.yml`)

Importance sampling is used to estimate the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. The final marginal likelihood on the test set was `-97.10` nats is comparable to published numbers.

```
$ python train_variational_autoencoder_pytorch.py --variational mean-field
step: 0 train elbo: -558.28
step: 0 valid elbo: -392.78 valid log p(x): -359.91
step: 10000 train elbo: -106.67
step: 10000 valid elbo: -109.12 valid log p(x): -103.11
step: 20000 train elbo: -107.28
step: 20000 valid elbo: -105.65 valid log p(x): -99.74
$ python train_variational_autoencoder_pytorch.py --variational mean-field --use_gpu --data_dir $DAT --max_iterations 30000 --log_interval 10000
Step 0 Train ELBO estimate: -558.027 Validation ELBO estimate: -384.432 Validation log p(x) estimate: -355.430 Speed: 2.72e+06 examples/s
Step 10000 Train ELBO estimate: -111.323 Validation ELBO estimate: -109.048 Validation log p(x) estimate: -103.746 Speed: 2.64e+04 examples/s
Step 20000 Train ELBO estimate: -103.013 Validation ELBO estimate: -107.655 Validation log p(x) estimate: -101.275 Speed: 2.63e+04 examples/s
Step 29999 Test ELBO estimate: -106.642 Test log p(x) estimate: -100.309
Total time: 2.49 minutes
```


Expand All @@ -36,10 +40,14 @@ step: 30000 train elbo: -98.70
step: 30000 valid elbo: -103.76 valid log p(x): -97.71
```

Using jax:
## jax implementation

Using jax (anaconda environment is in `environment-jax.yml`), to get a 3x speedup over pytorch:
```
Step 0 Validation ELBO estimate: -507.485 Validation log p(x) estimate: -507.485
Step 10000 Validation ELBO estimate: -152.695 Validation log p(x) estimate: -152.695
Step 20000 Validation ELBO estimate: -150.413 Validation log p(x) estimate: -150.413
Step 30000 Validation ELBO estimate: -150.529 Validation log p(x) estimate: -150.529
$ python train_variational_autoencoder_jax.py --gpu
Step 0 Train ELBO estimate: -566.059 Validation ELBO estimate: -565.755 Validation log p(x) estimate: -557.914 Speed: 2.56e+11 examples/s
Step 10000 Train ELBO estimate: -98.560 Validation ELBO estimate: -105.725 Validation log p(x) estimate: -98.973 Speed: 7.03e+04 examples/s
Step 20000 Train ELBO estimate: -109.794 Validation ELBO estimate: -105.756 Validation log p(x) estimate: -97.914 Speed: 4.26e+04 examples/s
Step 29999 Test ELBO estimate: -104.867 Test log p(x) estimate: -96.716
Total time: 0.810 minutes
```
28 changes: 14 additions & 14 deletions environment_jax.yml → environment-jax.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: jax
name: /scratch/gpfs/altosaar/environment-jax
channels:
- defaults
dependencies:
Expand All @@ -11,7 +11,6 @@ dependencies:
- libstdcxx-ng=9.1.0=hdf63c60_0
- ncurses=6.2=he6710b0_1
- openssl=1.1.1k=h27cfd23_0
- pip=21.1.1=py39h06a4308_0
- python=3.9.5=hdb3f193_3
- readline=8.1=h27cfd23_0
- setuptools=52.0.0=py39h06a4308_0
Expand All @@ -36,25 +35,26 @@ dependencies:
- flatbuffers==1.12
- future==0.18.2
- gast==0.4.0
- google-auth==1.30.0
- google-auth==1.30.1
- google-auth-oauthlib==0.4.4
- google-pasta==0.2.0
- googleapis-common-protos==1.53.0
- grpcio==1.34.1
- grpcio==1.38.0
- h5py==3.1.0
- idna==2.10
- jax==0.2.13
- jaxlib==0.1.67
- jaxlib==0.1.67+cuda111
- jmp==0.0.2
- keras-nightly==2.5.0.dev2021032900
- keras-nightly==2.6.0.dev2021052500
- keras-preprocessing==1.1.2
- markdown==3.3.4
- numpy==1.19.5
- oauthlib==3.1.0
- opt-einsum==3.3.0
- optax==0.0.7
- optax==0.0.6
- pip==21.1.2
- promise==2.3
- protobuf==3.17.0
- protobuf==3.17.1
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- requests==2.25.1
Expand All @@ -63,19 +63,19 @@ dependencies:
- scipy==1.6.3
- six==1.15.0
- tabulate==0.8.9
- tensorboard==2.5.0
- tb-nightly==2.6.0a20210525
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.0
- tensorflow==2.5.0
- tensorflow-datasets==4.3.0
- tensorflow-estimator==2.5.0
- tensorflow-metadata==1.0.0
- termcolor==1.1.0
- tfp-nightly==0.14.0.dev20210521
- tf-estimator-nightly==2.5.0.dev2021032601
- tf-nightly==2.6.0.dev20210525
- tfp-nightly==0.14.0.dev20210525
- toolz==0.11.1
- tqdm==4.60.0
- tqdm==4.61.0
- typing-extensions==3.7.4.3
- urllib3==1.26.4
- werkzeug==2.0.1
- wrapt==1.12.1
prefix: /home/jaan/miniconda3/envs/jax
prefix: /scratch/gpfs/altosaar/environment-jax
64 changes: 64 additions & 0 deletions environment-pytorch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: /scratch/gpfs/altosaar/environment-pytorch
channels:
- pytorch
- nvidia
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- blas=1.0=mkl
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2021.4.13=h06a4308_1
- certifi=2020.12.5=py38h06a4308_0
- cudatoolkit=11.1.74=h6bb024c_0
- ffmpeg=4.3=hf484d3e_0
- freetype=2.10.4=h5ab3b9f_0
- gmp=6.2.1=h2531618_2
- gnutls=3.6.15=he1e5248_0
- h5py=2.10.0=py38hd6299e0_1
- hdf5=1.10.6=hb1b8bf9_0
- intel-openmp=2021.2.0=h06a4308_610
- jpeg=9b=h024ee3a_2
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.33.1=h53a641e_7
- libffi=3.3=he6710b0_2
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.3.0=hdf63c60_0
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.1=h27cfd23_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.1.0=h2733197_1
- libunistring=0.9.10=h27cfd23_0
- libuv=1.40.0=h7b6447c_0
- lz4-c=1.9.3=h2531618_0
- mkl=2021.2.0=h06a4308_296
- mkl-service=2.3.0=py38h27cfd23_1
- mkl_fft=1.3.0=py38h42c9631_2
- mkl_random=1.2.1=py38ha9443f7_2
- ncurses=6.2=he6710b0_1
- nettle=3.7.2=hbbd107a_1
- ninja=1.10.2=hff7bd54_1
- numpy=1.20.2=py38h2d18471_0
- numpy-base=1.20.2=py38hfae3a4d_0
- olefile=0.46=py_0
- openh264=2.1.0=hd408876_0
- openssl=1.1.1k=h27cfd23_0
- pillow=8.2.0=py38he98fc37_0
- pip=21.1.1=py38h06a4308_0
- python=3.8.10=hdb3f193_7
- pytorch=1.8.1=py3.8_cuda11.1_cudnn8.0.5_0
- readline=8.1=h27cfd23_0
- setuptools=52.0.0=py38h06a4308_0
- six=1.15.0=py38h06a4308_0
- sqlite=3.35.4=hdfb4753_0
- tk=8.6.10=hbc83047_0
- torchaudio=0.8.1=py38
- torchvision=0.9.1=py38_cu111
- typing_extensions=3.7.4.3=pyha847dfd_0
- wheel=0.36.2=pyhd3eb1b0_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.9=haebb681_0
prefix: /scratch/gpfs/altosaar/environment-pytorch
Loading