Skip to content

Commit cdf9000

Browse files
author
Jaan Altosaar
committed
add mnist download utility
1 parent 1d0f523 commit cdf9000

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

train_variational_autoencoder_pytorch.py

+5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pathlib
1616
import h5py
1717
import random
18+
import data
1819
import flow
1920

2021
config = """
@@ -30,6 +31,7 @@
3031
n_samples: 128
3132
use_gpu: true
3233
train_dir: $TMPDIR
34+
data_dir: $TMPDIR
3335
seed: 582838
3436
"""
3537

@@ -147,6 +149,9 @@ def cycle(iterable):
147149

148150

149151
def load_binary_mnist(cfg, **kwcfg):
152+
fname = cfg.data_dir / 'binary_mnist.h5'
153+
if not fname.exists():
154+
data.download_binary_mnist(fname)
150155
f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binarized_mnist.hdf5'), 'r')
151156
x_train = f['train'][::]
152157
x_val = f['valid'][::]

0 commit comments

Comments
 (0)