in

BYOL tutorial: self-supervised learning on CIFAR images with code in Pytorch

After presenting SimCLR, a contrastive self-supervised studying framework, I made a decision to reveal one other notorious methodology, known as BYOL. Bootstrap Your Personal Latent (BYOL), is a brand new algorithm for self-supervised studying of picture representations. BYOL has two most important benefits:

  • It doesn’t explicitly use detrimental samples. As an alternative, it instantly minimizes the similarity of representations of the identical picture below a unique augmented view (constructive pair). Unfavourable samples are pictures from the batch apart from the constructive pair.

  • Because of this, BYOL is claimed to require smaller batch sizes, which makes it a beautiful selection.

Beneath, you may look at the strategy. Not like the unique paper, I name the net community pupil and the goal community trainer.


byol-overview

Overview of BYOL methodology. Supply: BYOL paper

On-line community aka pupil: in comparison with SimCLR, there’s a second MLP, known as predictor, which makes the entire methodology uneven. Uneven in comparison with what? Effectively, to the trainer mannequin (goal community).

Why is that essential?

As a result of the trainer mannequin is up to date solely by way of exponential transferring common (EMA) from the coed’s parameters. Finally, at every iteration, a tiny share (lower than 1%) of the parameters of the coed is handed to the trainer. Thus, gradients circulation solely by way of the coed community. This may be applied as:

class EMA():

def __init__(self, alpha):

tremendous().__init__()

self.alpha = alpha

def update_average(self, outdated, new):

if outdated is None:

return new

return outdated * self.alpha + (1 - self.alpha) * new

ema = EMA(0.99)

for student_params, teacher_params in zip(student_model.parameters(),teacher_model.parameters()):

old_weight, up_weight = teacher_params.information, student_params.information

teacher_params.information = ema.update_average(old_weight, up_weight)

One other key distinction between Simclr and BYOL is the loss operate.

Loss operate

The predictor MLP is solely utilized to the coed, making the structure uneven. It is a key design option to keep away from mode collapse. Mode collapse right here can be to output the identical projection for all of the inputs.


byol-paper-overview-with-tensors

Overview of BYOL methodology. Supply: BYOL paper

Lastly, the authors outlined the next imply squared error between the L2-normalized predictions and goal projections:

Lθ,ξqˉθ(zθ)zˉξ22=22qθ(zθ),zξqθ(zθ)2zξ2.mathcal{L}_{theta, xi} triangleqleft|bar{q}_{theta}left(z_{theta}proper)-bar{z}_{xi}^{prime}proper|_{2}^{2}=2-2 cdot frac{leftlangle q_{theta}left(z_{theta}proper), z_{xi}^{prime}rightrangle}{left|q_{theta}left(z_{theta}proper)proper|_{2} cdotleft|z_{xi}^{prime}proper|_{2}} .

The L2 loss may be applied as follows. L2 normalization is utilized beforehand.

import torch

import torch.nn.purposeful as F

def loss_fn(x, y):

x = F.normalize(x, dim=-1, p=2)

y = F.normalize(y, dim=-1, p=2)

return 2 - 2 * (x * y).sum(dim=-1)

Code is out there on GitHub

Monitoring down what’s taking place in self-supervised pretraining: KNN accuracy

Nonetheless, the loss in self-supervised studying just isn’t a dependable metric to trace. What I discovered to be the easiest way to trace what’s taking place whereas coaching, is to measure the ΚΝΝ accuracy.

The crucial benefit of utilizing KNN is that we do not have to coach a linear classifier on prime every time, so it’s quicker and utterly unsupervised.

Be aware: Measuring KNN solely applies to picture classification, however you get the thought. For this goal, I made a category to encapsulate the logic of KNN in our context:

import numpy as np

import torch

from sklearn.model_selection import cross_val_score

from sklearn.neighbors import KNeighborsClassifier

from torch import nn

class KNN():

def __init__(self, mannequin, ok, gadget):

tremendous(KNN, self).__init__()

self.ok = ok

self.gadget = gadget

self.mannequin = mannequin.to(gadget)

self.mannequin.eval()

def extract_features(self, loader):

"""

Infer/Extract options from a educated mannequin

Args:

loader: practice or take a look at loader

Returns: 3 tensors of all: input_images, options, labels

"""

x_lst = []

options = []

label_lst = []

with torch.no_grad():

for input_tensor, label in loader:

h = self.mannequin(input_tensor.to(self.gadget))

options.append(h)

x_lst.append(input_tensor)

label_lst.append(label)

x_total = torch.stack(x_lst)

h_total = torch.stack(options)

label_total = torch.stack(label_lst)

return x_total, h_total, label_total

def knn(self, options, labels, ok=1):

"""

Evaluating knn accuracy in characteristic area.

Calculates solely top-1 accuracy (returns 0 for top-5)

Args:

options: [... , dataset_size, feat_dim]

labels: [... , dataset_size]

ok: nearest neighbours

Returns: practice accuracy, or practice and take a look at acc

"""

feature_dim = options.form[-1]

with torch.no_grad():

features_np = options.cpu().view(-1, feature_dim).numpy()

labels_np = labels.cpu().view(-1).numpy()

self.cls = KNeighborsClassifier(ok, metric="cosine").match(features_np, labels_np)

acc = self.eval(options, labels)

return acc

def eval(self, options, labels):

feature_dim = options.form[-1]

options = options.cpu().view(-1, feature_dim).numpy()

labels = labels.cpu().view(-1).numpy()

acc = 100 * np.imply(cross_val_score(self.cls, options, labels))

return acc

def _find_best_indices(self, h_query, h_ref):

h_query = h_query / h_query.norm(dim=1).view(-1, 1)

h_ref = h_ref / h_ref.norm(dim=1).view(-1, 1)

scores = torch.matmul(h_query, h_ref.t())

rating, indices = scores.topk(1, dim=1)

return rating, indices

def match(self, train_loader, test_loader=None):

with torch.no_grad():

x_train, h_train, l_train = self.extract_features(train_loader)

train_acc = self.knn(h_train, l_train, ok=self.ok)

if test_loader is not None:

x_test, h_test, l_test = self.extract_features(test_loader)

test_acc = self.eval(h_test, l_test)

return train_acc, test_acc

Now we are able to deal with the strategy and BYOL mannequin.

Modify resnet: add MLP projection heads

We’ll begin with a base mannequin (resnet18) and modify it for self-supervised studying. The final layer that usually does the classification is changed with an id operate. The output options of resnet18 might be fed to the MLP projector.

import copy

import torch

from torch import nn

import torch.nn.purposeful as F

class MLP(nn.Module):

def __init__(self, dim, embedding_size=256, hidden_size=2048, batch_norm_mlp=False):

tremendous().__init__()

norm = nn.BatchNorm1d(hidden_size) if batch_norm_mlp else nn.Identification()

self.web = nn.Sequential(

nn.Linear(dim, hidden_size),

norm,

nn.ReLU(inplace=True),

nn.Linear(hidden_size, embedding_size)

)

def ahead(self, x):

return self.web(x)

class AddProjHead(nn.Module):

def __init__(self, mannequin, in_features, layer_name, hidden_size=4096,

embedding_size=256, batch_norm_mlp=True):

tremendous(AddProjHead, self).__init__()

self.spine = mannequin

setattr(self.spine, layer_name, nn.Identification())

self.spine.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

self.spine.maxpool = torch.nn.Identification()

self.projection = MLP(in_features, embedding_size, hidden_size=hidden_size, batch_norm_mlp=batch_norm_mlp)

def ahead(self, x, return_embedding=False):

embedding = self.spine(x)

if return_embedding:

return embedding

return self.projection(embedding)

I additionally changed the primary conv layer of resnet18 from 7×7 to 3×3 convolution since we’re taking part in with 32×32 pictures (CIFAR-10).

Code is out there on GitHub. If you’re planning to solidify your Pytorch information, there are two wonderful books that we extremely suggest: Deep studying with PyTorch from Manning Publications and Machine Studying with PyTorch and Scikit-Study by Sebastian Raschka. You’ll be able to all the time use the 35% low cost code blaisummer21 for all Manning’s merchandise.

The precise BYOL methodology

To date I introduced all of the essential elements to achieve this level. Now we’ll construct the BYOL module with our beloved pupil and trainer networks. Discover that the coed predictor MLP and projector are similar.

My implementation of BYOL was primarily based on lucidrains’ repo. I modified it to make it extra easy and mess around with it.

class BYOL(nn.Module):

def __init__(

self,

web,

batch_norm_mlp=True,

layer_name='fc',

in_features=512,

projection_size=256,

projection_hidden_size=2048,

moving_average_decay=0.99,

use_momentum=True):

"""

Args:

web: mannequin to be educated

batch_norm_mlp: whether or not to make use of batchnorm1d within the mlp predictor and projector

in_features: the quantity options which might be produced by the spine web i.e. resnet

projection_size: the dimensions of the output vector of the 2 similar MLPs

projection_hidden_size: the dimensions of the hidden vector of the 2 similar MLPs

augment_fn2: apply completely different augmentation the second view

moving_average_decay: t hyperparameter to manage the affect within the goal community weight replace

use_momentum: whether or not to replace the goal community

"""

tremendous().__init__()

self.web = web

self.student_model = AddProjHead(mannequin=web, in_features=in_features,

layer_name=layer_name,

embedding_size=projection_size,

hidden_size=projection_hidden_size,

batch_norm_mlp=batch_norm_mlp)

self.use_momentum = use_momentum

self.teacher_model = self._get_teacher()

self.target_ema_updater = EMA(moving_average_decay)

self.student_predictor = MLP(projection_size, projection_size, projection_hidden_size)

@torch.no_grad()

def _get_teacher(self):

return copy.deepcopy(self.student_model)

@torch.no_grad()

def update_moving_average(self):

assert self.use_momentum, 'you do not want to replace the transferring common, since you could have turned off momentum '

'for the goal encoder '

assert self.teacher_model is not None, 'goal encoder has not been created but'

for student_params, teacher_params in zip(self.student_model.parameters(), self.teacher_model.parameters()):

old_weight, up_weight = teacher_params.information, student_params.information

teacher_params.information = self.target_ema_updater.update_average(old_weight, up_weight)

def ahead(

self,

image_one, image_two=None,

return_embedding=False):

if return_embedding or (image_two is None):

return self.student_model(image_one, return_embedding=True)

student_proj_one = self.student_model(image_one)

student_proj_two = self.student_model(image_two)

student_pred_one = self.student_predictor(student_proj_one)

student_pred_two = self.student_predictor(student_proj_two)

with torch.no_grad():

teacher_proj_one = self.teacher_model(image_one).detach_()

teacher_proj_two = self.teacher_model(image_two).detach_()

loss_one = loss_fn(student_pred_one, teacher_proj_one)

loss_two = loss_fn(student_pred_two, teacher_proj_two)

return (loss_one + loss_two).imply()

For CIFAR-10 it’s sufficient to make use of 2048 as a hidden dimension and 256 because the embedding dimension. We’ll practice a resnet18 that outputs 512 options for 100 epochs. The elements of the code that consult with information loading and augmentations are omitted to extend readability. You’ll be able to look them up within the code.

You should utilize the Adam optimizer ( lr=3104lr=3 * 10^{-4}

The one factor that might be modified within the practice code is the EMA replace.

def training_step(mannequin, information):

(view1, view2), _ = information

loss = mannequin(view1.cuda(), view2.cuda())

return loss

def train_one_epoch(mannequin, train_dataloader, optimizer):

mannequin.practice()

total_loss = 0.

num_batches = len(train_dataloader)

for information in train_dataloader:

optimizer.zero_grad()

loss = training_step(mannequin, information)

loss.backward()

optimizer.step()

mannequin.update_moving_average()

total_loss += loss.merchandise()

return total_loss/num_batches

Let’s leap on the outcomes!

Outcomes: KNN accuracy VS pretraining epochs


knn-byol-training

KNN accuracy each 4 epochs. Picture by creator

Isn’t it wonderful that with none labels we are able to attain a validation accuracy of 70%? I discovered this wonderful, particularly for this methodology that appears to be much less delicate to the batch measurement.

However why does the batch measurement has an impact right here? Isn’t it presupposed to be not utilizing detrimental paris? The place does the dependence of the batch measurement come from?

Quick reply: Effectively, it’s batch normalization within the MLP layers!

Right here is the experiments I made to cross-check it.

A notice on batch norm in MLP networks and EMA momentum

I used to be curious to watch the mode collapse with out batch normalization. You’ll be able to attempt that by your self by setting:

mannequin = BYOL(mannequin, in_features=512, batch_norm_mlp=False)

I noticed that the L2 distance goes to nearly zero from the very first epochs:

Epoch 0: loss:0.06423207696957084

Epoch 8: loss:0.005584242034894534

Epoch 20: loss:0.005460431350347323

The loss goes to roughly zero and KNN stops rising (35% VS 60% within the regular setup). That’s why it’s claimed that BYOL implicitly makes use of a type of contrastive studying by leveraging the batch statistics within the MLPs. Right here is the KNN accuracy:


mode-collapse-byol-no-batch-norm

Mode collapse in BYOL by eradicating batch norm in MLPs. Picture by creator

I’m nicely conscious of papers that present that batch statistics aren’t the one situation for BYOL to work. That is an experimental put up, so I’m not going to play that sport. I used to be simply curious to watch mode collapse right here.

Conclusion

For a extra detailed rationalization of the strategy examine Yannic’s video on BYOL:

On this tutorial, we applied BYOL step-by-step and pretrained on CIFAR10. We observe the huge improve in KNN accuracy by matching the representations of the identical picture. A random classifier would have 10% and with 100 epochs we attain 70% KNN validation accuracy with none labels. How cool is that?

To be taught extra about self-supervised studying, keep tuned! Help us by social media sharing, making a donation, or shopping for our Deep studying in Manufacturing e-book. It might be extremely appreciated.

Deep Studying in Manufacturing Ebook 📖

Learn to construct, practice, deploy, scale and preserve deep studying fashions. Perceive ML infrastructure and MLOps utilizing hands-on examples.

Study extra

* Disclosure: Please notice that a few of the hyperlinks above is perhaps affiliate hyperlinks, and at no further price to you, we’ll earn a fee if you happen to determine to make a purchase order after clicking by way of.

Leave a Reply

Your email address will not be published. Required fields are marked *