# Writing a better code with pytorch and einops

## Rewriting building blocks of deep learning

Below are some fragments of code taken from official tutorials and popular repositories (fragments taken for educational purposes, sometimes shortened). For each fragment an enhanced version proposed with comments.

In most examples, `einops` was used to make things less complicated. But you'll also find some common recommendations and practices to improve your code.

Left: as it was, Right: improved version

```# start from importing some stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

from einops import rearrange, reduce, asnumpy, parse_shape
from einops.layers.torch import Rearrange, Reduce
```

# Simple ConvNet

```class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)

conv_net_old = Net()
```
```conv_net_new = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
nn.Conv2d(10, 20, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
nn.Dropout2d(),
Rearrange('b c h w -> b (c h w)'),
nn.Linear(320, 50),
nn.ReLU(),
nn.Dropout(),
nn.Linear(50, 10),
nn.LogSoftmax(dim=1)
)
```

Reasons to prefer new implementation:

• in the original code (to the left) if input size is changed and batch size is divisible by 16 (that's usualy so), we'll get something senseless after reshaping
• new code will explicitly raise an error in this case
• we won't forget to use dropout with flag self.training with new version
• code is straightforward to read and analyze
• sequential makes printing / saving / passing trivial. And there is no need in your code to load a model (which also has a number of benefits)
• ... and we could also add inplace for ReLU

# Super-resolution

```class SuperResolutionNetOld(nn.Module):
def __init__(self, upscale_factor):
super(SuperResolutionNetOld, self).__init__()

self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
```
```def SuperResolutionNetNew(upscale_factor):
return nn.Sequential(
nn.Conv2d(1, 64, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, upscale_factor ** 2, kernel_size=3, padding=1),
Rearrange('b (h2 w2) h w -> b (h h2) (w w2)', h2=upscale_factor, w2=upscale_factor),
)
```

Here is the difference:

• no need in special instruction pixel_shuffle (and result is transferrable between frameworks)
• output doesn't contain a fake axis (and we could do the same for the input)
• inplace ReLU used now, for high resolution pictures that becomes critical and saves us much memory
• and all the benefits of nn.Sequential again

# Restyling Gram matrix for style transfer

Original code is already good - first line shows what kind of input is expected

• einsum operation should be read like:
• for each batch and for each pair of channels, we sum over h and w.
• I've also changed normalization, because that's how Gram matrix is defined, otherwise we should call it normalized Gram matrix or alike
```def gram_matrix_old(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
```
```def gram_matrix_new(y):
b, ch, h, w = y.shape
return torch.einsum('bchw,bdhw->bcd', [y, y]) / (h * w)
```

It would be great to use just `'b c1 h w,b c2 h w->b c1 c2'`, but einsum supports only one-letter axes

# Recurrent model

All we did here is just made information about shapes explicit to skip deciphering

```class RNNModelOld(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)

def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

```
```class RNNModelNew(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(p=dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)

def forward(self, input, hidden):
t, b = input.shape
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = rearrange(self.drop(output), 't b nhid -> (t b) nhid')
decoded = rearrange(self.decoder(output), '(t b) token -> t b token', t=t, b=b)
return decoded, hidden
```

# Channel shuffle (from shufflenet)

```def channel_shuffle_old(x, groups):
batchsize, num_channels, height, width = x.data.size()

channels_per_group = num_channels // groups

# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)

# transpose
# - contiguous() required if transpose() is used before view().
#   See https://github.com/pytorch/pytorch/issues/764
x = torch.transpose(x, 1, 2).contiguous()

# flatten
x = x.view(batchsize, -1, height, width)

return x
```
```def channel_shuffle_new(x, groups):
return rearrange(x, 'b (c1 c2) h w -> b (c2 c1) h w', c1=groups)
```

While progress is obvious, this is not the limit. As you'll see below, we don't even need to write these couple of lines.

# Shufflenet

```from collections import OrderedDict

def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()

channels_per_group = num_channels // groups

# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)

# transpose
# - contiguous() required if transpose() is used before view().
#   See https://github.com/pytorch/pytorch/issues/764
x = torch.transpose(x, 1, 2).contiguous()

# flatten
x = x.view(batchsize, -1, height, width)

return x

class ShuffleUnitOld(nn.Module):
def __init__(self, in_channels, out_channels, groups=3,
grouped_conv=True, combine='add'):

super(ShuffleUnitOld, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.grouped_conv = grouped_conv
self.combine = combine
self.groups = groups
self.bottleneck_channels = self.out_channels // 4

# define the type of ShuffleUnit
if self.combine == 'add':
# ShuffleUnit Figure 2b
self.depthwise_stride = 1
self._combine_func = self._add
elif self.combine == 'concat':
# ShuffleUnit Figure 2c
self.depthwise_stride = 2
self._combine_func = self._concat

# ensure output of concat has the same channels as
# original output channels.
self.out_channels -= self.in_channels
else:
raise ValueError("Cannot combine tensors with \"{}\"" \
"Only \"add\" and \"concat\" are" \
"supported".format(self.combine))

# Use a 1x1 grouped or non-grouped convolution to reduce input channels
# to bottleneck channels, as in a ResNet bottleneck module.
# NOTE: Do not use group convolution for the first conv1x1 in Stage 2.
self.first_1x1_groups = self.groups if grouped_conv else 1

self.g_conv_1x1_compress = self._make_grouped_conv1x1(
self.in_channels,
self.bottleneck_channels,
self.first_1x1_groups,
batch_norm=True,
relu=True
)

# 3x3 depthwise convolution followed by batch normalization
self.depthwise_conv3x3 = conv3x3(
self.bottleneck_channels, self.bottleneck_channels,
stride=self.depthwise_stride, groups=self.bottleneck_channels)
self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)

# Use 1x1 grouped convolution to expand from
# bottleneck_channels to out_channels
self.g_conv_1x1_expand = self._make_grouped_conv1x1(
self.bottleneck_channels,
self.out_channels,
self.groups,
batch_norm=True,
relu=False
)

@staticmethod
def _add(x, out):
# residual connection
return x + out

@staticmethod
def _concat(x, out):
# concatenate along channel axis
return torch.cat((x, out), 1)

def _make_grouped_conv1x1(self, in_channels, out_channels, groups,
batch_norm=True, relu=False):

modules = OrderedDict()
conv = conv1x1(in_channels, out_channels, groups=groups)
modules['conv1x1'] = conv

if batch_norm:
modules['batch_norm'] = nn.BatchNorm2d(out_channels)
if relu:
modules['relu'] = nn.ReLU()
if len(modules) > 1:
return nn.Sequential(modules)
else:
return conv

def forward(self, x):
# save for combining later with output
residual = x
if self.combine == 'concat':
residual = F.avg_pool2d(residual, kernel_size=3,
stride=2, padding=1)

out = self.g_conv_1x1_compress(x)
out = channel_shuffle(out, self.groups)
out = self.depthwise_conv3x3(out)
out = self.bn_after_depthwise(out)
out = self.g_conv_1x1_expand(out)

out = self._combine_func(residual, out)
return F.relu(out)
```
```class ShuffleUnitNew(nn.Module):
def __init__(self, in_channels, out_channels, groups=3,
grouped_conv=True, combine='add'):
super().__init__()
first_1x1_groups = groups if grouped_conv else 1
bottleneck_channels = out_channels // 4
self.combine = combine
if combine == 'add':
# ShuffleUnit Figure 2b
self.left = Rearrange('...->...') # identity
depthwise_stride = 1
else:
# ShuffleUnit Figure 2c
self.left = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
depthwise_stride = 2
# ensure output of concat has the same channels as original output channels.
out_channels -= in_channels
assert out_channels > 0

self.right = nn.Sequential(
# Use a 1x1 grouped or non-grouped convolution to reduce input channels
# to bottleneck channels, as in a ResNet bottleneck module.
conv1x1(in_channels, bottleneck_channels, groups=first_1x1_groups),
nn.BatchNorm2d(bottleneck_channels),
nn.ReLU(inplace=True),
# channel shuffle
Rearrange('b (c1 c2) h w -> b (c2 c1) h w', c1=groups),
# 3x3 depthwise convolution followed by batch
conv3x3(bottleneck_channels, bottleneck_channels,
stride=depthwise_stride, groups=bottleneck_channels),
nn.BatchNorm2d(bottleneck_channels),
# Use 1x1 grouped convolution to expand from
# bottleneck_channels to out_channels
conv1x1(bottleneck_channels, out_channels, groups=groups),
nn.BatchNorm2d(out_channels),
)

def forward(self, x):
if self.combine == 'add':
combined = self.left(x) + self.right(x)
else:
combined = torch.cat([self.left(x), self.right(x)], dim=1)
return F.relu(combined, inplace=True)
```

Rewriting the code helped to identify:

• There is no sense in doing reshuffling and not using groups in the first convolution (indeed, in the paper it is not so). However, result is an equivalent model.
• It is also strange that the first convolution may be not grouped, while the last convolution is always grouped (and that is different from the paper)

Other comments:

• There is an identity layer for pytorch introduced here
• The last thing left is get rid of conv1x1 and conv3x3 in the code - those are not better than standard

# Simplifying ResNet

```class ResNetOld(nn.Module):

def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNetOld, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)

self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)

return x
```
```def make_layer(inplanes, planes, block, n_blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
# output size won't match input, so adjust residual
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
return nn.Sequential(
block(inplanes, planes, stride, downsample),
*[block(planes * block.expansion, planes) for _ in range(1, n_blocks)]
)

def ResNetNew(block, layers, num_classes=1000):
e = block.expansion

resnet = nn.Sequential(
Rearrange('b c h w -> b c h w', c=3, h=224, w=224),
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
make_layer(64,      64,  block, layers[0], stride=1),
make_layer(64 * e,  128, block, layers[1], stride=2),
make_layer(128 * e, 256, block, layers[2], stride=2),
make_layer(256 * e, 512, block, layers[3], stride=2),
# combined AvgPool and view in one averaging operation
Reduce('b c h w -> b c', 'mean'),
nn.Linear(512 * e, num_classes),
)

# initialization
for m in resnet.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return resnet
```

Changes:

• explicit check for input shape
• no views and simple sequential structure, output is just nn.Sequential, so can always be saved/passed/etc
• no need in AvgPool and additional views, this place is much clearer now
• `make_layer` doesn't use internal state (that's quite faulty place)

# Improving RNN language modelling

```class RNNOld(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
super().__init__()

self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers,
bidirectional=bidirectional, dropout=dropout)
self.fc = nn.Linear(hidden_dim*2, output_dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
#x = [sent len, batch size]

embedded = self.dropout(self.embedding(x))

#embedded = [sent len, batch size, emb dim]

output, (hidden, cell) = self.rnn(embedded)

#output = [sent len, batch size, hid dim * num directions]
#hidden = [num layers * num directions, batch size, hid dim]
#cell = [num layers * num directions, batch size, hid dim]

#concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers
#and apply dropout

hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))

#hidden = [batch size, hid dim * num directions]

return self.fc(hidden.squeeze(0))
```
```class RNNNew(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
super().__init__()

self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers,
bidirectional=bidirectional, dropout=dropout)
self.dropout = nn.Dropout(dropout)
self.directions = 2 if bidirectional else 1
self.fc = nn.Linear(hidden_dim * self.directions, output_dim)

def forward(self, x):
#x = [sent len, batch size]
embedded = self.dropout(self.embedding(x))

#embedded = [sent len, batch size, emb dim]
output, (hidden, cell) = self.rnn(embedded)

hidden = rearrange(hidden, '(layer dir) b c -> layer b (dir c)',
dir=self.directions)
# take the final layer's hidden
return self.fc(self.dropout(hidden[-1]))
```
• original code misbehaves for non-bidirectional models
• ... and fails when bidirectional = False, and there is only one layer
• modification of the code shows both how hidden is structured and how it is modified

# Writing FastText faster

```class FastTextOld(nn.Module):
def __init__(self, vocab_size, embedding_dim, output_dim):
super().__init__()

self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.fc = nn.Linear(embedding_dim, output_dim)

def forward(self, x):

#x = [sent len, batch size]

embedded = self.embedding(x)

#embedded = [sent len, batch size, emb dim]

embedded = embedded.permute(1, 0, 2)

#embedded = [batch size, sent len, emb dim]

pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)

#pooled = [batch size, embedding_dim]

return self.fc(pooled)
```
```def FastTextNew(vocab_size, embedding_dim, output_dim):
return nn.Sequential(
Rearrange('t b -> t b'),
nn.Embedding(vocab_size, embedding_dim),
Reduce('t b c -> b c', 'mean'),
nn.Linear(embedding_dim, output_dim),
Rearrange('b c -> b c'),
)
```

Some comments on new code:

• first and last operations do nothing and can be removed
• but were put to explicitly show expected input and output
• this also gives you a flexibility of changing interface by editing a single line. Should you need to accept inputs as (batch, time), you just change first line to `Rearrange('b t -> t b'),`

# CNNs for text classification

```class CNNOld(nn.Module):
def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.conv_0 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[0],embedding_dim))
self.conv_1 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[1],embedding_dim))
self.conv_2 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[2],embedding_dim))
self.fc = nn.Linear(len(filter_sizes)*n_filters, output_dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x):

#x = [sent len, batch size]

x = x.permute(1, 0)

#x = [batch size, sent len]

embedded = self.embedding(x)

#embedded = [batch size, sent len, emb dim]

embedded = embedded.unsqueeze(1)

#embedded = [batch size, 1, sent len, emb dim]

conved_0 = F.relu(self.conv_0(embedded).squeeze(3))
conved_1 = F.relu(self.conv_1(embedded).squeeze(3))
conved_2 = F.relu(self.conv_2(embedded).squeeze(3))

#conv_n = [batch size, n_filters, sent len - filter_sizes[n]]

pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2)
pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)
pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)

#pooled_n = [batch size, n_filters]

cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim=1))

#cat = [batch size, n_filters * len(filter_sizes)]

return self.fc(cat)
```
```class CNNNew(nn.Module):
def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.convs = nn.ModuleList([
nn.Conv1d(embedding_dim, n_filters, kernel_size=size) for size in filter_sizes
])
self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
x = rearrange(x, 't b -> t b')
emb = rearrange(self.embedding(x), 't b c -> b c t')
pooled = [reduce(conv(emb), 'b c t -> b c', 'max') for conv in self.convs]
concatenated = rearrange(pooled, 'filter b c -> b (filter c)')
return self.fc(self.dropout(F.relu(concatenated)))
```
• Original code misuses Conv2d, while Conv1d is the right choice
• Fixed code can work with any number of filter_sizes (and won't fail)
• First line in new code does nothing, but was added for simplicity

# Highway convolutions

• Highway convolutions are common in TTS systems. Code below makes splitting a bit more explicit.
• Splitting policy may eventually turn out to be important if input had previously groups over channel axes (group convolutions or bidirectional LSTMs/GRUs)
• Same applies to GLU and gated units in general
```class HighwayConv1dOld(nn.Conv1d):
def forward(self, inputs):
L = super(HighwayConv1dOld, self).forward(inputs)
H1, H2 = torch.chunk(L, 2, 1)  # chunk at the feature dim
torch.sigmoid_(H1)
return H1 * H2 + (1.0 - H1) * inputs
```
```class HighwayConv1dNew(nn.Conv1d):
def forward(self, inputs):
L = super().forward(inputs)
H1, H2 = rearrange(L, 'b (split c) t -> split b c t', split=2)
torch.sigmoid_(H1)
return H1 * H2 + (1.0 - H1) * inputs
```

# Tacotron's CBHG module

```class CBHG_Old(nn.Module):
"""CBHG module: a recurrent neural network composed of:
- 1-d convolution banks
- Highway networks + residual connections
- Bidirectional gated recurrent units
"""

def __init__(self, in_dim, K=16, projections=[128, 128]):
super(CBHG, self).__init__()
self.in_dim = in_dim
self.relu = nn.ReLU()
self.conv1d_banks = nn.ModuleList(
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
padding=k // 2, activation=self.relu)
for k in range(1, K + 1)])
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)

in_sizes = [K * in_dim] + projections[:-1]
activations = [self.relu] * (len(projections) - 1) + [None]
self.conv1d_projections = nn.ModuleList(
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
for (in_size, out_size, ac) in zip(
in_sizes, projections, activations)])

self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
self.highways = nn.ModuleList(
[Highway(in_dim, in_dim) for _ in range(4)])

self.gru = nn.GRU(
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
```
```def forward_old(self, inputs):
# (B, T_in, in_dim)
x = inputs

# Needed to perform conv1d on time-axis
# (B, in_dim, T_in)
if x.size(-1) == self.in_dim:
x = x.transpose(1, 2)

T = x.size(-1)

# (B, in_dim*K, T_in)
# Concat conv1d bank outputs
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T]

for conv1d in self.conv1d_projections:
x = conv1d(x)

# (B, T_in, in_dim)
# Back to the original shape
x = x.transpose(1, 2)

if x.size(-1) != self.in_dim:
x = self.pre_highway(x)

# Residual connection
x += inputs
for highway in self.highways:
x = highway(x)

# (B, T_in, in_dim*2)
outputs, _ = self.gru(x)

return outputs
```
```def forward_new(self, inputs, input_lengths=None):
x = rearrange(inputs, 'b t c -> b c t')
_, _, T = x.shape
# Concat conv1d bank outputs
x = rearrange([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks],
'bank b c t -> b (bank c) t', c=self.in_dim)
x = self.max_pool1d(x)[:, :, :T]

for conv1d in self.conv1d_projections:
x = conv1d(x)
x = rearrange(x, 'b c t -> b t c')
if x.size(-1) != self.in_dim:
x = self.pre_highway(x)

# Residual connection
x += inputs
for highway in self.highways:
x = highway(x)

# (B, T_in, in_dim*2)
outputs, _ = self.gru(self.highways(x))

return outputs
```

There is still a large room for improvements, but in this example only forward function was changed

# Simple attention

Good news: there is no more need to guess order of dimensions. Neither for inputs nor for outputs

```class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()

def forward(self, K, V, Q):
A = torch.bmm(K.transpose(1,2), Q) / np.sqrt(Q.shape[1])
A = F.softmax(A, 1)
R = torch.bmm(V, A)
return torch.cat((R, Q), dim=1)
```
```def attention(K, V, Q):
_, n_channels, _ = K.shape
A = torch.einsum('bct,bcl->btl', [K, Q])
A = F.softmax(A * n_channels ** (-0.5), 1)
R = torch.einsum('bct,btl->bcl', [V, A])
return torch.cat((R, Q), dim=1)
```

# Transformer's attention needs more attention

```class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''

def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=2)

def forward(self, q, k, v, mask=None):

attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature

if mask is not None:
attn = attn.masked_fill(mask, -np.inf)

attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)

return output, attn

class MultiHeadAttentionOld(nn.Module):
''' Multi-Head Attention module '''

def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()

self.n_head = n_head
self.d_k = d_k
self.d_v = d_v

self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
self.layer_norm = nn.LayerNorm(d_model)

self.fc = nn.Linear(n_head * d_v, d_model)
nn.init.xavier_normal_(self.fc.weight)

self.dropout = nn.Dropout(dropout)

def forward(self, q, k, v, mask=None):

d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()

residual = q

q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)

output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)

return output, attn
```
```class MultiHeadAttentionNew(nn.Module):
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()
self.n_head = n_head

self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)

nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

self.fc = nn.Linear(n_head * d_v, d_model)
nn.init.xavier_normal_(self.fc.weight)
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model)

def forward(self, q, k, v, mask=None):
residual = q
q = rearrange(self.w_qs(q), 'b l (head k) -> head b l k', head=self.n_head)
k = rearrange(self.w_ks(k), 'b t (head k) -> head b t k', head=self.n_head)
v = rearrange(self.w_vs(v), 'b t (head v) -> head b t v', head=self.n_head)
attn = torch.einsum('hblk,hbtk->hblt', [q, k]) / np.sqrt(q.shape[-1])
if mask is not None:
attn = attn.masked_fill(mask[None], -np.inf)
attn = torch.softmax(attn, dim=3)
output = torch.einsum('hblt,hbtv->hblv', [attn, v])
output = rearrange(output, 'head b l v -> b l (head v)')
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn

```

Benefits of new implementation

• we have one module, not two
• now code does not fail for None mask
• the amount of caveats in the original code that we removed is huge. Try erasing comments and deciphering what happens there

# Self-attention GANs

SAGANs are currently SotA for image generation, and can be simplified using same tricks.

```class Self_Attn_Old(nn.Module):
""" Self attention Layer"""
def __init__(self,in_dim,activation):
super(Self_Attn_Old,self).__init__()
self.chanel_in = in_dim
self.activation = activation

self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))

self.softmax  = nn.Softmax(dim=-1) #

def forward(self, x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""

m_batchsize,C,width ,height = x.size()
proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
energy =  torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)

out = self.gamma*out + x
return out,attention
```
```class Self_Attn_New(nn.Module):
""" Self attention Layer"""
def __init__(self, in_dim):
super().__init__()
self.query_conv = nn.Conv2d(in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros([1]))

def forward(self, x):
proj_query = rearrange(self.query_conv(x), 'b c h w -> b (h w) c')
proj_key = rearrange(self.key_conv(x), 'b c h w -> b c (h w)')
proj_value = rearrange(self.value_conv(x), 'b c h w -> b (h w) c')
energy = torch.bmm(proj_query, proj_key)
attention = F.softmax(energy, dim=2)
out = torch.bmm(attention, proj_value)
out = x + self.gamma * rearrange(out, 'b (h w) c -> b c h w',
**parse_shape(x, 'b c h w'))
return out, attention
```

# Improving time sequence prediction

While this example was considered to be simplistic, I had to analyze surrounding code to understand what kind of input was expected. You can try yourself.

Additionally now the code works with any dtype, not only double; and new code supports using GPU.

```class SequencePredictionOld(nn.Module):
def __init__(self):
super(SequencePredictionOld, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 51)
self.linear = nn.Linear(51, 1)

def forward(self, input, future = 0):
outputs = []
h_t = torch.zeros(input.size(0), 51, dtype=torch.double)
c_t = torch.zeros(input.size(0), 51, dtype=torch.double)
h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)

for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]

for i in range(future):# if we should predict the future
h_t, c_t = self.lstm1(output, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
outputs = torch.stack(outputs, 1).squeeze(2)
return outputs
```
```class SequencePredictionNew(nn.Module):
def __init__(self):
super(SequencePredictionNew, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 51)
self.linear = nn.Linear(51, 1)

def forward(self, input, future=0):
b, t = input.shape
h_t, c_t, h_t2, c_t2 = torch.zeros(4, b, 51, dtype=self.linear.weight.dtype,
device=self.linear.weight.device)

outputs = []
for input_t in rearrange(input, 'b t -> t b ()'):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]

for i in range(future): # if we should predict the future
h_t, c_t = self.lstm1(output, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
return rearrange(outputs, 't b () -> b t')
```

# Transforming spacial transformer network (STN)

```class SpacialTransformOld(nn.Module):
def __init__(self):
super(Net, self).__init__()

# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)

# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)

# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)

grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)

return x
```
```class SpacialTransformNew(nn.Module):
def __init__(self):
super(Net, self).__init__()
# Spatial transformer localization-network
linear = nn.Linear(32, 3 * 2)
# Initialize the weights/bias with identity transformation
linear.weight.data.zero_()
linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

self.compute_theta = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
Rearrange('b c h w -> b (c h w)', h=3, w=3),
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
linear,
Rearrange('b (row col) -> b row col', row=2, col=3),
)

# Spatial transformer network forward function
def stn(self, x):
grid = F.affine_grid(self.compute_theta(x), x.size())
return F.grid_sample(x, grid)
```
• new code will give reasonable errors when passed image size is different from expected
• if batch size is divisible by 18, whatever you input in the old code, it'll fail no sooner than affine_grid.

# Improving GLOW

That's a good old depth-to-space written manually!

Since GLOW is revertible, it will frequently rely on `rearrange`-like operations.

```def unsqueeze2d_old(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
factor2 = factor ** 2
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert C % (factor2) == 0, "{}".format(C)
x = input.view(B, C // factor2, factor, factor, H, W)
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(B, C // (factor2), H * factor, W * factor)
return x

def squeeze2d_old(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert H % factor == 0 and W % factor == 0, "{}".format((H, W))
x = input.view(B, C, H // factor, factor, W // factor, factor)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(B, C * factor * factor, H // factor, W // factor)
return x
```
```def unsqueeze2d_new(input, factor=2):
return rearrange(input, 'b (c h2 w2) h w -> b c (h h2) (w w2)', h2=factor, w2=factor)

def squeeze2d_new(input, factor=2):
return rearrange(input, 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2=factor, w2=factor)
```
• term `squeeze` isn't very helpful: which dimension is squeezed? There is `torch.squeeze`, but it's very different.
• in fact, we could skip creating functions completely

# Detecting problems in YOLO detection

```def YOLO_prediction_old(input, num_classes, num_anchors, anchors, stride_h, stride_w):
bs = input.size(0)
in_h = input.size(2)
in_w = input.size(3)
scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in anchors]

prediction = input.view(bs, num_anchors,
5 + num_classes, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
# Get outputs
x = torch.sigmoid(prediction[..., 0])  # Center x
y = torch.sigmoid(prediction[..., 1])  # Center y
w = prediction[..., 2]  # Width
h = prediction[..., 3]  # Height
conf = torch.sigmoid(prediction[..., 4])  # Conf
pred_cls = torch.sigmoid(prediction[..., 5:])  # Cls pred.

FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
# Calculate offsets for each grid
grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_w, 1).repeat(
bs * num_anchors, 1, 1).view(x.shape).type(FloatTensor)
grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_h, 1).t().repeat(
bs * num_anchors, 1, 1).view(y.shape).type(FloatTensor)
# Calculate anchor w, h
anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
# Add offset and scale with anchors
pred_boxes = FloatTensor(prediction[..., :4].shape)
pred_boxes[..., 0] = x.data + grid_x
pred_boxes[..., 1] = y.data + grid_y
pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
# Results
_scale = torch.Tensor([stride_w, stride_h] * 2).type(FloatTensor)
output = torch.cat((pred_boxes.view(bs, -1, 4) * _scale,
conf.view(bs, -1, 1), pred_cls.view(bs, -1, num_classes)), -1)
return output
```
```def YOLO_prediction_new(input, num_classes, num_anchors, anchors, stride_h, stride_w):
raw_predictions = rearrange(input, 'b (anchor prediction) h w -> prediction b anchor h w',
anchor=num_anchors, prediction=5 + num_classes)
anchors = torch.FloatTensor(anchors).to(input.device)
anchor_sizes = rearrange(anchors, 'anchor dim -> dim () anchor () ()')

_, _, _, in_h, in_w = raw_predictions.shape
grid_h = rearrange(torch.arange(in_h).float(), 'h -> () () h ()').to(input.device)
grid_w = rearrange(torch.arange(in_w).float(), 'w -> () () () w').to(input.device)

predicted_bboxes = torch.zeros_like(raw_predictions)
predicted_bboxes[0] = (raw_predictions[0].sigmoid() + grid_w) * stride_w  # center x
predicted_bboxes[1] = (raw_predictions[1].sigmoid() + grid_h) * stride_h  # center y
predicted_bboxes[2:4] = (raw_predictions[2:4].exp()) * anchor_sizes  # bbox width and height
predicted_bboxes[4] = raw_predictions[4].sigmoid()  # confidence
predicted_bboxes[5:] = raw_predictions[5:].sigmoid()  # class predictions
# merging all predicted bboxes for each image
return rearrange(predicted_bboxes, 'prediction b anchor h w -> b (anchor h w) prediction')
```

We changed and fixed a lot:

• new code won't fail if input is not on the first GPU
• old code has wrong grid_x and grid_y for non-square images
• new code doesn't use replication when broadcasting is sufficient
• old code strangely sometimes takes `.data`, but this has no real effect, as some branches preserve gradient till the end
• if gradients not needed, torch.no_grad should be used, so it's redundant

# Simpler output for a bunch of pictures

Next time you need to output drawings of you generative models, you can use this trick

```device = 'cpu'
plt.imshow(np.transpose(vutils.make_grid(fake_batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
```
```padded = F.pad(fake_batch[:64], [1, 1, 1, 1])
plt.imshow(rearrange(padded, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=8).cpu())
```

# Instead of conclusion

Better code is a vague term; to be specific, code is expected to be:

• reliable: does what expected and does not fail. Explicitly fails for wrong inputs
• maintainable and modifiable
• reusable: understanding and modifying code should be easier than writing from scratch
• fast: in my measurements, proposed versions have speed similar to the original code
• readability counts, as a mean to achieve previous goals

Provided examples show how to improve these criteria for deep learning code. And `einops` helps a lot.

# Links

• pytorch and einops
• significant part of the code was taken from the official examples and tutorials
• (references for other code are given in source of this html)
• einops has a tutorial for a more gentle introduction