I used to be very curious to see how JAX is in comparison with Pytorch or Tensorflow. I figured that one of the best ways for somebody to check frameworks is to construct the identical factor from scratch in each of them. And that’s precisely what I did. On this article, I’m creating a Variational Autoencoder with JAX, Tensorflow and Pytorch on the similar time. I’ll current the code for every element facet by facet with the intention to discover variations, similarities, weaknesses and strengths.
Shall we start?
Prologue
Some issues to notice earlier than we discover the code:
-
I’ll use Flax on prime of JAX, which is a neural community library developed by Google. It comprises many ready-to-use deep studying modules, layers, capabilities, and operations
-
For the Tensorflow implementation, I’ll depend on Keras abstractions.
-
For Pytorch, I’ll use the usual
nn.module
.
As a result of most of us are considerably aware of Tensorflow and Pytorch, we can pay extra consideration in JAX and Flax. That’s why I’ll clarify issues alongside the way in which which may be unfamiliar to many. So you’ll be able to take into account this text as a light-weight tutorial on Flax as nicely.
Additionally, I assume that you’re aware of the fundamental rules behind VAEs. If not, you’ll be able to advise my earlier article on latent variable fashions. If every little thing appears clear, let’s proceed.
Fast recap: The vanilla Autoencoder consists of an Encoder and a Decoder. The encoder converts the enter to a latent illustration and the decoder tries to reconstruct the enter primarily based on that illustration. In Variational Autoencoders, stochasticity can also be added to the combination in phrases that the latent illustration gives a chance distribution. That is taking place with the reparametrization trick.
Picture by writer
The encoder
For the encoder, a easy linear layer adopted by a RELU activation needs to be sufficient for a toy instance. The output of the layer might be each the imply and normal deviation of the chance distribution.
The fundamental constructing block of the Flax API is the Module
abstraction, which is what we’ll use to implement our encoder in JAX. The module
is a part of the linen
subpackage. Much like Pytorch’s nn.module
, we once more have to outline our class arguments. In Pytorch, we’re used to declaring them contained in the __init__
operate and implementing the ahead cross contained in the ahead
technique. In Flax, issues are a bit completely different. Arguments are outlined both as dataclass attributes or as technique arguments. Often, fastened properties are outlined as dataclass arguments whereas dynamic properties as technique arguments. Additionally as an alternative of implementing a ahead
technique, we implement __call__
The Dataclass module is launched in Python 3.7 as a utility device to make structured courses particularly for storing information. These courses maintain sure properties and capabilities to deal particularly with the information and its illustration. In addition they cut back plenty of boilerplate code in comparison with common courses.
So to create a brand new module in Flax, we have to:
-
Initialize a category that inherits
flax.linen.nn.Module
-
Outline the static arguments as dataclass arguments
-
Implement the ahead cross contained in the
__call_
technique.
To tie the arguments with the mannequin and with the ability to outline submodules immediately throughout the module, we additionally have to annotate the __call__
technique with @nn.compact
.
Word that as an alternative of utilizing dataclass arguments and the @nn.compact
annotation, we may have declared all arguments inside a setup
technique in the very same manner as we do in Pytorch’s or Tensorflow’s __init__
.
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from flax import optim
class Encoder(nn.Module):
latents: int
@nn.compact
def __call__(self, x):
x = nn.Dense(500, identify='fc1')(x)
x = nn.relu(x)
mean_x = nn.Dense(self.latents, identify='fc2_mean')(x)
logvar_x = nn.Dense(self.latents, identify='fc2_logvar')(x)
return mean_x, logvar_x
import tensorflow as tf
from tensorflow.keras import layers
class Encoder(layers.Layer):
def __init__(self,
latent_dim =20,
identify='encoder',
**kwargs):
tremendous(Encoder, self).__init__(identify=identify, **kwargs)
self.enc1 = layers.Dense(500, activation='relu')
self.mean_x = layers.Dense(latent_dim)
self.logvar_x = layers.Dense(latent_dim)
def name(self, inputs):
x = self.enc1(inputs)
z_mean = self.mean_x(x)
z_log_var = self.logvar_x(x)
return z_mean, z_log_var
import torch
import torch.nn.purposeful as F
class Encoder(torch.nn.Module):
def __init__(self, latent_dim=20):
tremendous(Encoder, self).__init__()
self.enc1 = torch.nn.Linear(784, 500)
self.mean_x = torch.nn.Linear(500,latent_dim)
self.logvar_x = torch.nn.Linear(500, latent_dim)
def ahead(self,inputs):
x = self.enc1(inputs)
x= F.relu(x)
z_mean = self.mean_x(x)
z_log_var = self.logvar_x(x)
return z_mean, z_log_var
Just a few extra issues to note right here earlier than we proceed:
-
Flax’s
nn.linen
package deal comprises most deep studying layers and operation corresponding toDense
,relu
, and plenty of extra -
The code in Flax, Tensorflow, and Pytorch is sort of indistinguishable from one another.
The decoder
In a really related style, we are able to develop the decoder in all 3 frameworks. The decoder might be two linear layers that obtain the latent illustration and output the reconstructed enter.
Once more the implementations are very related.
class Decoder(nn.Module):
@nn.compact
def __call__(self, z):
z = nn.Dense(500, identify='fc1')(z)
z = nn.relu(z)
z = nn.Dense(784, identify='fc2')(z)
return z
class Decoder(layers.Layer):
def __init__(self,
identify='decoder',
**kwargs):
tremendous(Decoder, self).__init__(identify=identify, **kwargs)
self.dec1 = layers.Dense(500, activation='relu')
self.out = layers.Dense(784)
def name(self, z):
z = self.dec1(z)
return self.out(z)
class Decoder(torch.nn.Module):
def __init__(self, latent_dim=20):
tremendous(Decoder, self).__init__()
self.dec1 = torch.nn.Linear(latent_dim, 500)
self.out = torch.nn.Linear(500, 784)
def ahead(self,z):
z = self.dec1(z)
z = F.relu(z)
return self.out(z)
Variational Autoencoder
To mix the encoder and the decoder, let’s have yet one more class, known as VAE
, that can characterize all the structure. Right here we additionally want to write down some code for the reparameterization trick. General we’ve got: the latent variable from the encoder is reparameterized and fed to the decoder, which produces the reconstructed enter.
As a reminder, right here is an intuitive picture that explains the reparameterization trick:
Supply: Alexander Amini and Ava Soleimany, Deep Generative Modeling | MIT 6.S191, http://introtodeeplearning.com/
Discover that this time, in JAX we make use of the setup
technique as an alternative of the nn.compact
annotation. Additionally, try how related the reparameterization capabilities are. Positive every framework makes use of its personal capabilities and operations however the basic picture is sort of equivalent.
class VAE(nn.Module):
latents: int = 20
def setup(self):
self.encoder = Encoder(self.latents)
self.decoder = Decoder()
def __call__(self, x, z_rng):
imply, logvar = self.encoder(x)
z = reparameterize(z_rng, imply, logvar)
recon_x = self.decoder(z)
return recon_x, imply, logvar
def reparameterize(rng, imply, logvar):
std = jnp.exp(0.5 * logvar)
eps = random.regular(rng, logvar.form)
return imply + eps * std
def mannequin():
return VAE(latents=LATENTS)
class VAE(tf.keras.Mannequin):
def __init__(self,
latent_dim=20,
identify='vae',
**kwargs):
tremendous(VAE, self).__init__(identify=identify, **kwargs)
self.encoder = Encoder(latent_dim=latent_dim)
self.decoder = Decoder()
def name(self, inputs):
z_mean, z_log_var = self.encoder(inputs)
z = self.reparameterize(z_mean, z_log_var)
reconstructed = self.decoder(z)
return reconstructed, z_mean, z_log_var
def reparameterize(self, imply, logvar):
eps = tf.random.regular(form=imply.form)
return imply + eps * tf.exp(logvar * .5)
class VAE(torch.nn.Module):
def __init__(self, latent_dim=20):
tremendous(VAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
def ahead(self,inputs):
z_mean, z_log_var = self.encoder(inputs)
z = self.reparameterize(z_mean, z_log_var)
reconstructed = self.decoder(z)
return reconstructed, z_mean, z_log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + (eps * std)
Loss and Coaching step
Issues are beginning to differ once we start implementing the coaching step and the loss operate. However not by a lot.
-
In an effort to absolutely benefit from JAX capabilities, we have to add computerized vectorization and XLA compiling to our code. This may be performed simply with the assistance of
vmap
andjit
annotations. -
Furthermore, we’ve got to allow computerized differentiation, which might be completed with the
grad_fn
transformation -
We use the
flax.optim
package deal for optimization algorithms
One other small distinction that we want to pay attention to is how we cross information to our mannequin. This may be achieved by means of the apply technique within the type of mannequin().apply({'params': params}, batch, z_rng)
, the place batch
is our coaching information.
@jax.vmap
def kl_divergence(imply, logvar):
return -0.5 * jnp.sum(1 + logvar - jnp.sq.(imply) - jnp.exp(logvar))
@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
logits = nn.log_sigmoid(logits)
return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))
@jax.jit
def train_step(optimizer, batch, z_rng):
def loss_fn(params):
recon_x, imply, logvar = mannequin().apply({'params': params}, batch, z_rng)
bce_loss = binary_cross_entropy_with_logits(recon_x, batch).imply()
kld_loss = kl_divergence(imply, logvar).imply()
loss = bce_loss + kld_loss
return loss, recon_x
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
_, grad = grad_fn(optimizer.goal)
optimizer = optimizer.apply_gradient(grad)
return optimizer
def kl_divergence(imply, logvar):
return -0.5 * tf.reduce_sum(
1 + logvar - tf.sq.(imply) -
tf.exp(logvar), axis=1)
def binary_cross_entropy_with_logits(logits, labels):
logits = tf.math.log(logits)
return - tf.reduce_sum(
labels * logits +
(1-labels) * tf.math.log(- tf.math.expm1(logits)),
axis=1
)
@tf.operate
def train_step(mannequin, x, optimizer):
with tf.GradientTape() as tape:
recon_x, imply, logvar = mannequin(x)
bce_loss = tf.reduce_mean(binary_cross_entropy_with_logits(recon_x, batch))
kld_loss = tf.reduce_mean(kl_divergence(imply, logvar))
loss = bce_loss + kld_loss
print(loss, kld_loss, bce_loss)
gradients = tape.gradient(loss, mannequin.trainable_variables)
optimizer.apply_gradients(zip(gradients, mannequin.trainable_variables))
def final_loss(reconstruction, train_x, mu, logvar):
BCE = torch.nn.BCEWithLogitsLoss(discount='sum')(reconstruction, train_x)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def train_step(train_x):
train_x = torch.from_numpy(train_x)
optimizer.zero_grad()
reconstruction, mu, logvar = mannequin(train_x)
loss = final_loss(reconstruction, train_x, mu, logvar)
running_loss += loss.merchandise()
loss.backward()
optimizer.step()
Do not forget that VAEs are skilled by maximizing the proof decrease sure, referred to as ELBO.
Coaching loop
Lastly, it’s time for all the coaching loop which can execute the train_step
operate iteratively.
In Flax, the mannequin needs to be initialized earlier than coaching, which is completed by the init
operate corresponding to: params = mannequin().init(key, init_data, rng)['params']
. The same initialization is important for the optimizer as nicely: optimizer = optim.Adam( learning_rate = LEARNING_RATE ).create( params )
.
jax.device_put
is used to switch the optimizer into the GPU’s reminiscence.
rng = random.PRNGKey(0)
rng, key = random.cut up(rng)
init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)
params = mannequin().init(key, init_data, rng)['params']
optimizer = optim.Adam(learning_rate=LEARNING_RATE).create(params)
optimizer = jax.device_put(optimizer)
rng, z_key, eval_rng = random.cut up(rng, 3)
z = random.regular(z_key, (64, LATENTS))
steps_per_epoch = 50000 // BATCH_SIZE
for epoch in vary(NUM_EPOCHS):
for _ in vary(steps_per_epoch):
batch = subsequent(train_ds)
rng, key = random.cut up(rng)
optimizer = train_step(optimizer, batch, key)
vae = VAE(latent_dim=LATENTS)
optimizer = tf.keras.optimizers.Adam(1e-4)
for epoch in vary(NUM_EPOCHS):
for train_x in train_ds:
train_step(vae, train_x, optimizer)
def prepare(mannequin,training_data):
optimizer = optim.Adam(mannequin.parameters(), lr=LEARNING_RATE)
running_loss = 0.0
for epoch in vary(NUM_EPOCHS):
for i, train_x in enumerate(training_data, 0):
train_step(train_x)
vae = VAE(LATENTS)
prepare(vae, train_ds)
Load and Course of Information
One factor I haven’t talked about is information. How can we load and preprocess information in Flax? Properly, Flax doesn’t embrace information manipulation packages but moreover the fundamental operations of jax.numpy
. Proper now, our greatest is to borrow packages from different frameworks corresponding to Tensorflow datasets (tfds) or Torchvision. To make the article self-complete, I’ll embrace the code I used to load a pattern coaching dataset with tfds
. Be at liberty although to make use of your personal dataloader if you happen to’re planning to run the implementations introduced on this article.
import tensorflow_datasets as tfds
tf.config.experimental.set_visible_devices([], 'GPU')
def prepare_image(x):
x = tf.solid(x['image'], tf.float32)
x = tf.reshape(x, (-1,))
return x
ds_builder = tfds.builder('binarized_mnist')
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(cut up=tfds.Break up.TRAIN)
train_ds = train_ds.map(prepare_image)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(50000)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = iter(tfds.as_numpy(train_ds))
test_ds = ds_builder.as_dataset(cut up=tfds.Break up.TEST)
test_ds = test_ds.map(prepare_image).batch(10000)
test_ds = np.array(checklist(test_ds)[0])
Closing observations
To shut the article, let’s focus on a number of ultimate observations that seem after a detailed evaluation of the code:
-
All 3 frameworks have diminished the boilerplate code to a minimal with Flax being the one which requires a bit extra, particularly on the coaching half. Nevertheless that is solely to make sure that we exploit all of the accessible transformations corresponding to computerized differentiation, vectorization and just-in-time compiler.
-
The definition of modules, layers and fashions is sort of equivalent in all of them
-
Flax and JAX is by design fairly versatile and expandable
-
Flax doesn’t have information loading and processing capabilities but
-
By way of ready-to-use layers and optimizers, Flax doesn’t must be jealous of Tensorflow and Pytorch. For positive it lacks the large library of its rivals however it’s progressively getting there.
* Disclosure: Please word that among the hyperlinks above is likely to be affiliate hyperlinks, and at no extra value to you, we’ll earn a fee if you happen to resolve to make a purchase order after clicking by means of.