in

JAX vs Tensorflow vs Pytorch: Building a Variational Autoencoder (VAE)

I used to be very curious to see how JAX is in comparison with Pytorch or Tensorflow. I figured that one of the best ways for somebody to check frameworks is to construct the identical factor from scratch in each of them. And that’s precisely what I did. On this article, I’m creating a Variational Autoencoder with JAX, Tensorflow and Pytorch on the similar time. I’ll current the code for every element facet by facet with the intention to discover variations, similarities, weaknesses and strengths.

Shall we start?

Prologue

Some issues to notice earlier than we discover the code:

  • I’ll use Flax on prime of JAX, which is a neural community library developed by Google. It comprises many ready-to-use deep studying modules, layers, capabilities, and operations

  • For the Tensorflow implementation, I’ll depend on Keras abstractions.

  • For Pytorch, I’ll use the usual nn.module.

As a result of most of us are considerably aware of Tensorflow and Pytorch, we can pay extra consideration in JAX and Flax. That’s why I’ll clarify issues alongside the way in which which may be unfamiliar to many. So you’ll be able to take into account this text as a light-weight tutorial on Flax as nicely.

Additionally, I assume that you’re aware of the fundamental rules behind VAEs. If not, you’ll be able to advise my earlier article on latent variable fashions. If every little thing appears clear, let’s proceed.

Fast recap: The vanilla Autoencoder consists of an Encoder and a Decoder. The encoder converts the enter to a latent illustration zz and the decoder tries to reconstruct the enter primarily based on that illustration. In Variational Autoencoders, stochasticity can also be added to the combination in phrases that the latent illustration gives a chance distribution. That is taking place with the reparametrization trick.


vae

Picture by writer

The encoder

For the encoder, a easy linear layer adopted by a RELU activation needs to be sufficient for a toy instance. The output of the layer might be each the imply and normal deviation of the chance distribution.

The fundamental constructing block of the Flax API is the Module abstraction, which is what we’ll use to implement our encoder in JAX. The module is a part of the linen subpackage. Much like Pytorch’s nn.module, we once more have to outline our class arguments. In Pytorch, we’re used to declaring them contained in the __init__ operate and implementing the ahead cross contained in the ahead technique. In Flax, issues are a bit completely different. Arguments are outlined both as dataclass attributes or as technique arguments. Often, fastened properties are outlined as dataclass arguments whereas dynamic properties as technique arguments. Additionally as an alternative of implementing a ahead technique, we implement __call__

The Dataclass module is launched in Python 3.7 as a utility device to make structured courses particularly for storing information. These courses maintain sure properties and capabilities to deal particularly with the information and its illustration. In addition they cut back plenty of boilerplate code in comparison with common courses.

So to create a brand new module in Flax, we have to:

  • Initialize a category that inherits flax.linen.nn.Module

  • Outline the static arguments as dataclass arguments

  • Implement the ahead cross contained in the __call_ technique.

To tie the arguments with the mannequin and with the ability to outline submodules immediately throughout the module, we additionally have to annotate the __call__ technique with @nn.compact.

Word that as an alternative of utilizing dataclass arguments and the @nn.compact annotation, we may have declared all arguments inside a setup technique in the very same manner as we do in Pytorch’s or Tensorflow’s __init__.

import numpy as np

import jax

import jax.numpy as jnp

from jax import random

from flax import linen as nn

from flax import optim

class Encoder(nn.Module):

latents: int

@nn.compact

def __call__(self, x):

x = nn.Dense(500, identify='fc1')(x)

x = nn.relu(x)

mean_x = nn.Dense(self.latents, identify='fc2_mean')(x)

logvar_x = nn.Dense(self.latents, identify='fc2_logvar')(x)

return mean_x, logvar_x

import tensorflow as tf

from tensorflow.keras import layers

class Encoder(layers.Layer):

def __init__(self,

latent_dim =20,

identify='encoder',

**kwargs):

tremendous(Encoder, self).__init__(identify=identify, **kwargs)

self.enc1 = layers.Dense(500, activation='relu')

self.mean_x = layers.Dense(latent_dim)

self.logvar_x = layers.Dense(latent_dim)

def name(self, inputs):

x = self.enc1(inputs)

z_mean = self.mean_x(x)

z_log_var = self.logvar_x(x)

return z_mean, z_log_var

import torch

import torch.nn.purposeful as F

class Encoder(torch.nn.Module):

def __init__(self, latent_dim=20):

tremendous(Encoder, self).__init__()

self.enc1 = torch.nn.Linear(784, 500)

self.mean_x = torch.nn.Linear(500,latent_dim)

self.logvar_x = torch.nn.Linear(500, latent_dim)

def ahead(self,inputs):

x = self.enc1(inputs)

x= F.relu(x)

z_mean = self.mean_x(x)

z_log_var = self.logvar_x(x)

return z_mean, z_log_var

Just a few extra issues to note right here earlier than we proceed:

  • Flax’s nn.linen package deal comprises most deep studying layers and operation corresponding to Dense, relu, and plenty of extra

  • The code in Flax, Tensorflow, and Pytorch is sort of indistinguishable from one another.

The decoder

In a really related style, we are able to develop the decoder in all 3 frameworks. The decoder might be two linear layers that obtain the latent illustration zz and output the reconstructed enter.

Once more the implementations are very related.

class Decoder(nn.Module):

@nn.compact

def __call__(self, z):

z = nn.Dense(500, identify='fc1')(z)

z = nn.relu(z)

z = nn.Dense(784, identify='fc2')(z)

return z

class Decoder(layers.Layer):

def __init__(self,

identify='decoder',

**kwargs):

tremendous(Decoder, self).__init__(identify=identify, **kwargs)

self.dec1 = layers.Dense(500, activation='relu')

self.out = layers.Dense(784)

def name(self, z):

z = self.dec1(z)

return self.out(z)

class Decoder(torch.nn.Module):

def __init__(self, latent_dim=20):

tremendous(Decoder, self).__init__()

self.dec1 = torch.nn.Linear(latent_dim, 500)

self.out = torch.nn.Linear(500, 784)

def ahead(self,z):

z = self.dec1(z)

z = F.relu(z)

return self.out(z)

Variational Autoencoder

To mix the encoder and the decoder, let’s have yet one more class, known as VAE, that can characterize all the structure. Right here we additionally want to write down some code for the reparameterization trick. General we’ve got: the latent variable from the encoder is reparameterized and fed to the decoder, which produces the reconstructed enter.

As a reminder, right here is an intuitive picture that explains the reparameterization trick:


reparameterization-trick

Supply: Alexander Amini and Ava Soleimany, Deep Generative Modeling | MIT 6.S191, http://introtodeeplearning.com/

Discover that this time, in JAX we make use of the setup technique as an alternative of the nn.compact annotation. Additionally, try how related the reparameterization capabilities are. Positive every framework makes use of its personal capabilities and operations however the basic picture is sort of equivalent.

class VAE(nn.Module):

latents: int = 20

def setup(self):

self.encoder = Encoder(self.latents)

self.decoder = Decoder()

def __call__(self, x, z_rng):

imply, logvar = self.encoder(x)

z = reparameterize(z_rng, imply, logvar)

recon_x = self.decoder(z)

return recon_x, imply, logvar

def reparameterize(rng, imply, logvar):

std = jnp.exp(0.5 * logvar)

eps = random.regular(rng, logvar.form)

return imply + eps * std

def mannequin():

return VAE(latents=LATENTS)

class VAE(tf.keras.Mannequin):

def __init__(self,

latent_dim=20,

identify='vae',

**kwargs):

tremendous(VAE, self).__init__(identify=identify, **kwargs)

self.encoder = Encoder(latent_dim=latent_dim)

self.decoder = Decoder()

def name(self, inputs):

z_mean, z_log_var = self.encoder(inputs)

z = self.reparameterize(z_mean, z_log_var)

reconstructed = self.decoder(z)

return reconstructed, z_mean, z_log_var

def reparameterize(self, imply, logvar):

eps = tf.random.regular(form=imply.form)

return imply + eps * tf.exp(logvar * .5)

class VAE(torch.nn.Module):

def __init__(self, latent_dim=20):

tremendous(VAE, self).__init__()

self.encoder = Encoder(latent_dim)

self.decoder = Decoder(latent_dim)

def ahead(self,inputs):

z_mean, z_log_var = self.encoder(inputs)

z = self.reparameterize(z_mean, z_log_var)

reconstructed = self.decoder(z)

return reconstructed, z_mean, z_log_var

def reparameterize(self, mu, log_var):

std = torch.exp(0.5 * log_var)

eps = torch.randn_like(std)

return mu + (eps * std)

Loss and Coaching step

Issues are beginning to differ once we start implementing the coaching step and the loss operate. However not by a lot.

  1. In an effort to absolutely benefit from JAX capabilities, we have to add computerized vectorization and XLA compiling to our code. This may be performed simply with the assistance of vmap and jit annotations.

  2. Furthermore, we’ve got to allow computerized differentiation, which might be completed with the grad_fn transformation

  3. We use the flax.optim package deal for optimization algorithms

One other small distinction that we want to pay attention to is how we cross information to our mannequin. This may be achieved by means of the apply technique within the type of mannequin().apply({'params': params}, batch, z_rng), the place batch is our coaching information.

@jax.vmap

def kl_divergence(imply, logvar):

return -0.5 * jnp.sum(1 + logvar - jnp.sq.(imply) - jnp.exp(logvar))

@jax.vmap

def binary_cross_entropy_with_logits(logits, labels):

logits = nn.log_sigmoid(logits)

return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))

@jax.jit

def train_step(optimizer, batch, z_rng):

def loss_fn(params):

recon_x, imply, logvar = mannequin().apply({'params': params}, batch, z_rng)

bce_loss = binary_cross_entropy_with_logits(recon_x, batch).imply()

kld_loss = kl_divergence(imply, logvar).imply()

loss = bce_loss + kld_loss

return loss, recon_x

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

_, grad = grad_fn(optimizer.goal)

optimizer = optimizer.apply_gradient(grad)

return optimizer

def kl_divergence(imply, logvar):

return -0.5 * tf.reduce_sum(

1 + logvar - tf.sq.(imply) -

tf.exp(logvar), axis=1)

def binary_cross_entropy_with_logits(logits, labels):

logits = tf.math.log(logits)

return - tf.reduce_sum(

labels * logits +

(1-labels) * tf.math.log(- tf.math.expm1(logits)),

axis=1

)

@tf.operate

def train_step(mannequin, x, optimizer):

with tf.GradientTape() as tape:

recon_x, imply, logvar = mannequin(x)

bce_loss = tf.reduce_mean(binary_cross_entropy_with_logits(recon_x, batch))

kld_loss = tf.reduce_mean(kl_divergence(imply, logvar))

loss = bce_loss + kld_loss

print(loss, kld_loss, bce_loss)

gradients = tape.gradient(loss, mannequin.trainable_variables)

optimizer.apply_gradients(zip(gradients, mannequin.trainable_variables))

def final_loss(reconstruction, train_x, mu, logvar):

BCE = torch.nn.BCEWithLogitsLoss(discount='sum')(reconstruction, train_x)

KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

return BCE + KLD

def train_step(train_x):

train_x = torch.from_numpy(train_x)

optimizer.zero_grad()

reconstruction, mu, logvar = mannequin(train_x)

loss = final_loss(reconstruction, train_x, mu, logvar)

running_loss += loss.merchandise()

loss.backward()

optimizer.step()

Do not forget that VAEs are skilled by maximizing the proof decrease sure, referred to as ELBO.

Lθ,ϕ(x)=Eqϕ(zx)[logpθ(xz)]KL(qϕ(zx)pθ(z))L_{theta,phi}(x) = textbf{E}_{q_{phi}(z|x)} [ log p_{theta}(x|z) ] – textbf{KL}(q_{phi}(z |x) || p_{theta}(z))

Coaching loop

Lastly, it’s time for all the coaching loop which can execute the train_step operate iteratively.

In Flax, the mannequin needs to be initialized earlier than coaching, which is completed by the init operate corresponding to: params = mannequin().init(key, init_data, rng)['params']. The same initialization is important for the optimizer as nicely: optimizer = optim.Adam( learning_rate = LEARNING_RATE ).create( params ).

jax.device_put is used to switch the optimizer into the GPU’s reminiscence.

rng = random.PRNGKey(0)

rng, key = random.cut up(rng)

init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)

params = mannequin().init(key, init_data, rng)['params']

optimizer = optim.Adam(learning_rate=LEARNING_RATE).create(params)

optimizer = jax.device_put(optimizer)

rng, z_key, eval_rng = random.cut up(rng, 3)

z = random.regular(z_key, (64, LATENTS))

steps_per_epoch = 50000 // BATCH_SIZE

for epoch in vary(NUM_EPOCHS):

for _ in vary(steps_per_epoch):

batch = subsequent(train_ds)

rng, key = random.cut up(rng)

optimizer = train_step(optimizer, batch, key)

vae = VAE(latent_dim=LATENTS)

optimizer = tf.keras.optimizers.Adam(1e-4)

for epoch in vary(NUM_EPOCHS):

for train_x in train_ds:

train_step(vae, train_x, optimizer)

def prepare(mannequin,training_data):

optimizer = optim.Adam(mannequin.parameters(), lr=LEARNING_RATE)

running_loss = 0.0

for epoch in vary(NUM_EPOCHS):

for i, train_x in enumerate(training_data, 0):

train_step(train_x)

vae = VAE(LATENTS)

prepare(vae, train_ds)

Load and Course of Information

One factor I haven’t talked about is information. How can we load and preprocess information in Flax? Properly, Flax doesn’t embrace information manipulation packages but moreover the fundamental operations of jax.numpy. Proper now, our greatest is to borrow packages from different frameworks corresponding to Tensorflow datasets (tfds) or Torchvision. To make the article self-complete, I’ll embrace the code I used to load a pattern coaching dataset with tfds. Be at liberty although to make use of your personal dataloader if you happen to’re planning to run the implementations introduced on this article.

import tensorflow_datasets as tfds

tf.config.experimental.set_visible_devices([], 'GPU')

def prepare_image(x):

x = tf.solid(x['image'], tf.float32)

x = tf.reshape(x, (-1,))

return x

ds_builder = tfds.builder('binarized_mnist')

ds_builder.download_and_prepare()

train_ds = ds_builder.as_dataset(cut up=tfds.Break up.TRAIN)

train_ds = train_ds.map(prepare_image)

train_ds = train_ds.cache()

train_ds = train_ds.repeat()

train_ds = train_ds.shuffle(50000)

train_ds = train_ds.batch(BATCH_SIZE)

train_ds = iter(tfds.as_numpy(train_ds))

test_ds = ds_builder.as_dataset(cut up=tfds.Break up.TEST)

test_ds = test_ds.map(prepare_image).batch(10000)

test_ds = np.array(checklist(test_ds)[0])

Closing observations

To shut the article, let’s focus on a number of ultimate observations that seem after a detailed evaluation of the code:

  • All 3 frameworks have diminished the boilerplate code to a minimal with Flax being the one which requires a bit extra, particularly on the coaching half. Nevertheless that is solely to make sure that we exploit all of the accessible transformations corresponding to computerized differentiation, vectorization and just-in-time compiler.

  • The definition of modules, layers and fashions is sort of equivalent in all of them

  • Flax and JAX is by design fairly versatile and expandable

  • Flax doesn’t have information loading and processing capabilities but

  • By way of ready-to-use layers and optimizers, Flax doesn’t must be jealous of Tensorflow and Pytorch. For positive it lacks the large library of its rivals however it’s progressively getting there.

Deep Studying in Manufacturing Ebook 📖

Learn to construct, prepare, deploy, scale and preserve deep studying fashions. Perceive ML infrastructure and MLOps utilizing hands-on examples.

Study extra

* Disclosure: Please word that among the hyperlinks above is likely to be affiliate hyperlinks, and at no extra value to you, we’ll earn a fee if you happen to resolve to make a purchase order after clicking by means of.

Leave a Reply

Your email address will not be published. Required fields are marked *