-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvae.py
More file actions
261 lines (212 loc) · 10.2 KB
/
vae.py
File metadata and controls
261 lines (212 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
""" vae.py: Run the Variational Auto-encoder.
TODO:
- Incorporate L (MC samples) without blowing up the decoder variable count
- There's a blow-up of the log_var output of the encoder which makes
the KL-divergence term of the error function go to infinity since there
is a var term (where var = e^{log_var}). This seems to be irreversible
when the learning rate is high.
UPDATE: The network weights seem to be such that at the beginning, there
is high variance output (exploding gradients when NN still malleable)
UPDATE: Also exacerbated by batch size. Gradient clipping ineffective
when gradient becomes nan!
UPDATE: This problem is *super* sensitive to learning rate and highly
nondeterministic. At a learning of 0.016 (using Adam) gradients will
sometimes blow up to 3e+28 and when it doesn't, they will go no higher
than 100! However, there is much more stability across runs even at 0.015!
UPDATE: Also sensitive to latent space dimensionality changes (reduction
from 10 -> 2 made it go haywire even with Adam rate of 0.01)
TODO: Try value clipping of the KL divergence. Norm clipping too?
- We use output of sigmoid to parameterize the multivariate Bernoulli.
Does its steepeness affect learning?
- Apply batch normalization as in [3]
- Try different reconstruction loss instead of log-likelihood:
- Cross entropy
- Quantitatively assess Adagrad, SGD, Adam performance
- Use tf.nn.dropout to perform Variational Dropout
- Other datasets besides MNIST
- Examples:
- DRAW network
- Generative adversarial network
- Music composition network (similar to DRAW)
- Picasso network (similar to DRAW)
"""
__author__ = "shraman-rc"
import tensorflow as tf
import numpy as np
import click as cl
from nets import BernoulliMLP, GaussianMLP
import likelihoods as lh
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.set_random_seed(0)
class VAE(object):
def __init__(self, ARCH, DIMS, OPT, PARAMS, TRAIN):
'''
Initializes VAE with parameter dictionaries that should follow the
same format as config/nn_config.yaml
'''
self.ARCH, self.DIMS, self.OPT, self.PARAMS, self.TRAIN = \
ARCH, DIMS, OPT, PARAMS, TRAIN
# Inputs - mini-batches of (flattened) images
self.x_batch = tf.placeholder(tf.float32,
shape=[None, self.DIMS["data"]])
# Encoder parameterizes posterior Gaussian approximation q(z|x)
self.encoder = GaussianMLP(self.x_batch,
self.ARCH["encoder"]["n_units"], self.DIMS["latent"])
self.latent = {}
self.latent["mu"] = self.encoder.out_params.mu
self.latent["log_var"] = self.encoder.out_params.log_var
self.latent["var"] = tf.exp(self.latent["log_var"])
self.latent["stddev"] = tf.sqrt(self.latent["var"])
# Reparameterize latent space (z = g(ep,x); ep ~ p(ep))
# Note: Element-wise univariate Gaussian sampling <=>
# multivariate Gaussian sampling
self.ep = tf.random_normal([self.TRAIN["batch_size"],
self.DIMS["latent"]], mean=0, stddev=1)
self.z_batch = self.latent["mu"] + self.latent["stddev"]*self.ep
# Decoder samples from latent distribution, parameterizes likelihood
# (in this case, a multivariate Bernoulli - working with images)
self.decoder = BernoulliMLP(self.z_batch,
self.ARCH["decoder"]["n_units"], self.DIMS["data"])
# The (negative) KL divergence between the variational approx. and the *prior*
# p_theta(z) acts as a regularizing term so that the latent distribution
# doesn't overfit. The closed-form eq. is derived in [1]: Appedix B
self.neg_KL_pr = 0.5 * tf.reduce_sum(1 + self.latent["log_var"]
- self.latent["mu"]**2 - self.latent["var"], 1)
# The 'reconstruction error' (predictive likelihood): log p_theta(x_batch|z)
self.ll = lh.ll_bernoulli(self.x_batch, self.decoder.out_params.p)
# ELBO estimate (total reward function)
self.ELBO_est = tf.reduce_mean(self.neg_KL_pr + self.ll) # Mean over minibatch
# Pick a flavor of gradient descent
if self.OPT["type"].lower() == "adagrad":
self.optimizer = tf.train.AdagradOptimizer(self.OPT["Adagrad_rate"])
elif self.OPT["type"].lower() == "adam":
self.optimizer = tf.train.AdamOptimizer(self.OPT["Adam_rate"])
# Notice that we are minimizing the negative (i.e. maximizing) the ELBO
# We also clip the gradients to prevent blowup during first few
# training phases
#self.train_op = self.optimizer.minimize(-self.ELBO_est)
#self.vi_train_op = self.optimizer.minimize(-self.neg_KL_pr)
#self.ll_train_op = self.optimizer.minimize(-self.ll)
# ...with clipped gradients:
gvs = self.optimizer.compute_gradients(-self.ELBO_est)
capped_gvs = [(tf.clip_by_value(
grad, -self.OPT["max_grad"], self.OPT["max_grad"]), var)
for grad, var in gvs if grad != None]
flat_grads = [tf.reshape(grad,[-1]) for grad, var in capped_gvs]
self.max_grad = tf.reduce_max(tf.concat(0, flat_grads))
self.train_op = self.optimizer.apply_gradients(capped_gvs)
# Default session to use for operations
self.sess = tf.InteractiveSession()
def _train_step(self, sess, data, verbose=True):
''' Common helper to run individual training steps, see _train()
Returns output of salient variables above (e.g. ELBO) after
each optimization iteration
Params:
- sess,verbose: See _train()
- data: batch of raw data with correct dimensions
'''
_, ELBO, ll, neg_KL, mu, log_var, ep, max_grad = sess.run([
self.train_op,
self.ELBO_est,
self.ll,
self.neg_KL_pr,
self.latent["mu"],
self.latent["log_var"],
self.ep,
self.max_grad],
feed_dict={self.x_batch: data})
# _, ELBO, ll, neg_KL, mu, log_var, ep, max_grad = sess.run([
# self.vi_train_op,
# self.ELBO_est,
# self.ll,
# self.neg_KL_pr,
# self.latent["mu"],
# self.latent["log_var"],
# self.ep,
# self.max_grad],
# feed_dict={x_batch: data})
# Perform some sort of reductions on minibatches if need be
ELBO, neg_KL, ll, mu, log_var, ep = (
np.mean(ELBO),
np.mean(neg_KL),
np.mean(ll),
mu[0],
log_var[0],
ep[0])
# Print stats
if verbose:
cl.secho((
"ELBO (estimate): {}\n"
"KL Div (prior): {}\n"
"Likelihood: {}\n"
"Mu: {}\n"
"Log var: {}\n"
"Epsilon: {}\n"
"Max grad: {}")
.format(ELBO, -neg_KL, ll, mu, log_var, ep, max_grad), fg='cyan')
return ELBO, ll, neg_KL, mu, log_var, ep, max_grad
def _train(self, iters, mbsize, sess, verbose=True):
''' Trains the VAE end-to-end on MNIST (handwriting) dataset
Returns progress through training phases on above variables
via numpy arrays.
Params:
- iters: Number of training iteration per epoch
- mbsize: Number of datapoints per minibatch
- sess: TF session to use if already instantiated one
if None, will use self.sess (InteractiveSession)
- verbose: Print training progress after each timestep
'''
# Train on MNIST
# To keep track of progress
progress = {}
progress["ELBO"] = np.zeros(iters)
progress["KL"] = np.zeros(iters)
progress["LL"] = np.zeros(iters)
# Optimize VAE
sess.run(tf.initialize_all_variables())
timesteps = xrange(iters)
for t in timesteps:
cl.secho('Minibatch {}'.format(t), fg='green', bold=False)
batch = mnist.train.next_batch(mbsize)[0]
prog = self._train_step(sess, batch, verbose)
progress["ELBO"][t] = prog[0]
progress["LL"][t] = prog[1]
progress["KL"][t] = -prog[2]
progress["iters"] = timesteps
cl.secho('Success!', fg='green', bold=True)
return progress
def train(self, iters=None, mbsize=None, sess=None, verbose=True):
''' Wrapper for above _train() function
'''
iters = iters or self.TRAIN["n_iters"]
mbsize = mbsize or self.TRAIN["batch_size"]
sess = sess or self.sess
return self._train(iters, mbsize, sess, verbose)
def encoder_fp(self, data, sess=None):
''' Performs one forward pass of the encoder.
Output are the parameters for the latent distribution.
Params:
- data: Datapoint(s) to be processed by encoder
- sess: TF session to use if already instantiated one
'''
sess = sess or self.sess
return sess.run(self.latent["mu"], self.latent["stddev"],
feed_dict={self.x_batch: data})
def decoder_fp(self, sess=None):
''' Performs one forward pass of the decoder.
Output are the parameters for the data distribution.
Params:
- sess: TF session to use if already instantiated one
'''
sess = sess or self.sess
return sess.run(self.decoder.out_params.p)
def full_fp(self, data, sess=None):
''' Performs one full forward pass of VAE end-to-end.
Params:
- data: Datapoint(s) to be processed by encoder
- sess: TF session to use if already instantiated one
'''
sess = sess or self.sess
return sess.run([self.latent["mu"], self.latent["stddev"],
self.decoder.out_params.p], feed_dict={self.x_batch: data})