import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
from sklearn import preprocessing
from tqdm import trange
from PIL import Image
import os
from tabulate import tabulate
Hypernet and Neural Processes on CelebA
Our goal is constructing the whole image from its few context points.
Motivation: Why are we doing this? Say, for example, you have less bandwidth, but you want to send an image to your friend. What you can do is to compress the image to few context points using NN model and give this compressed image and the model to your friend. Your friend can use this model and these few context points of the original image to reconstruct the whole image. We can also use this for image inpainting, super-resolution, etc.
Simple ways can be as follows, but we have some issues with these approaches: - For this, if we learn a single neural network for all images, then this will be a general model, but at testing time, we don’t have the whole image, so we won’t be able to use the same model. - And if we learn a neural network for each image, then this will be a task-specific model, but at testing time, we don’t know which image model to pick for a new image.
So, we use a Meta learning setup using hypernet and neural processes to learn a task-specific neural network that predicts the whole image given a few context points of an image.
Why meta? Because we are learning a model that learns the parameters of another model.
We use our version of the following to reproduce Figure 4 from the paper referenced at link conditional neural network paper. - Hypernet - Neural Processes
# select gpu
= torch.device("cuda:3")
device print(device)
= device #torch.cuda.current_device()
current_device = torch.cuda.get_device_name(current_device)
device_name print(f"Current GPU assigned: {current_device}, Name: {device_name}")
cuda:3
Current GPU assigned: cuda:3, Name: NVIDIA A100-SXM4-80GB
Loading and preprocessing
You can download CelebA data from link. We only need images for our task. Extract these images in the same folder as this notebook.
There are 2,02,599 images in total in CelebA dataset. We will use only 10,000 images for training and 2599 for test dataset. For better results, you can use full dataset.
class CustomImageDataset(Dataset):
def __init__(self, data_root, transform=None):
self.data_root = data_root
self.transform = transform
self.image_files = [f for f in os.listdir(data_root) if f.endswith('.jpg')]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
= os.path.join(self.data_root, self.image_files[idx])
img_name = Image.open(img_name)
image
if self.transform:
= self.transform(image)
image
return image
= 1 # keep this to 1
batch_size = 32 # Change as needed
img_size
# Specify the root directory where the dataset is located
= 'data/celeba/img_align_celeba_10000'
data_root
# Define the data transformations
= transforms.Compose([
transform # transforms.Resize((img_size, img_size)), # Resize the images to a common size (adjust as needed)
# Convert images to tensors
transforms.ToTensor(),
])# default shape is torch.Size([3, 218, 178])
# Create the custom dataset
= CustomImageDataset(data_root, transform=transform)
celeba_dataset
# Create a data loader
= DataLoader(celeba_dataset, batch_size=batch_size, shuffle=False) data_loader
Original image
'chw -> hwc', data_loader.dataset[33])) plt.imshow(torch.einsum(
<matplotlib.image.AxesImage at 0x7f800aac09d0>
# Define the data transformations
= transforms.Compose([
transform # Resize the images to a common size (adjust as needed)
transforms.Resize((img_size, img_size)), # Convert images to tensors
transforms.ToTensor(),
])# default shape is torch.Size([3, 218, 178])
# Create the custom dataset
= CustomImageDataset(data_root, transform=transform)
celeba_dataset
# Create a data loader
= DataLoader(celeba_dataset, batch_size=batch_size, shuffle=False) data_loader
Original image after transformation (applying resize to 32x32 to reduce computational cost)
'chw -> hwc', data_loader.dataset[33])) plt.imshow(torch.einsum(
<matplotlib.image.AxesImage at 0x7ff83c94fdf0>
Create a coordinate dataset from the image
def create_coordinate_map(img):
"""
img: torch.Tensor of shape (num_channels, height, width)
return: tuple of torch.Tensor of shape (height* width, 2) and torch.tensor containing the (num_channels)
"""
= img.shape
num_channels, height, width
# Create a 2D grid of (x,y) coordinates
= torch.arange(width).repeat(height, 1)
x_coords = torch.arange(height).repeat(width, 1).t()
y_coords = x_coords.reshape(-1)
x_coords = y_coords.reshape(-1)
y_coords
# Combine the x and y coordinates into a single tensor
= torch.stack([x_coords, y_coords], dim=1).float()
X
# Move X to GPU if available
= X.to(device)
X
# Create a tensor containing the image pixel values
= img.reshape(-1, num_channels).float().to(device)
Y return X, Y
Below is our Loss function. It is negative log likelihood loss. We will use this to train our model. This takes: - y_pred which is the mean of the predicted distribution and y_true which is the actual value for RGB channels. - log_sigma which is the log of standard deviation of the predicted distribution for each RGB channel. - y_true which is the actual value for RGB channels.
We convert list of sigma into identity covariance matrix and then calculate the log probability of the actual value given the predicted distribution with mean list y_pred for each RGB channel and covariance matrix of log_sigma.
def neg_loglikelyhood(y_pred,log_sigma,y_true):
= torch.diag_embed(log_sigma.exp())
cov_matrix = torch.distributions.MultivariateNormal(y_pred,cov_matrix,validate_args=False)
dist return - dist.log_prob(y_true).sum()
To count the parameter in any model, you can use following code:
def count_params(model):
# return torch.sum(p.numel() for p in model.parameters() if p.requires_grad)
return torch.sum(torch.tensor([p.numel() for p in model.parameters()]))
Hyper Network
Training phase architecture
Model defination
Target net defination
# Create a MLP with 5 hidden layers with 256 neurons each and ReLU activations.
# Input is (x, y) and output is (r, g, b) or (g) for grayscale
# here we output 6 values (3 for RGB mean and 3 for RGB std)
= 128 # hidden dim of model
s
class TargetNet(nn.Module):
def _init_siren(self, activation_scale):
self.fc1.weight.data.uniform_(-1/self.fc1.in_features, 1/self.fc1.in_features)
for layers in [self.fc2, self.fc3, self.fc4, self.fc5]:
-np.sqrt(6/self.fc2.in_features)/activation_scale,
layers.weight.data.uniform_(6/self.fc2.in_features)/activation_scale)
np.sqrt(
def __init__(self, activation=torch.relu, n_out=1, activation_scale=1.0):
super().__init__()
self.activation = activation
self.activation_scale = activation_scale
self.fc1 = nn.Linear(2, s) # input size is 2 (x, y) location of pixel
self.fc2 = nn.Linear(s, s)
self.fc3 = nn.Linear(s, s)
self.fc4 = nn.Linear(s, s)
self.fc5 = nn.Linear(s, n_out) #gray scale image (1) or RGB (3)
if self.activation == torch.sin:
# init weights and biases for sine activation
self._init_siren(activation_scale=self.activation_scale)
def forward(self, x):
= self.activation(self.activation_scale*self.fc1(x))
x = self.activation(self.activation_scale*self.fc2(x))
x = self.activation(self.activation_scale*self.fc3(x))
x = self.activation(self.activation_scale*self.fc4(x))
x return self.fc5(x)
Hypernetwork defination
Input: (x, y, R, G, B)
Output: Our Hypernetwork should have the output equal to the number of parameters in the main network.
# pass total params of target network before calling the hypernetwork model
class HyperNet(nn.Module):
def __init__(self, total_params, num_neurons=128, activation=torch.relu):
super().__init__()
self.activation = activation
self.n_out = total_params
self.fc1 = nn.Linear(5, num_neurons)
self.fc2 = nn.Linear(num_neurons, num_neurons)
self.fc3 = nn.Linear(num_neurons, self.n_out)
def forward(self, x):
= self.activation(self.fc1(x))
x = self.activation(self.fc2(x))
x return self.fc3(x)
Initialize the model and input
Initialize the target network
from torchinfo import summary
= TargetNet(activation=torch.relu, n_out=6, activation_scale=1).to(device)
targetnet =(img_size* img_size, 2)) #32*32 =1024 is the image size lentgh, 2 is x,y coordinate
summary(targetnet, input_size# outputs 6: 1,2,3 mean of each channel and 4,5,6 are log sigma of each channel
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
TargetNet [1024, 6] --
├─Linear: 1-1 [1024, 128] 384
├─Linear: 1-2 [1024, 128] 16,512
├─Linear: 1-3 [1024, 128] 16,512
├─Linear: 1-4 [1024, 128] 16,512
├─Linear: 1-5 [1024, 6] 774
==========================================================================================
Total params: 50,694
Trainable params: 50,694
Non-trainable params: 0
Total mult-adds (M): 51.91
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 4.24
Params size (MB): 0.20
Estimated Total Size (MB): 4.45
==========================================================================================
targetnet
TargetNet(
(fc1): Linear(in_features=2, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=128, bias=True)
(fc3): Linear(in_features=128, out_features=128, bias=True)
(fc4): Linear(in_features=128, out_features=128, bias=True)
(fc5): Linear(in_features=128, out_features=6, bias=True)
)
count_params(targetnet)
tensor(50694)
initialize the hypernetwork model
= HyperNet(total_params=count_params(targetnet), activation=torch.sin).to(device)
hypernet print(hypernet)
HyperNet(
(fc1): Linear(in_features=5, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=128, bias=True)
(fc3): Linear(in_features=128, out_features=50694, bias=True)
)
=(img_size* img_size,5)) # 32*32 = 1024 is the image size length, 5 is the input(x,y,r,g,b) to hypernet summary(hypernet,input_size
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
HyperNet [1024, 50694] --
├─Linear: 1-1 [1024, 128] 768
├─Linear: 1-2 [1024, 128] 16,512
├─Linear: 1-3 [1024, 50694] 6,539,526
==========================================================================================
Total params: 6,556,806
Trainable params: 6,556,806
Non-trainable params: 0
Total mult-adds (G): 6.71
==========================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 417.38
Params size (MB): 26.23
Estimated Total Size (MB): 443.63
==========================================================================================
= []
table_data = 0
total_params = 0
start = {}
start_end_mapping for name, param in targetnet.named_parameters():
= torch.prod(torch.tensor(param.shape)).item()
param_count += param_count
total_params = total_params
end
table_data.append([name, param.shape, param_count, start, end])= (start, end)
start_end_mapping[name] = end
start
print(tabulate(table_data, headers=["Layer Name", "Shape", "Parameter Count", "Start Index", "End Index"]))
print(f"Total number of parameters: {total_params}")
Layer Name Shape Parameter Count Start Index End Index
------------ ---------------------- ----------------- ------------- -----------
fc1.weight torch.Size([128, 2]) 256 0 256
fc1.bias torch.Size([128]) 128 256 384
fc2.weight torch.Size([128, 128]) 16384 384 16768
fc2.bias torch.Size([128]) 128 16768 16896
fc3.weight torch.Size([128, 128]) 16384 16896 33280
fc3.bias torch.Size([128]) 128 33280 33408
fc4.weight torch.Size([128, 128]) 16384 33408 49792
fc4.bias torch.Size([128]) 128 49792 49920
fc5.weight torch.Size([6, 128]) 768 49920 50688
fc5.bias torch.Size([6]) 6 50688 50694
Total number of parameters: 50694
Initialize the input
= create_coordinate_map(data_loader.dataset[0])
corr, vals corr, vals
(tensor([[ 0., 0.],
[ 1., 0.],
[ 2., 0.],
...,
[29., 31.],
[30., 31.],
[31., 31.]], device='cuda:3'),
tensor([[0.4510, 0.4706, 0.4824],
[0.4745, 0.4745, 0.4471],
[0.4667, 0.4353, 0.5412],
...,
[0.0314, 0.0549, 0.0471],
[0.0431, 0.0392, 0.0510],
[0.0549, 0.0392, 0.0549]], device='cuda:3'))
= preprocessing.MinMaxScaler().fit(corr.cpu())
scaler_img = torch.tensor(scaler_img.transform(corr.cpu())).float().to(device)
xy xy, xy.shape
(tensor([[0.0000, 0.0000],
[0.0323, 0.0000],
[0.0645, 0.0000],
...,
[0.9355, 1.0000],
[0.9677, 1.0000],
[1.0000, 1.0000]], device='cuda:3'),
torch.Size([1024, 2]))
Training loop
=20
n_epochs= 0.003
lr
= TargetNet(activation=torch.relu, n_out=6, activation_scale=1).to(device)
targetnet = HyperNet(total_params=count_params(targetnet), activation=torch.relu).to(device)
hypernet = optim.Adam(hypernet.parameters(),lr=lr) # only hypernet is updated
optimizer
= 100
n_context print("Context Points=",n_context)
for epoch in trange(n_epochs):
= np.array(random.sample(range(1023),n_context))
c_idx
print("Epoch=",epoch+1)
= 0
epoch_loss =1
i
for data in data_loader:
# print(data.shape)
optimizer.zero_grad()
= data.reshape(3,-1).T.to(device).float()
pixel_intensity input = torch.concatenate([xy[c_idx],pixel_intensity[c_idx]],axis=1).float()
= hypernet(input)
hyper_out = torch.mean(hyper_out,dim=0)
hyper_out
={}
target_dict for name,param in targetnet.named_parameters():
= start_end_mapping[name]
start,end = hyper_out[start:end].reshape(param.shape)
target_dict[name]
= torch.func.functional_call(targetnet, target_dict, xy)
img_out # print(img_out.shape, img_out[:,:3].shape, img_out[:,3:].shape, pixel_intensity.shape)
# print( img_out[:,:3], img_out[:,3:], pixel_intensity)
= neg_loglikelyhood(img_out[:,:3],img_out[:,3:],pixel_intensity)
loss
loss.backward()
optimizer.step()
= epoch_loss + loss.item()
epoch_loss =i+1
i
print("Epoch Loss=",epoch_loss/len(data_loader))
Context Points= 100
0%| | 0/20 [00:00<?, ?it/s]
Epoch= 1
5%|▌ | 1/20 [01:05<20:42, 65.40s/it]
Epoch Loss= -395.47315481672285
Epoch= 2
10%|█ | 2/20 [02:10<19:37, 65.41s/it]
Epoch Loss= -915.5049703121185
Epoch= 3
15%|█▌ | 3/20 [03:15<18:29, 65.26s/it]
Epoch Loss= -1166.0503022047044
Epoch= 4
20%|██ | 4/20 [04:21<17:24, 65.28s/it]
Epoch Loss= -1349.56748127985
Epoch= 5
25%|██▌ | 5/20 [05:26<16:19, 65.30s/it]
Epoch Loss= -1396.8538594449997
Epoch= 6
30%|███ | 6/20 [06:31<15:14, 65.31s/it]
Epoch Loss= -1479.6613237543106
Epoch= 7
35%|███▌ | 7/20 [07:37<14:09, 65.36s/it]
Epoch Loss= -1449.8615832103728
Epoch= 8
40%|████ | 8/20 [08:42<13:04, 65.38s/it]
Epoch Loss= -1528.4998937654495
Epoch= 9
45%|████▌ | 9/20 [09:47<11:58, 65.31s/it]
Epoch Loss= -1538.7266953744888
Epoch= 10
50%|█████ | 10/20 [10:53<10:53, 65.36s/it]
Epoch Loss= -1574.4109719749451
Epoch= 11
55%|█████▌ | 11/20 [11:58<09:48, 65.35s/it]
Epoch Loss= -1558.2231241334914
Epoch= 12
60%|██████ | 12/20 [13:03<08:42, 65.32s/it]
Epoch Loss= -1585.886608516693
Epoch= 13
65%|██████▌ | 13/20 [14:09<07:36, 65.27s/it]
Epoch Loss= -1586.6056880670546
Epoch= 14
70%|███████ | 14/20 [15:14<06:31, 65.23s/it]
Epoch Loss= -1561.2374246302604
Epoch= 15
75%|███████▌ | 15/20 [16:19<05:25, 65.20s/it]
Epoch Loss= -1606.3488553873062
Epoch= 16
80%|████████ | 16/20 [17:24<04:20, 65.24s/it]
Epoch Loss= -1637.4123486403466
Epoch= 17
85%|████████▌ | 17/20 [18:30<03:15, 65.29s/it]
Epoch Loss= -1656.406247360611
Epoch= 18
90%|█████████ | 18/20 [19:33<02:09, 64.69s/it]
Epoch Loss= -1621.5405502536773
Epoch= 19
95%|█████████▌| 19/20 [20:38<01:04, 64.88s/it]
Epoch Loss= -1708.3212175039291
Epoch= 20
100%|██████████| 20/20 [21:44<00:00, 65.21s/it]
Epoch Loss= -1700.857941632271
saving and loading the model
'hypernet_model_10000.pth')
torch.save(hypernet.state_dict(), 'targetnet_model_10000.pth') torch.save(targetnet.state_dict(),
# Load the hypernet and targetnet models
= HyperNet(total_params=count_params(targetnet), activation=torch.relu).to(device)
hypernet 'hypernet_model_10000.pth'))
hypernet.load_state_dict(torch.load(eval() # Set the model to evaluation mode hypernet.
HyperNet(
(fc1): Linear(in_features=5, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=128, bias=True)
(fc3): Linear(in_features=128, out_features=50694, bias=True)
)
= TargetNet(activation=torch.relu, n_out=6, activation_scale=1).to(device)
targetnet 'targetnet_model_10000.pth'))
targetnet.load_state_dict(torch.load(eval() targetnet.
TargetNet(
(fc1): Linear(in_features=2, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=128, bias=True)
(fc3): Linear(in_features=128, out_features=128, bias=True)
(fc4): Linear(in_features=128, out_features=128, bias=True)
(fc5): Linear(in_features=128, out_features=6, bias=True)
)
Testing
Testing phase architecture
loading the test data
= 1 # keep this to 1
batch_size = 32 # Change as needed
img_size
# Specify the root directory where the dataset is located
= '/home/jaiswalsuraj/suraj_work/projects/data/celeba/img_align_celeba_2599'
data_root
# Define the data transformations
= transforms.Compose([
transform # Resize the images to a common size (adjust as needed)
transforms.Resize((img_size, img_size)), # Convert images to tensors
transforms.ToTensor(),
])# default shape is torch.Size([3, 218, 178])
# Create the custom dataset
= CustomImageDataset(data_root, transform=transform)
celeba_dataset
# Create a data loader
= DataLoader(celeba_dataset, batch_size=batch_size, shuffle=False) test_data_loader
Plotting the results
def plot_hypernet(data,hypernet,targetnet,c_idx):
= data.reshape(3,-1).T.to(device).float()
pixel_intensity input = torch.concatenate([xy[c_idx],pixel_intensity[c_idx]],axis=1).float()
= hypernet(input) # hyper_out is a tensor of shape (n_context, total_params)
hyper_out = torch.mean(hyper_out,dim=0) # aggregate across context points
hyper_out
={}
target_dict = 0
start for name,param in targetnet.named_parameters():
= start + param.numel()
end = hyper_out[start:end].reshape(param.shape)
target_dict[name] = end
start
= torch.func.functional_call(targetnet, target_dict, xy)
img_out return img_out.cpu().detach()
= np.array(random.sample(range(img_size*img_size),1))
c_1 = np.array(random.sample(range(img_size*img_size),10))
c_10 = np.array(random.sample(range(img_size*img_size),100))
c_100 = np.array(random.sample(range(img_size*img_size),1000))
c_1000
= test_data_loader.dataset[0]
image_any = 0
idx = image_any data
=(9,7),constrained_layout=True)
plt.figure(figsize"HyperNetworks",fontsize=20)
plt.suptitle(def plot_image(i,j,k, data,hypernet,targetnet, c_idx):
plt.subplot(i,j,k)= data.permute(1,2,0)
img = np.zeros((32,32,3))
mask //32,c_idx%32,:] = 1
mask[c_idx*mask)
plt.imshow(imgf"Context: {len(c_idx)}")
plt.title('off')
plt.axis(
+4)
plt.subplot(i,j,k= plot_hypernet(data,hypernet,targetnet,c_idx)
plot_image 3].T.reshape(3,32,32).permute(1,2,0))
plt.imshow(plot_image[:,:'off')
plt.axis(
+8)
plt.subplot(i,j,k=plot_image[:,3:].exp().T.reshape(3,32,32).permute(1,2,0)
var = var-var.min()
var = var/var.max()
var
plt.imshow(var)'off') plt.axis(
<Figure size 900x700 with 0 Axes>
3,4,1,data,hypernet,targetnet,c_1)
plot_image(3,4,2,data,hypernet,targetnet,c_10)
plot_image(3,4,3,data,hypernet,targetnet,c_100)
plot_image(3,4,4,data,hypernet,targetnet,c_1000) plot_image(
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
/home/jaiswalsuraj/miniconda3/envs/tf_gpu/lib/python3.10/site-packages/matplotlib/cm.py:478: RuntimeWarning: invalid value encountered in cast
xx = (xx * 255).astype(np.uint8)
The first row shows the test context points, second row shows our model prediction and third row shows the variance of the predicted image.
Neural Processes
Training phase architecture
Encoder Decoder model defination
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, z_dim,activation=torch.sin,activation_scale=30.0):
super().__init__()
self.activation = activation
self.activation_scale = activation_scale
if activation != torch.sin:
self.activation_scale = 1.0
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, z_dim)
def forward(self, x):
= self.activation(self.linear1(x)*self.activation_scale)
x = self.activation(self.linear2(x)*self.activation_scale)
x return self.linear3(x)
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim, output_dim,activation=torch.sin,activation_scale=30.0):
super().__init__()
self.activation = activation
self.activation_scale = activation_scale
if activation != torch.sin:
self.activation_scale = 1.0
self.linear1 = nn.Linear(z_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, hidden_dim)
self.linear4 = nn.Linear(hidden_dim, hidden_dim)
self.linear5 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
= self.activation(self.linear1(x)*self.activation_scale)
x = self.activation(self.linear2(x)*self.activation_scale)
x = self.activation(self.linear3(x)*self.activation_scale)
x = self.activation(self.linear4(x)*self.activation_scale)
x return self.linear5(x)
from torchinfo import summary
= Encoder(5, 256, 128, activation=torch.relu,activation_scale=1)
encoder =(img_size*img_size,5)) # 32*32 = 1024 is the image size length, 5 is the input(x,y,r,g,b) to hypernet summary(encoder,input_size
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Encoder [1024, 128] --
├─Linear: 1-1 [1024, 256] 1,536
├─Linear: 1-2 [1024, 256] 65,792
├─Linear: 1-3 [1024, 128] 32,896
==========================================================================================
Total params: 100,224
Trainable params: 100,224
Non-trainable params: 0
Total mult-adds (M): 102.63
==========================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 5.24
Params size (MB): 0.40
Estimated Total Size (MB): 5.66
==========================================================================================
print(encoder)
Encoder(
(linear1): Linear(in_features=5, out_features=256, bias=True)
(linear2): Linear(in_features=256, out_features=256, bias=True)
(linear3): Linear(in_features=256, out_features=128, bias=True)
)
= Decoder(130, 256, 6, activation=torch.relu,activation_scale=1)
decoder =(img_size*img_size,130)) summary(decoder,input_size
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Decoder [1024, 6] --
├─Linear: 1-1 [1024, 256] 33,536
├─Linear: 1-2 [1024, 256] 65,792
├─Linear: 1-3 [1024, 256] 65,792
├─Linear: 1-4 [1024, 256] 65,792
├─Linear: 1-5 [1024, 6] 1,542
==========================================================================================
Total params: 232,454
Trainable params: 232,454
Non-trainable params: 0
Total mult-adds (M): 238.03
==========================================================================================
Input size (MB): 0.53
Forward/backward pass size (MB): 8.44
Params size (MB): 0.93
Estimated Total Size (MB): 9.90
==========================================================================================
print(decoder)
Decoder(
(linear1): Linear(in_features=130, out_features=256, bias=True)
(linear2): Linear(in_features=256, out_features=256, bias=True)
(linear3): Linear(in_features=256, out_features=256, bias=True)
(linear4): Linear(in_features=256, out_features=256, bias=True)
(linear5): Linear(in_features=256, out_features=6, bias=True)
)
Initialize the input
= create_coordinate_map(data_loader.dataset[0])
corr, vals corr, vals
(tensor([[ 0., 0.],
[ 1., 0.],
[ 2., 0.],
...,
[29., 31.],
[30., 31.],
[31., 31.]], device='cuda:2'),
tensor([[0.4510, 0.4706, 0.4824],
[0.4745, 0.4745, 0.4471],
[0.4667, 0.4353, 0.5412],
...,
[0.0314, 0.0549, 0.0471],
[0.0431, 0.0392, 0.0510],
[0.0549, 0.0392, 0.0549]], device='cuda:2'))
= preprocessing.MinMaxScaler().fit(corr.cpu())
scaler_img = torch.tensor(scaler_img.transform(corr.cpu())).float().to(device)
xy xy, xy.shape
(tensor([[0.0000, 0.0000],
[0.0323, 0.0000],
[0.0645, 0.0000],
...,
[0.9355, 1.0000],
[0.9677, 1.0000],
[1.0000, 1.0000]], device='cuda:2'),
torch.Size([1024, 2]))
Training loop
=20
n_epochs= 0.003
lr = 200
n_context print("Context Points=",n_context)
= Encoder(input_dim=5, hidden_dim=512, z_dim=128,activation=torch.relu,activation_scale=1).to(device)
encoder = Decoder(z_dim=130, hidden_dim=512, output_dim=6,activation=torch.relu,activation_scale=1).to(device)
decoder = optim.Adam(list(encoder.parameters())+list(decoder.parameters()),lr=lr)
optimizer
for epoch in trange(n_epochs):
= np.array(random.sample(range(1023),n_context))
c_idx
print("Epoch=",epoch+1)
= 0
epoch_loss =1
ifor data in data_loader:
# print(data.shape)
optimizer.zero_grad()
= data.reshape(3,-1).T.to(device).float()
pixel_intensity input = torch.concatenate([xy[c_idx],pixel_intensity[c_idx]],axis=1).float()
= encoder(input)
encoder_out = torch.mean(encoder_out,dim=0)
encoder_out
= encoder_out.repeat(1024,1)
decoder_in = torch.concatenate([xy,decoder_in],axis=1)
decoder_in
= decoder(decoder_in)
img_out
= neg_loglikelyhood(img_out[:,:3],img_out[:,3:],pixel_intensity)
loss
loss.backward()
optimizer.step()
= epoch_loss + loss.item()
epoch_loss =i+1
iprint("Epoch Loss=",epoch_loss/len(data_loader))
Context Points= 200
0%| | 0/20 [00:00<?, ?it/s]
Epoch= 1
5%|▌ | 1/20 [00:58<18:27, 58.30s/it]
Epoch Loss= 47.26116828255653
Epoch= 2
10%|█ | 2/20 [01:50<16:27, 54.89s/it]
Epoch Loss= -455.4421627301693
Epoch= 3
15%|█▌ | 3/20 [02:43<15:16, 53.92s/it]
Epoch Loss= -683.1707640041351
Epoch= 4
20%|██ | 4/20 [03:35<14:13, 53.32s/it]
Epoch Loss= -761.2885692318916
Epoch= 5
25%|██▌ | 5/20 [04:36<13:56, 55.78s/it]
Epoch Loss= -827.9153870079041
Epoch= 6
30%|███ | 6/20 [05:35<13:20, 57.17s/it]
Epoch Loss= -938.4062066322326
Epoch= 7
35%|███▌ | 7/20 [06:30<12:11, 56.24s/it]
Epoch Loss= -1007.5277942465782
Epoch= 8
40%|████ | 8/20 [07:30<11:28, 57.39s/it]
Epoch Loss= -1048.726986592865
Epoch= 9
45%|████▌ | 9/20 [08:20<10:06, 55.15s/it]
Epoch Loss= -1057.8311284263611
Epoch= 10
50%|█████ | 10/20 [09:17<09:17, 55.74s/it]
Epoch Loss= -1070.1760208235742
Epoch= 11
55%|█████▌ | 11/20 [09:56<07:34, 50.50s/it]
Epoch Loss= -1065.5062245418549
Epoch= 12
60%|██████ | 12/20 [10:35<06:16, 47.06s/it]
Epoch Loss= -1078.0465439793586
Epoch= 13
65%|██████▌ | 13/20 [11:14<05:13, 44.75s/it]
Epoch Loss= -1088.4120592634201
Epoch= 14
70%|███████ | 14/20 [11:53<04:17, 42.99s/it]
Epoch Loss= -1078.4957633354188
Epoch= 15
75%|███████▌ | 15/20 [12:32<03:28, 41.78s/it]
Epoch Loss= -1084.8360796244622
Epoch= 16
80%|████████ | 16/20 [13:11<02:43, 40.94s/it]
Epoch Loss= -1093.4410486424447
Epoch= 17
85%|████████▌ | 17/20 [13:50<02:01, 40.44s/it]
Epoch Loss= -1113.8205106693267
Epoch= 18
90%|█████████ | 18/20 [14:29<01:19, 39.99s/it]
Epoch Loss= -1110.6333978479386
Epoch= 19
95%|█████████▌| 19/20 [15:25<00:44, 44.66s/it]
Epoch Loss= -1098.3281953195572
Epoch= 20
100%|██████████| 20/20 [16:25<00:00, 49.28s/it]
Epoch Loss= -1106.2516599431992
saving and loading the model
'encoder_model_10000.pth')
torch.save(encoder.state_dict(), 'decoder_model_10000.pth') torch.save(decoder.state_dict(),
# Load the hypernet and targetnet models
= Encoder(input_dim=5, hidden_dim=128, z_dim=128,activation=torch.relu,activation_scale=1).to(device)
encoder
'encoder_model_10000.pth'))
encoder.load_state_dict(torch.load(eval() # Set the model to evaluation mode encoder.
Encoder(
(linear1): Linear(in_features=5, out_features=128, bias=True)
(linear2): Linear(in_features=128, out_features=128, bias=True)
(linear3): Linear(in_features=128, out_features=128, bias=True)
)
= Decoder(z_dim=130, hidden_dim=256, output_dim=6,activation=torch.relu,activation_scale=1).to(device)
decoder 'decoder_model_10000.pth'))
decoder.load_state_dict(torch.load(eval() decoder.
Decoder(
(linear1): Linear(in_features=130, out_features=256, bias=True)
(linear2): Linear(in_features=256, out_features=256, bias=True)
(linear3): Linear(in_features=256, out_features=256, bias=True)
(linear4): Linear(in_features=256, out_features=256, bias=True)
(linear5): Linear(in_features=256, out_features=6, bias=True)
)
Testing
Testing phase architecture
Loading the test data
= 1 # keep this to 1
batch_size = 32 # Change as needed
img_size
# Specify the root directory where the dataset is located
= 'data/celeba/img_align_celeba_2599'
data_root
# Define the data transformations
= transforms.Compose([
transform # Resize the images to a common size (adjust as needed)
transforms.Resize((img_size, img_size)), # Convert images to tensors
transforms.ToTensor(),
])# default shape is torch.Size([3, 218, 178])
# Create the custom dataset
= CustomImageDataset(data_root, transform=transform)
celeba_dataset
# Create a data loader
= DataLoader(celeba_dataset, batch_size=batch_size, shuffle=False) test_data_loader
Plotting the results
def plot_enc_dec(data,encoder,decoder,c_idx):
= data.reshape(3,-1).T.to(device).float()
pixel_intensity input = torch.concatenate([xy[c_idx],pixel_intensity[c_idx]],axis=1).float()
= encoder(input)
encoder_out = torch.mean(encoder_out,dim=0)
encoder_out
= encoder_out.repeat(1024,1)
decoder_in = torch.concatenate([xy,decoder_in],axis=1)
decoder_in
= decoder(decoder_in)
img_out return img_out.cpu().detach()
= np.array(random.sample(range(img_size*img_size),1))
c_1 = np.array(random.sample(range(img_size*img_size),10))
c_10 = np.array(random.sample(range(img_size*img_size),100))
c_100 = np.array(random.sample(range(img_size*img_size),1000))
c_1000
= 5
idx = test_data_loader.dataset[idx]
image_any = image_any data
=(9,7),constrained_layout=True)
plt.figure(figsize"Neural process",fontsize=20)
plt.suptitle(def plot_image(i,j,k, data,encoder,decoder, c_idx):
plt.subplot(i,j,k)= data.permute(1,2,0)
img = np.zeros((32,32,3))
mask //32,c_idx%32,:] = 1
mask[c_idx*mask)
plt.imshow(imgf"Context: {len(c_idx)}")
plt.title('off')
plt.axis(
+4)
plt.subplot(i,j,k= plot_enc_dec(data,encoder,decoder,c_idx)
plot_image 3].T.reshape(3,32,32).permute(1,2,0))
plt.imshow(plot_image[:,:'off')
plt.axis(
+8)
plt.subplot(i,j,k=plot_image[:,3:].exp().T.reshape(3,32,32).permute(1,2,0)
var = var-var.min()
var = var/var.max()
var
plt.imshow(var)'off') plt.axis(
<Figure size 900x700 with 0 Axes>
3,4,1,data,encoder,decoder,c_1)
plot_image(3,4,2,data,encoder,decoder,c_10)
plot_image(3,4,3,data,encoder,decoder,c_100)
plot_image(3,4,4,data,encoder,decoder,c_1000) plot_image(
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
The first row shows the test context points, second row shows our model prediction and third row shows the variance of the predicted image.
Conclusion
- Here we have seen some implementaion of meta learning using hypernet and neural processes.
- Further we can improve if we use sin activation function in the model for image like data. Refer link for more details.
The saga concludes with a symphony of learning, as Hypernets and Neural Processes rewrite the narrative of image reconstruction on CelebA, transforming a complex challenge into a melodious solution.