Skip to content

Commit 6b2f090

Browse files
author
Jaan Altosaar
committed
test flow example
1 parent f212232 commit 6b2f090

File tree

1 file changed

+90
-10
lines changed

1 file changed

+90
-10
lines changed

train_variational_autoencoder_jax.py

+90-10
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import optax
1616
import tensorflow_datasets as tfds
1717
from tensorflow_probability.substrates import jax as tfp
18+
import distrax
1819

1920
tfd = tfp.distributions
2021

@@ -110,13 +111,91 @@ def condition(self, inputs):
110111
scale = jax.nn.softplus(scale_arg)
111112
return loc, scale
112113

113-
def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
114+
def __call__(self, x: jnp.ndarray) -> tfd.Distribution:
114115
loc, scale = self.condition(x)
115116
# IMPORTANT: need to check in source code that reparameterization_type=tfd.FULLY_REPARAMETERIZED for this class
116117
q_z = tfd.Normal(loc=loc, scale=scale)
117118
return q_z
118119

119120

121+
def make_conditioner(
122+
event_shape: Sequence[int], hidden_sizes: Sequence[int], num_bijector_params: int
123+
) -> hk.Sequential:
124+
"""Creates an MLP conditioner for each layer of the flow."""
125+
return hk.Sequential(
126+
[
127+
hk.Flatten(preserve_dims=-len(event_shape)),
128+
hk.nets.MLP(hidden_sizes, activate_final=True),
129+
# We initialize this linear layer to zero so that the flow is initialized
130+
# to the identity function.
131+
hk.Linear(
132+
np.prod(event_shape) * num_bijector_params,
133+
w_init=jnp.zeros,
134+
b_init=jnp.zeros,
135+
),
136+
hk.Reshape(tuple(event_shape) + (num_bijector_params,), preserve_dims=-1),
137+
]
138+
)
139+
140+
141+
def make_flow(
142+
event_shape: Sequence[int],
143+
num_layers: int,
144+
hidden_sizes: Sequence[int],
145+
num_bins: int,
146+
) -> distrax.Transformed:
147+
"""Creates the flow model."""
148+
# Alternating binary mask.
149+
mask = jnp.arange(0, np.prod(event_shape)) % 2
150+
mask = jnp.reshape(mask, event_shape)
151+
mask = mask.astype(bool)
152+
153+
def bijector_fn(params: jnp.array):
154+
return distrax.RationalQuadraticSpline(params, range_min=0.0, range_max=1.0)
155+
156+
# Number of parameters for the rational-quadratic spline:
157+
# - `num_bins` bin widths
158+
# - `num_bins` bin heights
159+
# - `num_bins + 1` knot slopes
160+
# for a total of `3 * num_bins + 1` parameters.
161+
num_bijector_params = 3 * num_bins + 1
162+
163+
layers = []
164+
for _ in range(num_layers):
165+
layer = distrax.MaskedCoupling(
166+
mask=mask,
167+
bijector=bijector_fn,
168+
conditioner=make_conditioner(
169+
event_shape, hidden_sizes, num_bijector_params
170+
),
171+
)
172+
layers.append(layer)
173+
# Flip the mask after each layer.
174+
mask = jnp.logical_not(mask)
175+
176+
# We invert the flow so that the `forward` method is called with `log_prob`.
177+
flow = distrax.Inverse(distrax.Chain(layers))
178+
base_distribution = distrax.MultivariateNormalDiag(
179+
loc=jnp.zeros(event_shape), scale_diag=jnp.ones(event_shape)
180+
)
181+
return distrax.Transformed(base_distribution, flow)
182+
183+
184+
class VariationalFlow(hk.Module):
185+
def __init__(self, latent_size: int, hidden_size: int):
186+
super().__init__(name="variational")
187+
self._latent_size = latent_size
188+
self._hidden_size = hidden_size
189+
190+
def __call__(self, x: jnp.ndarray) -> distrax.Distribution:
191+
return make_flow(
192+
event_shape=(self._latent_size,),
193+
num_layers=2,
194+
hidden_sizes=[self._hidden_size] * 2,
195+
num_bins=4,
196+
)
197+
198+
120199
def main():
121200
start_time = time.time()
122201
parser = argparse.ArgumentParser()
@@ -126,8 +205,11 @@ def main():
126205
model = hk.transform(
127206
lambda x, z: Model(args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE)(x, z)
128207
)
208+
# variational = hk.transform(
209+
# lambda x: VariationalMeanField(args.latent_size, args.hidden_size)(x)
210+
# )
129211
variational = hk.transform(
130-
lambda x: VariationalMeanField(args.latent_size, args.hidden_size)(x)
212+
lambda x: VariationalFlow(args.latent_size, args.hidden_size)(x)
131213
)
132214
p_params = model.init(
133215
next(rng_seq),
@@ -139,16 +221,14 @@ def main():
139221
optimizer = optax.rmsprop(args.learning_rate)
140222
opt_state = optimizer.init(params)
141223

142-
@jax.jit
224+
# @jax.jit
143225
def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray:
144226
x = batch["image"]
145227
predicate = lambda module_name, name, value: "model" in module_name
146228
p_params, q_params = hk.data_structures.partition(predicate, params)
147229
q_z = variational.apply(q_params, rng_key, x)
148-
z = q_z.sample(sample_shape=[1], seed=rng_key)
230+
z, log_q_z = q_z.sample_and_log_prob(x, sample_shape=[1], seed=rng_key)
149231
p_z, p_x_given_z = model.apply(p_params, rng_key, x, z)
150-
# out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x)
151-
log_q_z = q_z.log_prob(z).sum(axis=-1)
152232
# sum over last three image dimensions (width, height, channels)
153233
log_p_x_given_z = p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1))
154234
# sum over latent dimension
@@ -160,7 +240,7 @@ def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarr
160240
elbo = elbo.sum(axis=0)
161241
return -elbo
162242

163-
@jax.jit
243+
# @jax.jit
164244
def train_step(
165245
params: hk.Params, rng_key: PRNGKey, opt_state: optax.OptState, batch: Batch
166246
) -> Tuple[hk.Params, optax.OptState]:
@@ -170,7 +250,7 @@ def train_step(
170250
new_params = optax.apply_updates(params, updates)
171251
return new_params, new_opt_state
172252

173-
@jax.jit
253+
# @jax.jit
174254
def importance_weighted_estimate(
175255
params: hk.Params, rng_key: PRNGKey, batch: Batch
176256
) -> Tuple[jnp.ndarray, jnp.ndarray]:
@@ -180,9 +260,9 @@ def importance_weighted_estimate(
180260
predicate = lambda module_name, name, value: "model" in module_name
181261
p_params, q_params = hk.data_structures.partition(predicate, params)
182262
q_z = variational.apply(q_params, rng_key, x)
183-
z = q_z.sample(args.num_eval_samples, seed=rng_key)
263+
z, log_q_z = q_z.sample_and_log_prob(sample_shape=[args.num_eval_samples], seed=rng_key)
184264
p_z, p_x_given_z = model.apply(p_params, rng_key, x, z)
185-
log_q_z = q_z.log_prob(z).sum(axis=-1)
265+
# log_q_z = q_z.log_prob(z).sum(axis=-1)
186266
# sum over last three image dimensions (width, height, channels)
187267
log_p_x_given_z = p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1))
188268
# sum over latent dimension

0 commit comments

Comments
 (0)