from __future__ import (
division,
print_function,
absolute_import
)
from six.moves import range
import tensorflow as tf
import tflearn
from tflearn.datasets import mnist
import vae
import numpy as np
from skimage import io
original_dim = 784
latent_dim = 2
intermediate_dim = 512
model_file = 'model_variational_autoencoder-43000'
X, Y, testX, testY = mnist.load_data()
original_shape = X.shape[1:]
original_shape = [original_shape[i] for i in range(len(original_shape))]
with tf.Graph().as_default():
input_shape = [None] + original_shape
x = tflearn.input_data(shape=input_shape)
z, mean, logvar = vae.encode(x, intermediate_dim=intermediate_dim,
latent_dim=latent_dim)
encoder = tflearn.DNN(z)
optargs = {'scope_for_restore':'Encoder'}
encoder.load(model_file, optargs)
mean_encoder = tflearn.DNN(mean)
mean_encoder.load(model_file, optargs)
logvar_encoder = tflearn.DNN(logvar)
logvar_encoder.load(model_file, optargs)
with tf.Graph().as_default():
# build a digit generator that can sample from the learned distribution
decoder_input = tflearn.input_data(shape=[None, latent_dim])
gen_decoded_mean = vae.decode(decoder_input, intermediate_dim=intermediate_dim,
original_shape=original_shape)
generator = tflearn.DNN(gen_decoded_mean)
generator.load(model_file, {'scope_for_restore':'Decoder'})
digit_size = 28
n = 15
linspace = 1000
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-linspace, linspace, n)
grid_y = np.linspace(-linspace, linspace, n)
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi] + [0 for k in range(2, latent_dim)]])
x_decoded = generator.predict(z_sample)
digit = np.reshape(x_decoded[0], [digit_size, digit_size])
figure[i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size] = digit
figure *= 255
figure = figure.astype(np.uint8)
io.imsave('vae_z.png', figure)
figure = np.ndarray(shape=(digit_size * (n), digit_size * (n)),
dtype=np.float16)
testX = tflearn.data_utils.shuffle(X)[0][0:1]
testMean = mean_encoder.predict(testX)[0]
testLogVar = logvar_encoder.predict(testX)[0]
std = [np.exp(0.5 * testLogVar[i]) * 4 for i in range(2)]
grid_x = np.linspace(-std[0], std[0], n) + testMean[0]
grid_y = np.linspace(-std[1], std[1], n) + testMean[1]
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi] + [testMean[k] for k in range(2, latent_dim)]])
x_decoded = generator.predict(z_sample)
digit = np.reshape(x_decoded[0], [digit_size, digit_size])
figure[i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size] = digit
figure *= 255
figure = figure.astype(np.uint8)
io.imsave('vae_std.png', figure)