Hypernet and Neural Processes on CelebA

Deep Learning
Meta Learning
Author

Suraj Jaiswal

Published

November 8, 2023

Our goal is constructing the whole image from its few context points.

Motivation: Why are we doing this? Say, for example, you have less bandwidth, but you want to send an image to your friend. What you can do is to compress the image to few context points using NN model and give this compressed image and the model to your friend. Your friend can use this model and these few context points of the original image to reconstruct the whole image. We can also use this for image inpainting, super-resolution, etc.

Simple ways can be as follows, but we have some issues with these approaches: - For this, if we learn a single neural network for all images, then this will be a general model, but at testing time, we don’t have the whole image, so we won’t be able to use the same model. - And if we learn a neural network for each image, then this will be a task-specific model, but at testing time, we don’t know which image model to pick for a new image.

So, we use a Meta learning setup using hypernet and neural processes to learn a task-specific neural network that predicts the whole image given a few context points of an image.

Why meta? Because we are learning a model that learns the parameters of another model.

We use our version of the following to reproduce Figure 4 from the paper referenced at link conditional neural network paper. - Hypernet - Neural Processes

Open In Colab

import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
from sklearn import preprocessing
from tqdm import trange

from PIL import Image
import os
from tabulate import tabulate
# select gpu
device = torch.device("cuda:3")
print(device)
current_device = device #torch.cuda.current_device()
device_name = torch.cuda.get_device_name(current_device)
print(f"Current GPU assigned: {current_device}, Name: {device_name}")
cuda:3
Current GPU assigned: cuda:3, Name: NVIDIA A100-SXM4-80GB

Loading and preprocessing

  • You can download CelebA data from link. We only need images for our task. Extract these images in the same folder as this notebook.

  • There are 2,02,599 images in total in CelebA dataset. We will use only 10,000 images for training and 2599 for test dataset. For better results, you can use full dataset.

class CustomImageDataset(Dataset):
    def __init__(self, data_root, transform=None):
        self.data_root = data_root
        self.transform = transform
        self.image_files = [f for f in os.listdir(data_root) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.data_root, self.image_files[idx])
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image
batch_size = 1 # keep this to 1
img_size = 32 # Change as needed

# Specify the root directory where the dataset is located
data_root = 'data/celeba/img_align_celeba_10000'

# Define the data transformations
transform = transforms.Compose([
    # transforms.Resize((img_size, img_size)),  # Resize the images to a common size (adjust as needed)
    transforms.ToTensor(),   # Convert images to tensors
])
# default shape is torch.Size([3, 218, 178])
# Create the custom dataset
celeba_dataset = CustomImageDataset(data_root, transform=transform)

# Create a data loader
data_loader = DataLoader(celeba_dataset, batch_size=batch_size, shuffle=False)

Original image

plt.imshow(torch.einsum('chw -> hwc', data_loader.dataset[33]))
<matplotlib.image.AxesImage at 0x7f800aac09d0>

# Define the data transformations
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),  # Resize the images to a common size (adjust as needed)
    transforms.ToTensor(),   # Convert images to tensors
])
# default shape is torch.Size([3, 218, 178])
# Create the custom dataset
celeba_dataset = CustomImageDataset(data_root, transform=transform)

# Create a data loader
data_loader = DataLoader(celeba_dataset, batch_size=batch_size, shuffle=False)

Original image after transformation (applying resize to 32x32 to reduce computational cost)

plt.imshow(torch.einsum('chw -> hwc', data_loader.dataset[33]))
<matplotlib.image.AxesImage at 0x7ff83c94fdf0>

Create a coordinate dataset from the image

def create_coordinate_map(img):
    """
    img: torch.Tensor of shape (num_channels, height, width)

    return: tuple of torch.Tensor of shape (height* width, 2) and torch.tensor containing the (num_channels)
    """
    num_channels, height, width = img.shape

    # Create a 2D grid of (x,y) coordinates
    x_coords = torch.arange(width).repeat(height, 1)
    y_coords = torch.arange(height).repeat(width, 1).t()
    x_coords = x_coords.reshape(-1)
    y_coords = y_coords.reshape(-1)

    # Combine the x and y coordinates into a single tensor
    X = torch.stack([x_coords, y_coords], dim=1).float()

    # Move X to GPU if available
    X = X.to(device)

    # Create a tensor containing the image pixel values
    Y = img.reshape(-1, num_channels).float().to(device)
    return X, Y

Below is our Loss function. It is negative log likelihood loss. We will use this to train our model. This takes: - y_pred which is the mean of the predicted distribution and y_true which is the actual value for RGB channels. - log_sigma which is the log of standard deviation of the predicted distribution for each RGB channel. - y_true which is the actual value for RGB channels.

We convert list of sigma into identity covariance matrix and then calculate the log probability of the actual value given the predicted distribution with mean list y_pred for each RGB channel and covariance matrix of log_sigma.

def neg_loglikelyhood(y_pred,log_sigma,y_true):
    cov_matrix = torch.diag_embed(log_sigma.exp())
    dist = torch.distributions.MultivariateNormal(y_pred,cov_matrix,validate_args=False)
    return - dist.log_prob(y_true).sum()

To count the parameter in any model, you can use following code:

def count_params(model):
    # return torch.sum(p.numel() for p in model.parameters() if p.requires_grad)
    return torch.sum(torch.tensor([p.numel() for p in model.parameters()]))

Hyper Network

Training phase architecture

Model defination

Target net defination

# Create a MLP with 5 hidden layers with 256 neurons each and ReLU activations.
# Input is (x, y) and output is (r, g, b) or (g) for grayscale
# here we output 6 values (3 for RGB mean and 3 for RGB std)
s = 128 # hidden dim of model

class TargetNet(nn.Module):
    def _init_siren(self, activation_scale):
        self.fc1.weight.data.uniform_(-1/self.fc1.in_features, 1/self.fc1.in_features)
        for layers in [self.fc2, self.fc3, self.fc4, self.fc5]:
            layers.weight.data.uniform_(-np.sqrt(6/self.fc2.in_features)/activation_scale,
                                        np.sqrt(6/self.fc2.in_features)/activation_scale)

    def __init__(self, activation=torch.relu, n_out=1, activation_scale=1.0):
        super().__init__()
        self.activation = activation
        self.activation_scale = activation_scale
        self.fc1 = nn.Linear(2, s) # input size is 2 (x, y) location of pixel
        self.fc2 = nn.Linear(s, s)
        self.fc3 = nn.Linear(s, s)
        self.fc4 = nn.Linear(s, s)
        self.fc5 = nn.Linear(s, n_out) #gray scale image (1) or RGB (3)
        if self.activation == torch.sin:
            # init weights and biases for sine activation
            self._init_siren(activation_scale=self.activation_scale)

    def forward(self, x):
        x = self.activation(self.activation_scale*self.fc1(x))
        x = self.activation(self.activation_scale*self.fc2(x))
        x = self.activation(self.activation_scale*self.fc3(x))
        x = self.activation(self.activation_scale*self.fc4(x))
        return self.fc5(x)

Hypernetwork defination

Input: (x, y, R, G, B)

Output: Our Hypernetwork should have the output equal to the number of parameters in the main network.

# pass total params of target network before calling the hypernetwork model
class HyperNet(nn.Module):
    def __init__(self, total_params, num_neurons=128, activation=torch.relu):
        super().__init__()
        self.activation = activation
        self.n_out = total_params
        self.fc1 = nn.Linear(5, num_neurons)
        self.fc2 = nn.Linear(num_neurons, num_neurons)
        self.fc3 = nn.Linear(num_neurons, self.n_out)

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        return self.fc3(x)

Initialize the model and input

Initialize the target network

from torchinfo import summary
targetnet = TargetNet(activation=torch.relu, n_out=6, activation_scale=1).to(device)
summary(targetnet, input_size=(img_size* img_size, 2)) #32*32 =1024 is the image size lentgh, 2 is x,y coordinate
# outputs 6: 1,2,3 mean of each channel and 4,5,6 are log sigma of each channel
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
TargetNet                                [1024, 6]                 --
├─Linear: 1-1                            [1024, 128]               384
├─Linear: 1-2                            [1024, 128]               16,512
├─Linear: 1-3                            [1024, 128]               16,512
├─Linear: 1-4                            [1024, 128]               16,512
├─Linear: 1-5                            [1024, 6]                 774
==========================================================================================
Total params: 50,694
Trainable params: 50,694
Non-trainable params: 0
Total mult-adds (M): 51.91
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 4.24
Params size (MB): 0.20
Estimated Total Size (MB): 4.45
==========================================================================================
targetnet
TargetNet(
  (fc1): Linear(in_features=2, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=128, bias=True)
  (fc5): Linear(in_features=128, out_features=6, bias=True)
)
count_params(targetnet)
tensor(50694)

initialize the hypernetwork model

hypernet = HyperNet(total_params=count_params(targetnet), activation=torch.sin).to(device)
print(hypernet)
HyperNet(
  (fc1): Linear(in_features=5, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=50694, bias=True)
)
summary(hypernet,input_size=(img_size* img_size,5))  # 32*32 = 1024 is the image size length, 5 is the input(x,y,r,g,b) to hypernet
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
HyperNet                                 [1024, 50694]             --
├─Linear: 1-1                            [1024, 128]               768
├─Linear: 1-2                            [1024, 128]               16,512
├─Linear: 1-3                            [1024, 50694]             6,539,526
==========================================================================================
Total params: 6,556,806
Trainable params: 6,556,806
Non-trainable params: 0
Total mult-adds (G): 6.71
==========================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 417.38
Params size (MB): 26.23
Estimated Total Size (MB): 443.63
==========================================================================================
table_data = []
total_params = 0
start = 0
start_end_mapping = {}
for name, param in targetnet.named_parameters():
    param_count = torch.prod(torch.tensor(param.shape)).item()
    total_params += param_count
    end = total_params
    table_data.append([name, param.shape, param_count, start, end])
    start_end_mapping[name] = (start, end)
    start = end

print(tabulate(table_data, headers=["Layer Name", "Shape", "Parameter Count", "Start Index", "End Index"]))
print(f"Total number of parameters: {total_params}")
Layer Name    Shape                     Parameter Count    Start Index    End Index
------------  ----------------------  -----------------  -------------  -----------
fc1.weight    torch.Size([128, 2])                  256              0          256
fc1.bias      torch.Size([128])                     128            256          384
fc2.weight    torch.Size([128, 128])              16384            384        16768
fc2.bias      torch.Size([128])                     128          16768        16896
fc3.weight    torch.Size([128, 128])              16384          16896        33280
fc3.bias      torch.Size([128])                     128          33280        33408
fc4.weight    torch.Size([128, 128])              16384          33408        49792
fc4.bias      torch.Size([128])                     128          49792        49920
fc5.weight    torch.Size([6, 128])                  768          49920        50688
fc5.bias      torch.Size([6])                         6          50688        50694
Total number of parameters: 50694

Initialize the input

corr, vals = create_coordinate_map(data_loader.dataset[0])
corr, vals
(tensor([[ 0.,  0.],
         [ 1.,  0.],
         [ 2.,  0.],
         ...,
         [29., 31.],
         [30., 31.],
         [31., 31.]], device='cuda:3'),
 tensor([[0.4510, 0.4706, 0.4824],
         [0.4745, 0.4745, 0.4471],
         [0.4667, 0.4353, 0.5412],
         ...,
         [0.0314, 0.0549, 0.0471],
         [0.0431, 0.0392, 0.0510],
         [0.0549, 0.0392, 0.0549]], device='cuda:3'))
scaler_img = preprocessing.MinMaxScaler().fit(corr.cpu())
xy = torch.tensor(scaler_img.transform(corr.cpu())).float().to(device)
xy, xy.shape
(tensor([[0.0000, 0.0000],
         [0.0323, 0.0000],
         [0.0645, 0.0000],
         ...,
         [0.9355, 1.0000],
         [0.9677, 1.0000],
         [1.0000, 1.0000]], device='cuda:3'),
 torch.Size([1024, 2]))

Training loop

n_epochs=20
lr = 0.003

targetnet = TargetNet(activation=torch.relu, n_out=6, activation_scale=1).to(device)
hypernet = HyperNet(total_params=count_params(targetnet), activation=torch.relu).to(device)
optimizer = optim.Adam(hypernet.parameters(),lr=lr) # only hypernet is updated

n_context = 100
print("Context Points=",n_context)
for epoch in trange(n_epochs):

    c_idx = np.array(random.sample(range(1023),n_context))

    print("Epoch=",epoch+1)
    epoch_loss = 0
    i=1

    for data in data_loader:
        # print(data.shape)
        optimizer.zero_grad()

        pixel_intensity = data.reshape(3,-1).T.to(device).float()
        input = torch.concatenate([xy[c_idx],pixel_intensity[c_idx]],axis=1).float()

        hyper_out = hypernet(input)
        hyper_out = torch.mean(hyper_out,dim=0)

        target_dict ={}
        for name,param in targetnet.named_parameters():
            start,end = start_end_mapping[name]
            target_dict[name] = hyper_out[start:end].reshape(param.shape)

        img_out = torch.func.functional_call(targetnet, target_dict, xy)
        # print(img_out.shape, img_out[:,:3].shape, img_out[:,3:].shape, pixel_intensity.shape)
        # print( img_out[:,:3], img_out[:,3:], pixel_intensity)
        loss = neg_loglikelyhood(img_out[:,:3],img_out[:,3:],pixel_intensity)
        loss.backward()
        optimizer.step()

        epoch_loss = epoch_loss + loss.item()
        i=i+1

    print("Epoch Loss=",epoch_loss/len(data_loader))
Context Points= 100
  0%|          | 0/20 [00:00<?, ?it/s]
Epoch= 1
  5%|▌         | 1/20 [01:05<20:42, 65.40s/it]
Epoch Loss= -395.47315481672285
Epoch= 2
 10%|█         | 2/20 [02:10<19:37, 65.41s/it]
Epoch Loss= -915.5049703121185
Epoch= 3
 15%|█▌        | 3/20 [03:15<18:29, 65.26s/it]
Epoch Loss= -1166.0503022047044
Epoch= 4
 20%|██        | 4/20 [04:21<17:24, 65.28s/it]
Epoch Loss= -1349.56748127985
Epoch= 5
 25%|██▌       | 5/20 [05:26<16:19, 65.30s/it]
Epoch Loss= -1396.8538594449997
Epoch= 6
 30%|███       | 6/20 [06:31<15:14, 65.31s/it]
Epoch Loss= -1479.6613237543106
Epoch= 7
 35%|███▌      | 7/20 [07:37<14:09, 65.36s/it]
Epoch Loss= -1449.8615832103728
Epoch= 8
 40%|████      | 8/20 [08:42<13:04, 65.38s/it]
Epoch Loss= -1528.4998937654495
Epoch= 9
 45%|████▌     | 9/20 [09:47<11:58, 65.31s/it]
Epoch Loss= -1538.7266953744888
Epoch= 10
 50%|█████     | 10/20 [10:53<10:53, 65.36s/it]
Epoch Loss= -1574.4109719749451
Epoch= 11
 55%|█████▌    | 11/20 [11:58<09:48, 65.35s/it]
Epoch Loss= -1558.2231241334914
Epoch= 12
 60%|██████    | 12/20 [13:03<08:42, 65.32s/it]
Epoch Loss= -1585.886608516693
Epoch= 13
 65%|██████▌   | 13/20 [14:09<07:36, 65.27s/it]
Epoch Loss= -1586.6056880670546
Epoch= 14
 70%|███████   | 14/20 [15:14<06:31, 65.23s/it]
Epoch Loss= -1561.2374246302604
Epoch= 15
 75%|███████▌  | 15/20 [16:19<05:25, 65.20s/it]
Epoch Loss= -1606.3488553873062
Epoch= 16
 80%|████████  | 16/20 [17:24<04:20, 65.24s/it]
Epoch Loss= -1637.4123486403466
Epoch= 17
 85%|████████▌ | 17/20 [18:30<03:15, 65.29s/it]
Epoch Loss= -1656.406247360611
Epoch= 18
 90%|█████████ | 18/20 [19:33<02:09, 64.69s/it]
Epoch Loss= -1621.5405502536773
Epoch= 19
 95%|█████████▌| 19/20 [20:38<01:04, 64.88s/it]
Epoch Loss= -1708.3212175039291
Epoch= 20
100%|██████████| 20/20 [21:44<00:00, 65.21s/it]
Epoch Loss= -1700.857941632271

saving and loading the model

torch.save(hypernet.state_dict(), 'hypernet_model_10000.pth')
torch.save(targetnet.state_dict(), 'targetnet_model_10000.pth')
# Load the hypernet and targetnet models
hypernet = HyperNet(total_params=count_params(targetnet), activation=torch.relu).to(device)
hypernet.load_state_dict(torch.load('hypernet_model_10000.pth'))
hypernet.eval()  # Set the model to evaluation mode
HyperNet(
  (fc1): Linear(in_features=5, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=50694, bias=True)
)
targetnet = TargetNet(activation=torch.relu, n_out=6, activation_scale=1).to(device)
targetnet.load_state_dict(torch.load('targetnet_model_10000.pth'))
targetnet.eval()
TargetNet(
  (fc1): Linear(in_features=2, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=128, bias=True)
  (fc5): Linear(in_features=128, out_features=6, bias=True)
)

Testing

Testing phase architecture

loading the test data

batch_size = 1 # keep this to 1
img_size = 32 # Change as needed

# Specify the root directory where the dataset is located
data_root = '/home/jaiswalsuraj/suraj_work/projects/data/celeba/img_align_celeba_2599'

# Define the data transformations
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),  # Resize the images to a common size (adjust as needed)
    transforms.ToTensor(),   # Convert images to tensors
])
# default shape is torch.Size([3, 218, 178])
# Create the custom dataset
celeba_dataset = CustomImageDataset(data_root, transform=transform)

# Create a data loader
test_data_loader = DataLoader(celeba_dataset, batch_size=batch_size, shuffle=False)

Plotting the results

def plot_hypernet(data,hypernet,targetnet,c_idx):

    pixel_intensity = data.reshape(3,-1).T.to(device).float()
    input = torch.concatenate([xy[c_idx],pixel_intensity[c_idx]],axis=1).float()

    hyper_out = hypernet(input) # hyper_out is a tensor of shape (n_context, total_params)
    hyper_out = torch.mean(hyper_out,dim=0) # aggregate across context points

    target_dict ={}
    start = 0
    for name,param in targetnet.named_parameters():
        end = start + param.numel()
        target_dict[name] = hyper_out[start:end].reshape(param.shape)
        start = end

    img_out = torch.func.functional_call(targetnet, target_dict, xy)
    return img_out.cpu().detach()
c_1 = np.array(random.sample(range(img_size*img_size),1))
c_10 = np.array(random.sample(range(img_size*img_size),10))
c_100 = np.array(random.sample(range(img_size*img_size),100))
c_1000 = np.array(random.sample(range(img_size*img_size),1000))

image_any = test_data_loader.dataset[0]
idx = 0
data = image_any
plt.figure(figsize=(9,7),constrained_layout=True)
plt.suptitle("HyperNetworks",fontsize=20)
def plot_image(i,j,k, data,hypernet,targetnet, c_idx):
    plt.subplot(i,j,k)
    img = data.permute(1,2,0)
    mask = np.zeros((32,32,3))
    mask[c_idx//32,c_idx%32,:] = 1
    plt.imshow(img*mask)
    plt.title(f"Context: {len(c_idx)}")
    plt.axis('off')

    plt.subplot(i,j,k+4)
    plot_image = plot_hypernet(data,hypernet,targetnet,c_idx)
    plt.imshow(plot_image[:,:3].T.reshape(3,32,32).permute(1,2,0))
    plt.axis('off')

    plt.subplot(i,j,k+8)
    var =plot_image[:,3:].exp().T.reshape(3,32,32).permute(1,2,0)
    var = var-var.min()
    var = var/var.max()
    plt.imshow(var)
    plt.axis('off')
<Figure size 900x700 with 0 Axes>
plot_image(3,4,1,data,hypernet,targetnet,c_1)
plot_image(3,4,2,data,hypernet,targetnet,c_10)
plot_image(3,4,3,data,hypernet,targetnet,c_100)
plot_image(3,4,4,data,hypernet,targetnet,c_1000)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
/home/jaiswalsuraj/miniconda3/envs/tf_gpu/lib/python3.10/site-packages/matplotlib/cm.py:478: RuntimeWarning: invalid value encountered in cast
  xx = (xx * 255).astype(np.uint8)

The first row shows the test context points, second row shows our model prediction and third row shows the variance of the predicted image.

Neural Processes

Training phase architecture

Encoder Decoder model defination

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim,activation=torch.sin,activation_scale=30.0):
        super().__init__()
        self.activation = activation
        self.activation_scale = activation_scale
        if activation != torch.sin:
            self.activation_scale = 1.0

        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, z_dim)

    def forward(self, x):
        x = self.activation(self.linear1(x)*self.activation_scale)
        x = self.activation(self.linear2(x)*self.activation_scale)
        return self.linear3(x)

class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, output_dim,activation=torch.sin,activation_scale=30.0):
        super().__init__()
        self.activation = activation
        self.activation_scale = activation_scale
        if activation != torch.sin:
            self.activation_scale = 1.0
        self.linear1 = nn.Linear(z_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = nn.Linear(hidden_dim, hidden_dim)
        self.linear5 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.activation(self.linear1(x)*self.activation_scale)
        x = self.activation(self.linear2(x)*self.activation_scale)
        x = self.activation(self.linear3(x)*self.activation_scale)
        x = self.activation(self.linear4(x)*self.activation_scale)
        return self.linear5(x)
from torchinfo import summary
encoder = Encoder(5, 256, 128, activation=torch.relu,activation_scale=1)
summary(encoder,input_size=(img_size*img_size,5)) # 32*32 = 1024 is the image size length, 5 is the input(x,y,r,g,b) to hypernet
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Encoder                                  [1024, 128]               --
├─Linear: 1-1                            [1024, 256]               1,536
├─Linear: 1-2                            [1024, 256]               65,792
├─Linear: 1-3                            [1024, 128]               32,896
==========================================================================================
Total params: 100,224
Trainable params: 100,224
Non-trainable params: 0
Total mult-adds (M): 102.63
==========================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 5.24
Params size (MB): 0.40
Estimated Total Size (MB): 5.66
==========================================================================================
print(encoder)
Encoder(
  (linear1): Linear(in_features=5, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=256, bias=True)
  (linear3): Linear(in_features=256, out_features=128, bias=True)
)
decoder = Decoder(130, 256, 6, activation=torch.relu,activation_scale=1)
summary(decoder,input_size=(img_size*img_size,130))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Decoder                                  [1024, 6]                 --
├─Linear: 1-1                            [1024, 256]               33,536
├─Linear: 1-2                            [1024, 256]               65,792
├─Linear: 1-3                            [1024, 256]               65,792
├─Linear: 1-4                            [1024, 256]               65,792
├─Linear: 1-5                            [1024, 6]                 1,542
==========================================================================================
Total params: 232,454
Trainable params: 232,454
Non-trainable params: 0
Total mult-adds (M): 238.03
==========================================================================================
Input size (MB): 0.53
Forward/backward pass size (MB): 8.44
Params size (MB): 0.93
Estimated Total Size (MB): 9.90
==========================================================================================
print(decoder)
Decoder(
  (linear1): Linear(in_features=130, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=256, bias=True)
  (linear3): Linear(in_features=256, out_features=256, bias=True)
  (linear4): Linear(in_features=256, out_features=256, bias=True)
  (linear5): Linear(in_features=256, out_features=6, bias=True)
)

Initialize the input

corr, vals = create_coordinate_map(data_loader.dataset[0])
corr, vals
(tensor([[ 0.,  0.],
         [ 1.,  0.],
         [ 2.,  0.],
         ...,
         [29., 31.],
         [30., 31.],
         [31., 31.]], device='cuda:2'),
 tensor([[0.4510, 0.4706, 0.4824],
         [0.4745, 0.4745, 0.4471],
         [0.4667, 0.4353, 0.5412],
         ...,
         [0.0314, 0.0549, 0.0471],
         [0.0431, 0.0392, 0.0510],
         [0.0549, 0.0392, 0.0549]], device='cuda:2'))
scaler_img = preprocessing.MinMaxScaler().fit(corr.cpu())
xy = torch.tensor(scaler_img.transform(corr.cpu())).float().to(device)
xy, xy.shape
(tensor([[0.0000, 0.0000],
         [0.0323, 0.0000],
         [0.0645, 0.0000],
         ...,
         [0.9355, 1.0000],
         [0.9677, 1.0000],
         [1.0000, 1.0000]], device='cuda:2'),
 torch.Size([1024, 2]))

Training loop

n_epochs=20
lr = 0.003
n_context = 200
print("Context Points=",n_context)

encoder = Encoder(input_dim=5, hidden_dim=512, z_dim=128,activation=torch.relu,activation_scale=1).to(device)
decoder = Decoder(z_dim=130, hidden_dim=512, output_dim=6,activation=torch.relu,activation_scale=1).to(device)
optimizer = optim.Adam(list(encoder.parameters())+list(decoder.parameters()),lr=lr)

for epoch in trange(n_epochs):

    c_idx = np.array(random.sample(range(1023),n_context))

    print("Epoch=",epoch+1)
    epoch_loss = 0
    i=1
    for data in data_loader:
        # print(data.shape)

        optimizer.zero_grad()

        pixel_intensity = data.reshape(3,-1).T.to(device).float()
        input = torch.concatenate([xy[c_idx],pixel_intensity[c_idx]],axis=1).float()

        encoder_out = encoder(input)
        encoder_out = torch.mean(encoder_out,dim=0)

        decoder_in = encoder_out.repeat(1024,1)
        decoder_in = torch.concatenate([xy,decoder_in],axis=1)

        img_out = decoder(decoder_in)

        loss = neg_loglikelyhood(img_out[:,:3],img_out[:,3:],pixel_intensity)
        loss.backward()
        optimizer.step()

        epoch_loss = epoch_loss + loss.item()
        i=i+1
    print("Epoch Loss=",epoch_loss/len(data_loader))
Context Points= 200
  0%|          | 0/20 [00:00<?, ?it/s]
Epoch= 1
  5%|▌         | 1/20 [00:58<18:27, 58.30s/it]
Epoch Loss= 47.26116828255653
Epoch= 2
 10%|█         | 2/20 [01:50<16:27, 54.89s/it]
Epoch Loss= -455.4421627301693
Epoch= 3
 15%|█▌        | 3/20 [02:43<15:16, 53.92s/it]
Epoch Loss= -683.1707640041351
Epoch= 4
 20%|██        | 4/20 [03:35<14:13, 53.32s/it]
Epoch Loss= -761.2885692318916
Epoch= 5
 25%|██▌       | 5/20 [04:36<13:56, 55.78s/it]
Epoch Loss= -827.9153870079041
Epoch= 6
 30%|███       | 6/20 [05:35<13:20, 57.17s/it]
Epoch Loss= -938.4062066322326
Epoch= 7
 35%|███▌      | 7/20 [06:30<12:11, 56.24s/it]
Epoch Loss= -1007.5277942465782
Epoch= 8
 40%|████      | 8/20 [07:30<11:28, 57.39s/it]
Epoch Loss= -1048.726986592865
Epoch= 9
 45%|████▌     | 9/20 [08:20<10:06, 55.15s/it]
Epoch Loss= -1057.8311284263611
Epoch= 10
 50%|█████     | 10/20 [09:17<09:17, 55.74s/it]
Epoch Loss= -1070.1760208235742
Epoch= 11
 55%|█████▌    | 11/20 [09:56<07:34, 50.50s/it]
Epoch Loss= -1065.5062245418549
Epoch= 12
 60%|██████    | 12/20 [10:35<06:16, 47.06s/it]
Epoch Loss= -1078.0465439793586
Epoch= 13
 65%|██████▌   | 13/20 [11:14<05:13, 44.75s/it]
Epoch Loss= -1088.4120592634201
Epoch= 14
 70%|███████   | 14/20 [11:53<04:17, 42.99s/it]
Epoch Loss= -1078.4957633354188
Epoch= 15
 75%|███████▌  | 15/20 [12:32<03:28, 41.78s/it]
Epoch Loss= -1084.8360796244622
Epoch= 16
 80%|████████  | 16/20 [13:11<02:43, 40.94s/it]
Epoch Loss= -1093.4410486424447
Epoch= 17
 85%|████████▌ | 17/20 [13:50<02:01, 40.44s/it]
Epoch Loss= -1113.8205106693267
Epoch= 18
 90%|█████████ | 18/20 [14:29<01:19, 39.99s/it]
Epoch Loss= -1110.6333978479386
Epoch= 19
 95%|█████████▌| 19/20 [15:25<00:44, 44.66s/it]
Epoch Loss= -1098.3281953195572
Epoch= 20
100%|██████████| 20/20 [16:25<00:00, 49.28s/it]
Epoch Loss= -1106.2516599431992

saving and loading the model

torch.save(encoder.state_dict(), 'encoder_model_10000.pth')
torch.save(decoder.state_dict(), 'decoder_model_10000.pth')
# Load the hypernet and targetnet models
encoder = Encoder(input_dim=5, hidden_dim=128, z_dim=128,activation=torch.relu,activation_scale=1).to(device)

encoder.load_state_dict(torch.load('encoder_model_10000.pth'))
encoder.eval()  # Set the model to evaluation mode
Encoder(
  (linear1): Linear(in_features=5, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=128, bias=True)
)
decoder = Decoder(z_dim=130, hidden_dim=256, output_dim=6,activation=torch.relu,activation_scale=1).to(device)
decoder.load_state_dict(torch.load('decoder_model_10000.pth'))
decoder.eval()
Decoder(
  (linear1): Linear(in_features=130, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=256, bias=True)
  (linear3): Linear(in_features=256, out_features=256, bias=True)
  (linear4): Linear(in_features=256, out_features=256, bias=True)
  (linear5): Linear(in_features=256, out_features=6, bias=True)
)

Testing

Testing phase architecture

Loading the test data

batch_size = 1 # keep this to 1
img_size = 32 # Change as needed

# Specify the root directory where the dataset is located
data_root = 'data/celeba/img_align_celeba_2599'

# Define the data transformations
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),  # Resize the images to a common size (adjust as needed)
    transforms.ToTensor(),   # Convert images to tensors
])
# default shape is torch.Size([3, 218, 178])
# Create the custom dataset
celeba_dataset = CustomImageDataset(data_root, transform=transform)

# Create a data loader
test_data_loader = DataLoader(celeba_dataset, batch_size=batch_size, shuffle=False)

Plotting the results

def plot_enc_dec(data,encoder,decoder,c_idx):

    pixel_intensity = data.reshape(3,-1).T.to(device).float()
    input = torch.concatenate([xy[c_idx],pixel_intensity[c_idx]],axis=1).float()

    encoder_out = encoder(input)
    encoder_out = torch.mean(encoder_out,dim=0)

    decoder_in = encoder_out.repeat(1024,1)
    decoder_in = torch.concatenate([xy,decoder_in],axis=1)

    img_out = decoder(decoder_in)
    return img_out.cpu().detach()
c_1 = np.array(random.sample(range(img_size*img_size),1))
c_10 = np.array(random.sample(range(img_size*img_size),10))
c_100 = np.array(random.sample(range(img_size*img_size),100))
c_1000 = np.array(random.sample(range(img_size*img_size),1000))

idx = 5
image_any = test_data_loader.dataset[idx]
data = image_any
plt.figure(figsize=(9,7),constrained_layout=True)
plt.suptitle("Neural process",fontsize=20)
def plot_image(i,j,k, data,encoder,decoder, c_idx):
    plt.subplot(i,j,k)
    img = data.permute(1,2,0)
    mask = np.zeros((32,32,3))
    mask[c_idx//32,c_idx%32,:] = 1
    plt.imshow(img*mask)
    plt.title(f"Context: {len(c_idx)}")
    plt.axis('off')

    plt.subplot(i,j,k+4)
    plot_image = plot_enc_dec(data,encoder,decoder,c_idx)
    plt.imshow(plot_image[:,:3].T.reshape(3,32,32).permute(1,2,0))
    plt.axis('off')

    plt.subplot(i,j,k+8)
    var =plot_image[:,3:].exp().T.reshape(3,32,32).permute(1,2,0)
    var = var-var.min()
    var = var/var.max()
    plt.imshow(var)
    plt.axis('off')
<Figure size 900x700 with 0 Axes>
plot_image(3,4,1,data,encoder,decoder,c_1)
plot_image(3,4,2,data,encoder,decoder,c_10)
plot_image(3,4,3,data,encoder,decoder,c_100)
plot_image(3,4,4,data,encoder,decoder,c_1000)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

The first row shows the test context points, second row shows our model prediction and third row shows the variance of the predicted image.

Conclusion

  • Here we have seen some implementaion of meta learning using hypernet and neural processes.
  • Further we can improve if we use sin activation function in the model for image like data. Refer link for more details.

The saga concludes with a symphony of learning, as Hypernets and Neural Processes rewrite the narrative of image reconstruction on CelebA, transforming a complex challenge into a melodious solution.