import matplotlib.pyplot as plt
import torch
%matplotlib inline
%config InlineBackend.figure_format='retina'
Drawing the model (using ONNX and Netron)
# Download some MNIST to demonstrate super-resolution
from torchvision import datasets, transforms
= datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
mnist = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())
mnist_test
# Displaying an image
def show_image(img):
1, 2, 0).squeeze(), cmap='gray')
plt.imshow(img.permute('off')
plt.axis(
# Displaying a batch of images in 1 row and n columns
def show_batch(batch):
= plt.subplots(1, len(batch), figsize=(20, 20))
fig, ax for i, img in enumerate(batch):
1, 2, 0).squeeze(), cmap='gray')
ax[i].imshow(img.permute('off')
ax[i].axis(
0][0].shape mnist[
torch.Size([1, 28, 28])
# Downsample the images
= transforms.Resize(7)
downsample
# First 10000 images X
= [downsample(mnist[i][0]) for i in range(10000)]
mnist_small = torch.stack(mnist_small)
mnist_small
# First 10000 images Y
= torch.stack([mnist[i][0] for i in range(10000)])
mnist_large
# Test set X
= [downsample(mnist_test[i][0]) for i in range(10000)]
mnist_test_small = torch.stack(mnist_test_small)
mnist_test_small
# Test set Y
= torch.stack([mnist_test[i][0] for i in range(10000)]) mnist_test_large
C:\Users\HP\AppData\Roaming\Python\Python311\site-packages\torchvision\transforms\functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).
warnings.warn(
# Show the downsampled images and the original images side-by-side
for i in range(10)]))
show_batch(torch.stack([mnist_small[i]
plt.figure()0] for i in range(10)])) show_batch(torch.stack([mnist[i][
<Figure size 640x480 with 0 Axes>
mnist_small.shape, mnist.data.shape
(torch.Size([10000, 1, 7, 7]), torch.Size([60000, 28, 28]))
import torch
import torch.nn as nn
class SinActivation(nn.Module):
def forward(self, x):
return torch.sin(x)
# Create an instance of the custom SinActivation module
= SinActivation()
sin_activation
class UNet(nn.Module):
def __init__(self, activation=sin_activation):
super(UNet, self).__init__()
# Encoder
self.encoder = nn.Sequential(
1, 16, kernel_size=3, padding=1), # Input: (batch_size, 1, 7, 7), Output: (batch_size, 16, 7, 7)
nn.Conv2d(# Use the custom activation function
activation,16, 32, kernel_size=3, padding=1), # Input: (batch_size, 16, 7, 7), Output: (batch_size, 32, 7, 7)
nn.Conv2d(
activation,=2, stride=2) # Input: (batch_size, 32, 7, 7), Output: (batch_size, 32, 3, 3)
nn.MaxPool2d(kernel_size
)
# Bottleneck
self.bottleneck = nn.Sequential(
32, 64, kernel_size=3, padding=1), # Input: (batch_size, 32, 3, 3), Output: (batch_size, 64, 3, 3)
nn.Conv2d(
activation,
)
# Decoder
self.decoder = nn.Sequential(
64, 32, kernel_size=4, stride=4, padding=0), # Input: (batch_size, 64, 3, 3), Output: (batch_size, 32, 12, 12)
nn.ConvTranspose2d(
activation,# Input (batch_size, 32, 12, 12), Output: (batch_size, 16, 12, 12)
32, 16, kernel_size=3, stride=1, padding=0),
nn.ConvTranspose2d(
activation,# Input (batch_size, 16, 12, 12), Output: (batch_size, 1, 28, 28)
16, 1, kernel_size=4, stride=2, padding=1)
nn.ConvTranspose2d(
)
def forward(self, x):
# Encoder
= self.encoder(x)
x1
# Bottleneck
= self.bottleneck(x1)
x
# Decoder
= self.decoder(x)
x
return x
# Create an instance of the modified UNet model
= UNet(nn.GELU())
model
# Print the model architecture with input and output shape
= 1
batch_size = (batch_size, 1, 7, 7)
input_size = torch.randn(input_size)
dummy_input = model(dummy_input)
output print(model)
print(f"Input shape: {input_size}")
print(f"Output shape: {output.shape}")
UNet(
(encoder): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): GELU(approximate='none')
(2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): GELU(approximate='none')
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(bottleneck): Sequential(
(0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): GELU(approximate='none')
)
(decoder): Sequential(
(0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(4, 4))
(1): GELU(approximate='none')
(2): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
(3): GELU(approximate='none')
(4): ConvTranspose2d(16, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
)
Input shape: (1, 1, 7, 7)
Output shape: torch.Size([1, 1, 28, 28])
#Provide an example input to the model
= 1
batch_size = (batch_size, 1, 7, 7)
input_size = torch.randn(input_size)
dummy_input
# Export the model to ONNX
= "unet_model.onnx"
onnx_path =False)
torch.onnx.export(model, dummy_input, onnx_path, verbose
print("Model exported to ONNX successfully.")
============== Diagnostic Run torch.onnx.export version 2.0.1+cpu ==============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
Model exported to ONNX successfully.
# Input to the model is a batch of 1-channel 7x7 images
= 1
batch_size = (batch_size, 1, 7, 7)
input_size
# Create an instance of the modified UNet model
# Output of the model is a batch of 1-channel 28x28 images
= (batch_size, 1, 28, 28) output_size
# Input to the model is a batch of 1-channel 7x7 images
= 1
batch_size = (batch_size, 1, 7, 7)
input_size
# Create an instance of the modified UNet model
# Output of the model is a batch of 1-channel 28x28 images
= (batch_size, 1, 28, 28) output_size
# Create X_train, Y_train, X_test, Y_test
= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
= mnist_small.float().to(device)
X_train = mnist_large.float().to(device)
Y_train
= mnist_test_small.float().to(device)
X_test = mnist_test_large.float().to(device)
Y_test
X_train.shape, Y_train.shape, X_test.shape, Y_test.shape
= UNet(activation=sin_activation).to(device) model
# Define the loss function
= nn.MSELoss()
loss_fn
# Define the optimizer
= torch.optim.Adam(model.parameters(), lr=3e-4)
optimizer
# Number of epochs
# n_epochs = 5001
= 500
n_epochs
# List to store losses
= []
losses
# Loop over epochs
for epoch in range(n_epochs):
# Forward pass
= model(X_train)
Y_pred
# Compute Loss
= loss_fn(Y_pred, Y_train)
loss
# Print loss
if epoch % 100 == 0:
print(f"Epoch {epoch+1} loss: {loss.item()}")
# Store loss
losses.append(loss.item())
# Zero the gradients
optimizer.zero_grad()
# Backpropagation
loss.backward()
# Update the weights
optimizer.step()
Epoch 1 loss: 0.15383368730545044
Epoch 101 loss: 0.05964525789022446
Epoch 201 loss: 0.04472550377249718
Epoch 301 loss: 0.03792179003357887
Epoch 401 loss: 0.034905895590782166
# Extract a mini-batch of 10 images
= X_train[:10]
X_mini = Y_train[:10]
Y_mini
# Forward pass
= model(X_mini)
Y_hat
# Move the tensors to CPU
= X_mini.cpu()
X_mini = Y_mini.cpu()
Y_mini = Y_hat.cpu()
Y_hat
def plot_images(X_mini, Y_mini, Y_hat=None):
# Plot 3 rows
= 3
rows
# 10 images X 3
# First row: 10 images from the mini-batch
# Second row: 10 ground truth images
# Third row: 10 predicted images
= plt.subplots(rows, 10, figsize=(20, 6))
fig, ax
for i in range(rows):
for j in range(10):
if i == 0:
="gray")
ax[i][j].imshow(X_mini[j].squeeze(), cmapelif i == 1:
="gray")
ax[i][j].imshow(Y_mini[j].squeeze(), cmapelse:
="gray")
ax[i][j].imshow(Y_hat[j].detach().squeeze(), cmap
"off")
ax[i][j].axis(
# Put labels for the three rows using suptitle()
"MNIST Image Generation using U-Net", fontsize=16)
fig.suptitle(
0][0].set_title("Input Images")
ax[1][0].set_title("Ground Truth Images")
ax[2][0].set_title("Predicted Images")
ax[
plot_images(X_mini, Y_mini, Y_hat)