import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import torch
import os
import torch.utils.data as data
import torch.nn as nn
import glob
# from torchsummary import summary
from torchinfo import summary
from tqdm import trange
Image-to-Image for Climate Modelling using Auto-Encoders
We have implemented MLP autoencoder, CNN autoencoder, and UNET autoencoder for image to image prediction on Camx dataset.
Go through the link to see the presentation of the project to get the flow.
Git Repo: link
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device # device = torch.device("cuda:2")
print(device)
= device #torch.cuda.current_device()
current_device = torch.cuda.get_device_name(current_device)
device_name print(f"Current GPU assigned: {current_device}, Name: {device_name}")
cuda
Current GPU assigned: cuda, Name: Quadro RTX 5000
Consider it a prediction task, use any channel like pressure and/or wind speed inputs to start with, because they are the indicators of horizontal movement of air pollution. Multiple inputs can go as multiple channels of images similar to RGB. Output can be P25 or P10 in cam 120hr files.
def get_latitudes():
= 76.8499984741211
lat_start=0.009999999776482582
lat_step
=[]
latitudes
for i in range(80):
+i*lat_step)
latitudes.append(lat_start
latitudes.reverse()
return latitudes
def get_longitudes():
= 28.200000762939453
long_start=0.009999999776482582
long_step
=[]
longitudes
for i in range(80):
+i*long_step)
longitudes.append(long_start
# longitudes.reverse()
return longitudes
=get_latitudes()
latitudes=get_longitudes()
longitudes
def create_plot(data,hour,var_name):
# print(data[var_name].shape) #shape (120, 1, 80, 80)
=data[var_name]['TSTEP'==hour] # shape (1, 80, 80)
p10_hour=p10_hour[0,:,:] # shape (80, 80)
p10_hour
plt.imshow(p10_hour)f'{var_name} at hour '+str(hour))
plt.title(# plt.colorbar()
# only show every latitude and longitude of end points
# round to 2 decimal places
=latitudes[0]
top=round(top,2)
top=latitudes[-1]
bottom=round(bottom,2)
bottom=longitudes[0]
left=round(left,2)
left=longitudes[-1]
right=round(right,2)
right
0,79],[left,right])
plt.xticks(['Longitude')
plt.xlabel(0,79],[top,bottom])
plt.yticks(['Latitude')
plt.ylabel(# plt.savefig(f'plots/120/{var_name}_{day}.png')
# plt.close()
Exploratory Data Analysis
Visualizing 96 hr files
= xr.open_dataset('data/camxmet2d.delhi.20230717.96hours.nc')
data_96 = data_96.to_dataframe().reset_index() data_96_df
data_96
<xarray.Dataset> Dimensions: (TSTEP: 96, VAR: 14, DATE-TIME: 2, LAY: 1, ROW: 80, COL: 80) Dimensions without coordinates: TSTEP, VAR, DATE-TIME, LAY, ROW, COL Data variables: (12/15) TFLAG (TSTEP, VAR, DATE-TIME) int32 2023198 0 ... 2023201 230000 TSURF_K (TSTEP, LAY, ROW, COL) float32 302.3 302.3 302.3 ... 300.6 301.2 SNOWEW_M (TSTEP, LAY, ROW, COL) float32 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 SNOWAGE_HR (TSTEP, LAY, ROW, COL) float32 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 PRATE_MMpH (TSTEP, LAY, ROW, COL) float32 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 CLOUD_OD (TSTEP, LAY, ROW, COL) float32 62.24 61.67 61.1 ... 37.1 36.78 ... ... SWSFC_WpM2 (TSTEP, LAY, ROW, COL) float32 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 SOLM_M3pM3 (TSTEP, LAY, ROW, COL) float32 0.3131 0.3114 ... 0.3278 0.3292 CLDTOP_KM (TSTEP, LAY, ROW, COL) float32 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 CAPE (TSTEP, LAY, ROW, COL) float32 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 PBL_WRF_M (TSTEP, LAY, ROW, COL) float32 17.21 17.21 17.21 ... 94.4 120.0 PBL_YSU_M (TSTEP, LAY, ROW, COL) float32 17.21 17.21 17.21 ... 64.43 94.71 Attributes: (12/33) IOAPI_VERSION: $Id: @(#) ioapi library version 3.0 $ ... EXEC_ID: ???????????????? ... FTYPE: 1 CDATE: 2023198 CTIME: 73941 WDATE: 2023198 ... ... VGLVLS: [0. 0.] GDNAM: ???????????????? UPNAM: CAMx2IOAPI VAR-LIST: TSURF_K SNOWEW_M SNOWAGE_HR PRATE_MMp... FILEDESC: I/O API formatted CAMx AVRG output ... HISTORY:
data_96_df
TSTEP | VAR | DATE-TIME | LAY | ROW | COL | TFLAG | TSURF_K | SNOWEW_M | SNOWAGE_HR | ... | CLOUD_OD | U10_MpS | V10_MpS | T2_K | SWSFC_WpM2 | SOLM_M3pM3 | CLDTOP_KM | CAPE | PBL_WRF_M | PBL_YSU_M | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0 | 0 | 0 | 0 | 0 | 2023198 | 302.314636 | 0.0 | 0.0 | ... | 62.239326 | 1.630373 | 0.148722 | 302.571838 | 0.0 | 0.313128 | 0.0 | 0.0 | 17.212452 | 17.212452 |
1 | 0 | 0 | 0 | 0 | 0 | 1 | 2023198 | 302.312042 | 0.0 | 0.0 | ... | 61.671093 | 1.585281 | 0.158011 | 302.584351 | 0.0 | 0.311416 | 0.0 | 0.0 | 17.212652 | 17.212652 |
2 | 0 | 0 | 0 | 0 | 0 | 2 | 2023198 | 302.309479 | 0.0 | 0.0 | ... | 61.102859 | 1.540188 | 0.167299 | 302.596832 | 0.0 | 0.309703 | 0.0 | 0.0 | 17.212852 | 17.212852 |
3 | 0 | 0 | 0 | 0 | 0 | 3 | 2023198 | 302.306885 | 0.0 | 0.0 | ... | 60.534630 | 1.495095 | 0.176587 | 302.609344 | 0.0 | 0.307990 | 0.0 | 0.0 | 17.213055 | 17.213055 |
4 | 0 | 0 | 0 | 0 | 0 | 4 | 2023198 | 302.303558 | 0.0 | 0.0 | ... | 59.912636 | 1.450304 | 0.184773 | 302.621948 | 0.0 | 0.306286 | 0.0 | 0.0 | 17.213356 | 17.213356 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
17203195 | 95 | 13 | 1 | 0 | 79 | 75 | 230000 | 298.979095 | 0.0 | 0.0 | ... | 37.972748 | 0.600130 | -1.341342 | 299.389465 | 0.0 | 0.324286 | 0.0 | 0.0 | 26.661556 | 17.082989 |
17203196 | 95 | 13 | 1 | 0 | 79 | 76 | 230000 | 299.372223 | 0.0 | 0.0 | ... | 37.747482 | 0.658128 | -1.301898 | 299.638062 | 0.0 | 0.324927 | 0.0 | 0.0 | 43.123440 | 17.094244 |
17203197 | 95 | 13 | 1 | 0 | 79 | 77 | 230000 | 299.967468 | 0.0 | 0.0 | ... | 37.423786 | 0.732263 | -1.252025 | 300.018402 | 0.0 | 0.326345 | 0.0 | 0.0 | 68.759804 | 17.112419 |
17203198 | 95 | 13 | 1 | 0 | 79 | 78 | 230000 | 300.562744 | 0.0 | 0.0 | ... | 37.100090 | 0.806398 | -1.202152 | 300.398743 | 0.0 | 0.327763 | 0.0 | 0.0 | 94.396065 | 64.426094 |
17203199 | 95 | 13 | 1 | 0 | 79 | 79 | 230000 | 301.157501 | 0.0 | 0.0 | ... | 36.776646 | 0.880477 | -1.152316 | 300.778778 | 0.0 | 0.329180 | 0.0 | 0.0 | 120.012794 | 94.710510 |
17203200 rows × 21 columns
data_96_df.columns
Index(['TSTEP', 'VAR', 'DATE-TIME', 'LAY', 'ROW', 'COL', 'TFLAG', 'TSURF_K',
'SNOWEW_M', 'SNOWAGE_HR', 'PRATE_MMpH', 'CLOUD_OD', 'U10_MpS',
'V10_MpS', 'T2_K', 'SWSFC_WpM2', 'SOLM_M3pM3', 'CLDTOP_KM', 'CAPE',
'PBL_WRF_M', 'PBL_YSU_M'],
dtype='object')
'U10_MpS'] #shape (96, 1, 80, 80) data_96[
<xarray.DataArray 'U10_MpS' (TSTEP: 96, LAY: 1, ROW: 80, COL: 80)> array([[[[ 1.630373, ..., -0.434344], ..., [-0.544169, ..., -2.070899]]], ..., [[[ 0.569056, ..., 0.423629], ..., [ 0.774855, ..., 0.880477]]]], dtype=float32) Dimensions without coordinates: TSTEP, LAY, ROW, COL Attributes: long_name: U10_MpS units: ppmV var_desc: VARIABLE U10_MpS ...
'U10_MpS'].shape , data_96['U10_MpS']['TSTEP'==0].shape, data_96['U10_MpS']['TSTEP'==0][0].shape data_96[
((96, 1, 80, 80), (1, 80, 80), (80, 80))
'U10_MpS']['TSTEP'==0][0]) #shape (80, 80) plt.imshow(data_96[
<matplotlib.image.AxesImage at 0x7f76abc7cf70>
= data_96, hour = 1, var_name = 'U10_MpS') create_plot(data
Visualizing 120 hr files
= xr.open_dataset('data/camx120hr_merged_20230717.nc')
data_120 = data_120.to_dataframe().reset_index() data_120_df
data_120
<xarray.Dataset> Dimensions: (TSTEP: 120, LAY: 1, ROW: 80, COL: 80, VAR: 9, DATE-TIME: 2) Dimensions without coordinates: TSTEP, LAY, ROW, COL, VAR, DATE-TIME Data variables: P10 (TSTEP, LAY, ROW, COL) float32 23.86 23.86 24.07 ... 10.1 10.1 P25 (TSTEP, LAY, ROW, COL) float32 19.24 19.24 19.62 ... 9.63 9.63 TFLAG (TSTEP, VAR, DATE-TIME) int32 2023197 0 2023197 ... 2023201 230000 Attributes: (12/34) IOAPI_VERSION: $Id: @(#) ioapi library version 3.0 $ ... EXEC_ID: ???????????????? ... FTYPE: 1 CDATE: 2023197 CTIME: 83911 WDATE: 2023197 ... ... GDNAM: ???????????????? UPNAM: CAMXMETOU VAR-LIST: P10 P25 FILEDESC: I/O API formatted CAMx AVRG output ... HISTORY: Mon Jul 17 08:45:22 2023: ncrcat camxout.2023.07.16.nc ca... NCO: netCDF Operators version 4.9.1 (Homepage = http://nco.sf....
data_120_df
TSTEP | LAY | ROW | COL | VAR | DATE-TIME | P10 | P25 | TFLAG | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0 | 0 | 0 | 0 | 0 | 23.857107 | 19.240587 | 2023197 |
1 | 0 | 0 | 0 | 0 | 0 | 1 | 23.857107 | 19.240587 | 0 |
2 | 0 | 0 | 0 | 0 | 1 | 0 | 23.857107 | 19.240587 | 2023197 |
3 | 0 | 0 | 0 | 0 | 1 | 1 | 23.857107 | 19.240587 | 0 |
4 | 0 | 0 | 0 | 0 | 2 | 0 | 23.857107 | 19.240587 | 2023197 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
13823995 | 119 | 0 | 79 | 79 | 6 | 1 | 10.097546 | 9.629862 | 230000 |
13823996 | 119 | 0 | 79 | 79 | 7 | 0 | 10.097546 | 9.629862 | 2023201 |
13823997 | 119 | 0 | 79 | 79 | 7 | 1 | 10.097546 | 9.629862 | 230000 |
13823998 | 119 | 0 | 79 | 79 | 8 | 0 | 10.097546 | 9.629862 | 2023201 |
13823999 | 119 | 0 | 79 | 79 | 8 | 1 | 10.097546 | 9.629862 | 230000 |
13824000 rows × 9 columns
data_120_df.describe()
TSTEP | LAY | ROW | COL | VAR | DATE-TIME | P10 | P25 | TFLAG | |
---|---|---|---|---|---|---|---|---|---|
count | 1.382400e+07 | 13824000.0 | 1.382400e+07 | 1.382400e+07 | 1.382400e+07 | 13824000.0 | 1.382400e+07 | 1.382400e+07 | 1.382400e+07 |
mean | 5.950000e+01 | 0.0 | 3.950000e+01 | 3.950000e+01 | 4.000000e+00 | 0.5 | 4.101470e+01 | 2.856279e+01 | 1.069100e+06 |
std | 3.463981e+01 | 0.0 | 2.309221e+01 | 2.309221e+01 | 2.581989e+00 | 0.5 | 2.361775e+01 | 1.582030e+01 | 9.553543e+05 |
min | 0.000000e+00 | 0.0 | 0.000000e+00 | 0.000000e+00 | 0.000000e+00 | 0.0 | 1.111396e+00 | 8.114247e-01 | 0.000000e+00 |
25% | 2.975000e+01 | 0.0 | 1.975000e+01 | 1.975000e+01 | 2.000000e+00 | 0.0 | 2.546643e+01 | 1.880392e+01 | 1.175000e+05 |
50% | 5.950000e+01 | 0.0 | 3.950000e+01 | 3.950000e+01 | 4.000000e+00 | 0.5 | 3.590978e+01 | 2.589731e+01 | 1.126598e+06 |
75% | 8.925000e+01 | 0.0 | 5.925000e+01 | 5.925000e+01 | 6.000000e+00 | 1.0 | 5.103806e+01 | 3.433007e+01 | 2.023199e+06 |
max | 1.190000e+02 | 0.0 | 7.900000e+01 | 7.900000e+01 | 8.000000e+00 | 1.0 | 6.492913e+02 | 6.130998e+02 | 2.023201e+06 |
'COL'] data_120[
<xarray.DataArray 'COL' (COL: 80)> array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]) Dimensions without coordinates: COL
'P10'] #shape (120, 1, 80, 80) data_120[
<xarray.DataArray 'P10' (TSTEP: 120, LAY: 1, ROW: 80, COL: 80)> array([[[[23.857107, ..., 28.82367 ], ..., [34.902046, ..., 18.506985]]], ..., [[[38.434433, ..., 22.536842], ..., [20.921093, ..., 10.097546]]]], dtype=float32) Dimensions without coordinates: TSTEP, LAY, ROW, COL Attributes: long_name: CPRM units: micrograms/m**3 var_desc: VARIABLE CPRM ...
'P10']['TSTEP'==1]# shape 1x80x80 data_120[
<xarray.DataArray 'P10' (LAY: 1, ROW: 80, COL: 80)> array([[[23.857107, 23.857107, ..., 28.82367 , 28.82367 ], [23.857107, 23.857107, ..., 28.82367 , 28.82367 ], ..., [34.902046, 34.902046, ..., 18.506985, 18.506985], [34.902046, 34.902046, ..., 18.506985, 18.506985]]], dtype=float32) Dimensions without coordinates: LAY, ROW, COL Attributes: long_name: CPRM units: micrograms/m**3 var_desc: VARIABLE CPRM ...
'P10']['TSTEP'==1][0] # shape 80x80
data_120['P25']['TSTEP'==1][0])#, vmin=0, vmax=100) plt.imshow(data_120[
<matplotlib.image.AxesImage at 0x7f190d573e80>
1,'P25') create_plot(data_120,
Constructing data
# Define the folder path where your files are located
= 'data/'
folder_path
# Use glob to retrieve all netCDF files in the folder
= 'camx120*.nc'
file_pattern = glob.glob(folder_path + file_pattern)
files_120 # files_120
= '*96hours.nc'
file_pattern = glob.glob(folder_path + file_pattern)
files_96 # files_96
print(len(files_120))
files_120.sort()print(files_120[0], files_120[-1])
70
data/camx120hr_merged_20230717.nc data/camx120hr_merged_20230924.nc
print(len(files_96))
files_96.sort()print(files_96[0], files_96[-1])
69
data/camxmet2d.delhi.20230717.96hours.nc data/camxmet2d.delhi.20230924.96hours.nc
camx96hr has 20230808 day file missing. so removing it from camx120hr as well ‘data/camx120hr_merged_20230808.nc’
'data/camx120hr_merged_20230808.nc')
files_120.remove(len(files_120)
69
camx96hr file contains 96 hrs for days [0,1,2,3]
camx120hr file contains 120 hrs for days [-1,0,1,2,3]
We need current day data from each file to avoid redundancy. So we will take 0-23 hrs from camx96hr and 24-47hrs from camx120hr
def get_data(target_var_96_list, target_var_120_list):
= []
X = []
y
for target_var_96 in target_var_96_list:
= []
column for i in files_96:
= xr.open_dataset(i)
data_96 for j in range(0,24):
= data_96[target_var_96]['TSTEP'==j]
variable96
column.append(variable96)= np.array(column)
column # print(column.shape) # (69*24, 1, 80, 80)
X.append(column)= np.array(X)
X = np.concatenate([X[i] for i in range(X.shape[0])], axis=1)
X print('X shape ',X.shape)
for target_var_120 in target_var_120_list:
= []
column for i in files_120:
= xr.open_dataset(i)
data_120 for j in range(0,24):
= data_120[target_var_120]['TSTEP'==24+j]
variable120
column.append(variable120)= np.array(column)
column # print(column.shape) # (69*24, 1, 80, 80)
y.append(column)= np.array(y)
y = np.concatenate([y[i] for i in range(y.shape[0])], axis=1)
y print('y shape', y.shape)
return X,y
=['TSURF_K',
target_var_96_list 'SNOWEW_M', 'SNOWAGE_HR', 'PRATE_MMpH', 'CLOUD_OD', 'U10_MpS',
'V10_MpS', 'T2_K', 'SWSFC_WpM2', 'SOLM_M3pM3', 'CLDTOP_KM', 'CAPE',
'PBL_WRF_M', 'PBL_YSU_M'] # ['U10_MpS', 'T2_K', 'V10_MpS']
= ['P25','P10']
target_var_120_list = get_data(target_var_96_list, target_var_120_list) X,y
X shape (1656, 14, 80, 80)
y shape (1656, 2, 80, 80)
class CustomDataset(data.Dataset):
def __init__(self, input_data, output_data):
self.input_data = input_data
self.output_data = output_data
def __len__(self):
return len(self.input_data)
def __getitem__(self, index):
= torch.Tensor(self.input_data[index])
input_sample = torch.Tensor(self.output_data[index])
output_sample return input_sample, output_sample
Simple autoencoder
Model defination
class Autoencoder_MLP(nn.Module):
def __init__(self):
super(Autoencoder_MLP, self).__init__()
# Define the layers
self.flatten = nn.Flatten() # Flatten the 2D input matrix
self.fc1 = nn.Linear(80*80, 1024) # Fully connected layer 1
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 1024) # Fully connected layer 2
self.fc4 = nn.Linear(1024, 80*80) # Fully connected layer 2
self.relu = nn.ReLU() # Activation function
def forward(self, x):
# Flatten the input
= self.flatten(x)
x
# Forward pass through the fully connected layers
= self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.relu(x)
x = self.fc4(x)
x
# Reshape the output to match the 2D matrix size
= x.view(-1, 80, 80)
x
return x
=(1, 80, 80)) summary(Autoencoder_MLP(), input_size
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Autoencoder_MLP [1, 80, 80] --
├─Flatten: 1-1 [1, 6400] --
├─Linear: 1-2 [1, 1024] 6,554,624
├─ReLU: 1-3 [1, 1024] --
├─Linear: 1-4 [1, 512] 524,800
├─ReLU: 1-5 [1, 512] --
├─Linear: 1-6 [1, 1024] 525,312
├─ReLU: 1-7 [1, 1024] --
├─Linear: 1-8 [1, 6400] 6,560,000
==========================================================================================
Total params: 14,164,736
Trainable params: 14,164,736
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 14.16
==========================================================================================
Input size (MB): 0.03
Forward/backward pass size (MB): 0.07
Params size (MB): 56.66
Estimated Total Size (MB): 56.76
==========================================================================================
Training single channel input and output(P25)
X.shape, y.shape
((1656, 14, 80, 80), (1656, 2, 80, 80))
from sklearn.model_selection import train_test_split
= train_test_split(X, y, test_size=0.2, random_state=42)
X_train_all, X_test_all, y_train_all, y_test_all X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape
((1324, 14, 80, 80), (332, 14, 80, 80), (1324, 2, 80, 80), (332, 2, 80, 80))
= []
test_loss_list_P_25 = 0
y_channel # x_channel = 0
for x_channel in range(X.shape[1]):
####################### Selecting the channel #######################
print('X Channel name : ', target_var_96_list[x_channel])
= X_train_all[:, x_channel, :,:]
X_train = X_test_all[:, x_channel, :,:]
X_test = y_train_all[:, y_channel, :,:]
y_train = y_test_all[:, y_channel, :,:]
y_test print('Shapes: ', X_train.shape, X_test.shape, y_train.shape, y_test.shape)
####################### Creating the dataset loader #######################
= CustomDataset(X_train, y_train)
train_custom_dataset # print(len(train_custom_dataset))
= 32
batch_size = data.DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
train_loader # print(len(train_loader))
= CustomDataset(X_test, y_test)
test_custom_dataset # print(len(test_custom_dataset))
= 32
batch_size = data.DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader # print(len(test_loader))
#################### Training the model ####################
= Autoencoder_MLP()
model
model.to(device)# Define the loss function and optimizer
= nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = []
losses # Training loop
= 200
num_epochs for epoch in range(num_epochs):
# Set the model to training mode
model.train() = 0.0
total_loss
for inputs, targets in train_loader:
# Zero the gradients
optimizer.zero_grad() = inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
# Backpropagation
loss.backward()
optimizer.step()
+= loss.item()
total_loss
# Print the average loss for this epoch
= total_loss / len(train_loader)
average_loss
losses.append(average_loss)if epoch % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {average_loss:.4f}")
############################# testing the model #############################
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets
# Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
+= loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss
test_loss_list_P_25.append(average_test_loss)print(f"Average Test Loss: {average_test_loss:.4f}")
range(1, num_epochs + 1), losses)
plt.plot('Epoch')
plt.xlabel('Loss')
plt.ylabel('Train Loss vs. Epoch for channel '+target_var_96_list[x_channel])
plt.title(True)
plt.grid( plt.show()
X Channel name : TSURF_K
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 405.4315
Epoch [21/200] Loss: 214.5978
Epoch [41/200] Loss: 215.0236
Epoch [61/200] Loss: 207.2482
Epoch [81/200] Loss: 206.7534
Epoch [101/200] Loss: 213.1840
Epoch [121/200] Loss: 207.5252
Epoch [141/200] Loss: 208.2888
Epoch [161/200] Loss: 209.6494
Epoch [181/200] Loss: 209.7600
Average Test Loss: 190.0971
X Channel name : SNOWEW_M
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 464.9160
Epoch [21/200] Loss: 206.1055
Epoch [41/200] Loss: 210.0196
Epoch [61/200] Loss: 209.2211
Epoch [81/200] Loss: 211.8091
Epoch [101/200] Loss: 205.4592
Epoch [121/200] Loss: 203.9590
Epoch [141/200] Loss: 208.5030
Epoch [161/200] Loss: 207.1320
Epoch [181/200] Loss: 204.3017
Average Test Loss: 189.7174
X Channel name : SNOWAGE_HR
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 460.0334
Epoch [21/200] Loss: 207.1043
Epoch [41/200] Loss: 206.5882
Epoch [61/200] Loss: 208.3192
Epoch [81/200] Loss: 205.4035
Epoch [101/200] Loss: 209.7349
Epoch [121/200] Loss: 205.5515
Epoch [141/200] Loss: 206.1994
Epoch [161/200] Loss: 205.4001
Epoch [181/200] Loss: 206.3801
Average Test Loss: 190.0263
X Channel name : PRATE_MMpH
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 476.9623
Epoch [21/200] Loss: 207.3108
Epoch [41/200] Loss: 208.4855
Epoch [61/200] Loss: 206.4735
Epoch [81/200] Loss: 205.5895
Epoch [101/200] Loss: 205.6536
Epoch [121/200] Loss: 205.9680
Epoch [141/200] Loss: 205.7129
Epoch [161/200] Loss: 208.6799
Epoch [181/200] Loss: 205.6486
Average Test Loss: 189.6717
X Channel name : CLOUD_OD
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 599.1686
Epoch [21/200] Loss: 454.3787
Epoch [41/200] Loss: 149.9270
Epoch [61/200] Loss: 23.7072
Epoch [81/200] Loss: 7.4556
Epoch [101/200] Loss: 18.4050
Epoch [121/200] Loss: 3.0840
Epoch [141/200] Loss: 100.6344
Epoch [161/200] Loss: 25.9631
Epoch [181/200] Loss: 38.8335
Average Test Loss: 9.8902
X Channel name : U10_MpS
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 373.7542
Epoch [21/200] Loss: 51.2008
Epoch [41/200] Loss: 37.7057
Epoch [61/200] Loss: 25.3702
Epoch [81/200] Loss: 0.2791
Epoch [101/200] Loss: 20.1649
Epoch [121/200] Loss: 0.2297
Epoch [141/200] Loss: 84.8078
Epoch [161/200] Loss: 2.4612
Epoch [181/200] Loss: 1.4404
Average Test Loss: 0.7598
X Channel name : V10_MpS
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 399.2307
Epoch [21/200] Loss: 1.1263
Epoch [41/200] Loss: 0.0137
Epoch [61/200] Loss: 0.0820
Epoch [81/200] Loss: 2.3733
Epoch [101/200] Loss: 0.0002
Epoch [121/200] Loss: 0.0113
Epoch [141/200] Loss: 0.0393
Epoch [161/200] Loss: 11.0771
Epoch [181/200] Loss: 0.0003
Average Test Loss: 0.0507
X Channel name : T2_K
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 412.2657
Epoch [21/200] Loss: 210.9129
Epoch [41/200] Loss: 209.2686
Epoch [61/200] Loss: 206.7048
Epoch [81/200] Loss: 206.0187
Epoch [101/200] Loss: 209.1361
Epoch [121/200] Loss: 212.1444
Epoch [141/200] Loss: 206.8286
Epoch [161/200] Loss: 205.8083
Epoch [181/200] Loss: 208.7421
Average Test Loss: 190.9819
X Channel name : SWSFC_WpM2
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 485.4561
Epoch [21/200] Loss: 210.1752
Epoch [41/200] Loss: 207.2874
Epoch [61/200] Loss: 207.9026
Epoch [81/200] Loss: 206.2404
Epoch [101/200] Loss: 205.9271
Epoch [121/200] Loss: 205.6481
Epoch [141/200] Loss: 208.2400
Epoch [161/200] Loss: 205.4081
Epoch [181/200] Loss: 205.1859
Average Test Loss: 199.1508
X Channel name : SOLM_M3pM3
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 305.6402
Epoch [21/200] Loss: 176.6024
Epoch [41/200] Loss: 161.6851
Epoch [61/200] Loss: 157.7673
Epoch [81/200] Loss: 138.7664
Epoch [101/200] Loss: 132.1217
Epoch [121/200] Loss: 109.4900
Epoch [141/200] Loss: 104.2382
Epoch [161/200] Loss: 69.5860
Epoch [181/200] Loss: 50.0248
Average Test Loss: 42.1182
X Channel name : CLDTOP_KM
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 461.0668
Epoch [21/200] Loss: 205.2268
Epoch [41/200] Loss: 213.0740
Epoch [61/200] Loss: 206.9702
Epoch [81/200] Loss: 205.7769
Epoch [101/200] Loss: 210.3533
Epoch [121/200] Loss: 206.2578
Epoch [141/200] Loss: 207.2727
Epoch [161/200] Loss: 205.3765
Epoch [181/200] Loss: 206.2520
Average Test Loss: 192.2221
X Channel name : CAPE
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 468.8439
Epoch [21/200] Loss: 207.2851
Epoch [41/200] Loss: 211.7429
Epoch [61/200] Loss: 205.8536
Epoch [81/200] Loss: 205.3124
Epoch [101/200] Loss: 207.8155
Epoch [121/200] Loss: 208.3488
Epoch [141/200] Loss: 206.0120
Epoch [161/200] Loss: 205.9442
Epoch [181/200] Loss: 204.1215
Average Test Loss: 190.8858
X Channel name : PBL_WRF_M
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 287.4021
Epoch [21/200] Loss: 209.5118
Epoch [41/200] Loss: 209.2537
Epoch [61/200] Loss: 206.0375
Epoch [81/200] Loss: 207.1711
Epoch [101/200] Loss: 205.2734
Epoch [121/200] Loss: 210.4286
Epoch [141/200] Loss: 207.5243
Epoch [161/200] Loss: 208.9856
Epoch [181/200] Loss: 204.6973
Average Test Loss: 191.9669
X Channel name : PBL_YSU_M
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 323.2034
Epoch [21/200] Loss: 207.3925
Epoch [41/200] Loss: 210.0555
Epoch [61/200] Loss: 206.0877
Epoch [81/200] Loss: 207.9315
Epoch [101/200] Loss: 207.4798
Epoch [121/200] Loss: 206.1072
Epoch [141/200] Loss: 207.4259
Epoch [161/200] Loss: 209.3318
Epoch [181/200] Loss: 207.3006
Average Test Loss: 194.3565
test_loss_list_P_25
[190.09705144708806,
189.7173545143821,
190.02632834694603,
189.67173489657316,
9.890190774744207,
0.7598354017192667,
0.050707115029746834,
190.98193498091265,
199.15079567649147,
42.118152445012875,
192.22213883833453,
190.88581431995738,
191.9668634588068,
194.3564910888672]
= list(zip(target_var_96_list, test_loss_list_P_25))
name_loss_pairs
# Sort based on loss values
= sorted(name_loss_pairs, key=lambda x: x[1])
sorted_name_loss_pairs for pair in sorted_name_loss_pairs:
print(pair)
# print(sorted_name_loss_pairs)
# Select the lowest 3 names
= [pair[0] for pair in sorted_name_loss_pairs[:3]]
lowest_3_names
print("Lowest 3 names:", lowest_3_names)
('V10_MpS', 0.050707115029746834)
('U10_MpS', 0.7598354017192667)
('CLOUD_OD', 9.890190774744207)
('SOLM_M3pM3', 42.118152445012875)
('PRATE_MMpH', 189.67173489657316)
('SNOWEW_M', 189.7173545143821)
('SNOWAGE_HR', 190.02632834694603)
('TSURF_K', 190.09705144708806)
('CAPE', 190.88581431995738)
('T2_K', 190.98193498091265)
('PBL_WRF_M', 191.9668634588068)
('CLDTOP_KM', 192.22213883833453)
('PBL_YSU_M', 194.3564910888672)
('SWSFC_WpM2', 199.15079567649147)
Lowest 3 names: ['V10_MpS', 'U10_MpS', 'CLOUD_OD']
Training single channel input and output(P25)
= []
test_loss_list_P_10 = 1
y_channel # x_channel = 0
for x_channel in range(X.shape[1]):
####################### Selecting the channel #######################
print('X Channel name : ', target_var_96_list[x_channel])
= X_train_all[:, x_channel, :,:]
X_train = X_test_all[:, x_channel, :,:]
X_test = y_train_all[:, y_channel, :,:]
y_train = y_test_all[:, y_channel, :,:]
y_test print('Shapes: ', X_train.shape, X_test.shape, y_train.shape, y_test.shape)
####################### Creating the dataset loader #######################
= CustomDataset(X_train, y_train)
train_custom_dataset # print(len(train_custom_dataset))
= 32
batch_size = data.DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
train_loader # print(len(train_loader))
= CustomDataset(X_test, y_test)
test_custom_dataset # print(len(test_custom_dataset))
= 32
batch_size = data.DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader # print(len(test_loader))
#################### Training the model ####################
= Autoencoder_MLP()
model
model.to(device)# Define the loss function and optimizer
= nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = []
losses # Training loop
= 200
num_epochs for epoch in range(num_epochs):
# Set the model to training mode
model.train() = 0.0
total_loss
for inputs, targets in train_loader:
# Zero the gradients
optimizer.zero_grad() = inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
# Backpropagation
loss.backward()
optimizer.step()
+= loss.item()
total_loss
# Print the average loss for this epoch
= total_loss / len(train_loader)
average_loss
losses.append(average_loss)if epoch % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {average_loss:.4f}")
############################# testing the model #############################
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets
# Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
+= loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss
test_loss_list_P_10.append(average_test_loss)print(f"Average Test Loss: {average_test_loss:.4f}")
range(1, num_epochs + 1), losses)
plt.plot('Epoch')
plt.xlabel('Loss')
plt.ylabel('Train Loss vs. Epoch for channel '+target_var_96_list[x_channel])
plt.title(True)
plt.grid( plt.show()
X Channel name : TSURF_K
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 689.7490
Epoch [21/200] Loss: 418.6676
Epoch [41/200] Loss: 410.9685
Epoch [61/200] Loss: 404.8066
Epoch [81/200] Loss: 400.5351
Epoch [101/200] Loss: 401.0250
Epoch [121/200] Loss: 401.0659
Epoch [141/200] Loss: 403.6432
Epoch [161/200] Loss: 402.2360
Epoch [181/200] Loss: 407.2921
Average Test Loss: 401.8882
X Channel name : SNOWEW_M
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 1031.9999
Epoch [21/200] Loss: 405.7656
Epoch [41/200] Loss: 402.2811
Epoch [61/200] Loss: 403.8337
Epoch [81/200] Loss: 398.2591
Epoch [101/200] Loss: 409.4181
Epoch [121/200] Loss: 396.3737
Epoch [141/200] Loss: 395.1861
Epoch [161/200] Loss: 403.0066
Epoch [181/200] Loss: 397.6094
Average Test Loss: 361.2384
X Channel name : SNOWAGE_HR
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 1050.1397
Epoch [21/200] Loss: 397.1798
Epoch [41/200] Loss: 401.0453
Epoch [61/200] Loss: 401.7171
Epoch [81/200] Loss: 400.2689
Epoch [101/200] Loss: 396.7101
Epoch [121/200] Loss: 404.5816
Epoch [141/200] Loss: 397.7231
Epoch [161/200] Loss: 394.6489
Epoch [181/200] Loss: 399.8899
Average Test Loss: 362.0641
X Channel name : PRATE_MMpH
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 1034.3143
Epoch [21/200] Loss: 399.8861
Epoch [41/200] Loss: 399.7876
Epoch [61/200] Loss: 395.9112
Epoch [81/200] Loss: 400.4803
Epoch [101/200] Loss: 396.5552
Epoch [121/200] Loss: 401.0467
Epoch [141/200] Loss: 406.4987
Epoch [161/200] Loss: 401.3759
Epoch [181/200] Loss: 397.4671
Average Test Loss: 361.1263
X Channel name : CLOUD_OD
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 1214.4916
Epoch [21/200] Loss: 725.7131
Epoch [41/200] Loss: 220.5272
Epoch [61/200] Loss: 23.5934
Epoch [81/200] Loss: 8.8063
Epoch [101/200] Loss: 274.7881
Epoch [121/200] Loss: 218.0335
Epoch [141/200] Loss: 232.4360
Epoch [161/200] Loss: 212.6520
Epoch [181/200] Loss: 195.4785
Average Test Loss: 180.2198
X Channel name : U10_MpS
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 822.4214
Epoch [21/200] Loss: 182.0370
Epoch [41/200] Loss: 50.0221
Epoch [61/200] Loss: 60.5883
Epoch [81/200] Loss: 37.8001
Epoch [101/200] Loss: 2.6386
Epoch [121/200] Loss: 0.0629
Epoch [141/200] Loss: 76.5106
Epoch [161/200] Loss: 41.2225
Epoch [181/200] Loss: 4.6194
Average Test Loss: 0.6711
X Channel name : V10_MpS
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 1160.5402
Epoch [21/200] Loss: 8.7026
Epoch [41/200] Loss: 0.3045
Epoch [61/200] Loss: 0.1493
Epoch [81/200] Loss: 0.1235
Epoch [101/200] Loss: 5.4404
Epoch [121/200] Loss: 0.0000
Epoch [141/200] Loss: 10.2387
Epoch [161/200] Loss: 0.2644
Epoch [181/200] Loss: 0.0284
Average Test Loss: 0.6849
X Channel name : T2_K
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 646.6076
Epoch [21/200] Loss: 412.2161
Epoch [41/200] Loss: 407.6458
Epoch [61/200] Loss: 403.7684
Epoch [81/200] Loss: 399.9570
Epoch [101/200] Loss: 397.7386
Epoch [121/200] Loss: 402.0990
Epoch [141/200] Loss: 397.6674
Epoch [161/200] Loss: 398.1612
Epoch [181/200] Loss: 403.3343
Average Test Loss: 360.5232
X Channel name : SWSFC_WpM2
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 1027.1539
Epoch [21/200] Loss: 403.9408
Epoch [41/200] Loss: 401.9444
Epoch [61/200] Loss: 399.7967
Epoch [81/200] Loss: 401.7070
Epoch [101/200] Loss: 404.1586
Epoch [121/200] Loss: 398.2909
Epoch [141/200] Loss: 396.0758
Epoch [161/200] Loss: 396.4020
Epoch [181/200] Loss: 402.5728
Average Test Loss: 362.6420
X Channel name : SOLM_M3pM3
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 696.3634
Epoch [21/200] Loss: 331.1351
Epoch [41/200] Loss: 296.5541
Epoch [61/200] Loss: 301.9112
Epoch [81/200] Loss: 272.0759
Epoch [101/200] Loss: 251.3878
Epoch [121/200] Loss: 232.4279
Epoch [141/200] Loss: 170.0628
Epoch [161/200] Loss: 138.8842
Epoch [181/200] Loss: 108.4864
Average Test Loss: 110.6638
X Channel name : CLDTOP_KM
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 1017.6373
Epoch [21/200] Loss: 400.4651
Epoch [41/200] Loss: 405.2223
Epoch [61/200] Loss: 402.9269
Epoch [81/200] Loss: 400.7290
Epoch [101/200] Loss: 400.5154
Epoch [121/200] Loss: 404.1065
Epoch [141/200] Loss: 403.8573
Epoch [161/200] Loss: 397.0463
Epoch [181/200] Loss: 398.1819
Average Test Loss: 361.3055
X Channel name : CAPE
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 1029.8392
Epoch [21/200] Loss: 401.8889
Epoch [41/200] Loss: 399.1993
Epoch [61/200] Loss: 400.2419
Epoch [81/200] Loss: 396.8849
Epoch [101/200] Loss: 395.3639
Epoch [121/200] Loss: 401.5739
Epoch [141/200] Loss: 403.1417
Epoch [161/200] Loss: 402.3704
Epoch [181/200] Loss: 400.0658
Average Test Loss: 361.0597
X Channel name : PBL_WRF_M
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 536.0374
Epoch [21/200] Loss: 401.9832
Epoch [41/200] Loss: 402.4697
Epoch [61/200] Loss: 404.1890
Epoch [81/200] Loss: 400.4949
Epoch [101/200] Loss: 399.0872
Epoch [121/200] Loss: 398.4192
Epoch [141/200] Loss: 400.4667
Epoch [161/200] Loss: 398.1382
Epoch [181/200] Loss: 408.3430
Average Test Loss: 360.3921
X Channel name : PBL_YSU_M
Shapes: (1324, 80, 80) (332, 80, 80) (1324, 80, 80) (332, 80, 80)
Epoch [1/200] Loss: 538.3320
Epoch [21/200] Loss: 405.8272
Epoch [41/200] Loss: 430.2409
Epoch [61/200] Loss: 398.9199
Epoch [81/200] Loss: 404.8572
Epoch [101/200] Loss: 397.2876
Epoch [121/200] Loss: 403.3052
Epoch [141/200] Loss: 401.7307
Epoch [161/200] Loss: 396.1157
Epoch [181/200] Loss: 399.0870
Average Test Loss: 361.4502
test_loss_list_P_10
[401.88824185458094,
361.23838112571025,
362.0641340775923,
361.1262997713956,
180.2197820490057,
0.6710819737477736,
0.6848962740464644,
360.5232141668146,
362.6420177112926,
110.6638252951882,
361.3054504394531,
361.05967018821025,
360.39211342551494,
361.45016063343394]
= list(zip(target_var_96_list, test_loss_list_P_10))
name_loss_pairs
# Sort based on loss values
= sorted(name_loss_pairs, key=lambda x: x[1])
sorted_name_loss_pairs for pair in sorted_name_loss_pairs:
print(pair)
# print(sorted_name_loss_pairs)
# Select the lowest 3 names
= [pair[0] for pair in sorted_name_loss_pairs[:3]]
lowest_3_names
print("Lowest 3 names:", lowest_3_names)
('U10_MpS', 0.6710819737477736)
('V10_MpS', 0.6848962740464644)
('SOLM_M3pM3', 110.6638252951882)
('CLOUD_OD', 180.2197820490057)
('PBL_WRF_M', 360.39211342551494)
('T2_K', 360.5232141668146)
('CAPE', 361.05967018821025)
('PRATE_MMpH', 361.1262997713956)
('SNOWEW_M', 361.23838112571025)
('CLDTOP_KM', 361.3054504394531)
('PBL_YSU_M', 361.45016063343394)
('SNOWAGE_HR', 362.0641340775923)
('SWSFC_WpM2', 362.6420177112926)
('TSURF_K', 401.88824185458094)
Lowest 3 names: ['U10_MpS', 'V10_MpS', 'SOLM_M3pM3']
Insights
# Adding values at corresponding indices
= [(x + y)/2 for x, y in zip(test_loss_list_P_10, test_loss_list_P_25)]
result
# Creating a DataFrame
= {'Input channel': target_var_96_list, 'P10': test_loss_list_P_10, 'P25': test_loss_list_P_25, 'P10+P25 avg': result}
data_frame = pd.DataFrame(data_frame)
df
# Sorting the DataFrame based on "List1 + List2"
= df.sort_values(by='P10+P25 avg')
df_sorted
# Displaying the sorted DataFrame
= df_sorted.round(1)
df_rounded '/home/rishabh.mondal/climax_alternative/Climax_2/results/test_loss_list_P_10_P25_MLP_auto.csv', index=False)
df_rounded.to_csv( df_rounded
Input channel | P10 | P25 | P10+P25 avg | |
---|---|---|---|---|
6 | V10_MpS | 0.7 | 0.1 | 0.4 |
5 | U10_MpS | 0.7 | 0.8 | 0.7 |
9 | SOLM_M3pM3 | 110.7 | 42.1 | 76.4 |
4 | CLOUD_OD | 180.2 | 9.9 | 95.1 |
3 | PRATE_MMpH | 361.1 | 189.7 | 275.4 |
1 | SNOWEW_M | 361.2 | 189.7 | 275.5 |
7 | T2_K | 360.5 | 191.0 | 275.8 |
11 | CAPE | 361.1 | 190.9 | 276.0 |
2 | SNOWAGE_HR | 362.1 | 190.0 | 276.0 |
12 | PBL_WRF_M | 360.4 | 192.0 | 276.2 |
10 | CLDTOP_KM | 361.3 | 192.2 | 276.8 |
13 | PBL_YSU_M | 361.5 | 194.4 | 277.9 |
8 | SWSFC_WpM2 | 362.6 | 199.2 | 280.9 |
0 | TSURF_K | 401.9 | 190.1 | 296.0 |
Autoencoder with convolutional layers
Model defination
class Encoder(nn.Module):
def __init__(self, image_size, num_input_channels, num_output_channels = 1, c_hid = 16, latent_dim = 1024, activation= nn.GELU):
super(Encoder, self).__init__()
self.image_size = image_size
self.num_input_channels = num_input_channels
self.num_output_channels = num_output_channels
self.c_hid = c_hid
self.latent_dim = latent_dim
self.activation = activation
self.net = nn.Sequential(
=self.num_input_channels, out_channels=self.c_hid, kernel_size=3, stride=2, padding=1),
nn.Conv2d(in_channelsself.activation(),
=self.c_hid, out_channels=self.c_hid, kernel_size=3, padding=1),
nn.Conv2d(in_channelsself.activation(),
=self.c_hid, out_channels=2 * self.c_hid, kernel_size=3, stride=2, padding=1),
nn.Conv2d(in_channelsself.activation(),
=2 * self.c_hid, out_channels=2 * self.c_hid, kernel_size=3, padding=1),
nn.Conv2d(in_channelsself.activation(),
=2 * self.c_hid, out_channels=4*self.c_hid, kernel_size=3, stride=2, padding=1),
nn.Conv2d(in_channelsself.activation(),
nn.Flatten(),4 * self.c_hid) * self.image_size//8 * self.image_size//8, self.latent_dim)
nn.Linear((
)
def forward(self, x):
return self.net(x)
class Decoder(nn.Module):
def __init__(self, image_size, num_input_channels, num_output_channels = 1,c_hid = 16, latent_dim = 1024, activation= nn.GELU):
super(Decoder, self).__init__()
self.image_size = image_size
self.num_input_channels = num_input_channels
self.num_output_channels = num_output_channels
self.c_hid = c_hid
self.latent_dim = latent_dim
self.activation = activation
self.decoder = nn.Sequential(
self.latent_dim, (4 * self.c_hid) * self.image_size//8 * self.image_size//8),
nn.Linear(1, (4*self.c_hid, self.image_size//8, self.image_size//8)),
nn.Unflatten(=4*self.c_hid, out_channels= 2*self.c_hid, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ConvTranspose2d(in_channelsself.activation(),
= 2*self.c_hid, out_channels= 2*self.c_hid, kernel_size=3, padding=1),
nn.Conv2d(in_channels self.activation(),
= 2*self.c_hid, out_channels = self.c_hid, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ConvTranspose2d(in_channels self.activation(),
= self.c_hid,out_channels= self.c_hid, kernel_size=3, padding=1),
nn.Conv2d(in_channels self.activation(),
= self.c_hid, out_channels= self.num_output_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ConvTranspose2d(in_channels
)def forward(self, x):
return self.decoder(x)
class Autoencoder_CNN(nn.Module):
def __init__(self, image_size, num_input_channels, num_output_channels = 1,c_hid = 16, latent_dim = 1024, activation= nn.GELU, encoder = Encoder, decoder = Decoder):
super(Autoencoder_CNN, self).__init__()
self.image_size = image_size
self.num_input_channels = num_input_channels
self.num_output_channels = num_output_channels
self.c_hid = c_hid
self.latent_dim = latent_dim
self.activation = activation
self.encoder = encoder(self.image_size, self.num_input_channels, self.num_output_channels, self.c_hid, self.latent_dim, self.activation)
self.decoder = decoder(self.image_size, self.num_input_channels, self.num_output_channels, self.c_hid, self.latent_dim, self.activation)
def forward(self, x):
= self.encoder(x)
x # print(x.shape)
= self.decoder(x)
x return x
= Autoencoder_CNN(image_size = 80, num_input_channels = 14, num_output_channels=2, c_hid = 16, latent_dim = 1024, activation= nn.GELU)
model_auto
model_auto= torch.randn(11, 14, 80, 80) # Assuming input size of (batch_size, num_num_input_channels, height, width)
dummy_input = model_auto(dummy_input)
out print(out.shape)
print(model_auto)
torch.Size([11, 2, 80, 80])
Autoencoder_CNN(
(encoder): Encoder(
(net): Sequential(
(0): Conv2d(14, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): GELU(approximate='none')
(2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): GELU(approximate='none')
(4): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(5): GELU(approximate='none')
(6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): GELU(approximate='none')
(8): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(9): GELU(approximate='none')
(10): Flatten(start_dim=1, end_dim=-1)
(11): Linear(in_features=6400, out_features=1024, bias=True)
)
)
(decoder): Decoder(
(decoder): Sequential(
(0): Linear(in_features=1024, out_features=6400, bias=True)
(1): Unflatten(dim=1, unflattened_size=(64, 10, 10))
(2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(3): GELU(approximate='none')
(4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): GELU(approximate='none')
(6): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(7): GELU(approximate='none')
(8): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): GELU(approximate='none')
(10): ConvTranspose2d(16, 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
)
)
)
from torchinfo import summary
=(1656, 14, 80, 80)) summary(model_auto, input_size
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Autoencoder_CNN [1656, 2, 80, 80] --
├─Encoder: 1-1 [1656, 1024] --
│ └─Sequential: 2-1 [1656, 1024] --
│ │ └─Conv2d: 3-1 [1656, 16, 40, 40] 2,032
│ │ └─GELU: 3-2 [1656, 16, 40, 40] --
│ │ └─Conv2d: 3-3 [1656, 16, 40, 40] 2,320
│ │ └─GELU: 3-4 [1656, 16, 40, 40] --
│ │ └─Conv2d: 3-5 [1656, 32, 20, 20] 4,640
│ │ └─GELU: 3-6 [1656, 32, 20, 20] --
│ │ └─Conv2d: 3-7 [1656, 32, 20, 20] 9,248
│ │ └─GELU: 3-8 [1656, 32, 20, 20] --
│ │ └─Conv2d: 3-9 [1656, 64, 10, 10] 18,496
│ │ └─GELU: 3-10 [1656, 64, 10, 10] --
│ │ └─Flatten: 3-11 [1656, 6400] --
│ │ └─Linear: 3-12 [1656, 1024] 6,554,624
├─Decoder: 1-2 [1656, 2, 80, 80] --
│ └─Sequential: 2-2 [1656, 2, 80, 80] --
│ │ └─Linear: 3-13 [1656, 6400] 6,560,000
│ │ └─Unflatten: 3-14 [1656, 64, 10, 10] --
│ │ └─ConvTranspose2d: 3-15 [1656, 32, 20, 20] 18,464
│ │ └─GELU: 3-16 [1656, 32, 20, 20] --
│ │ └─Conv2d: 3-17 [1656, 32, 20, 20] 9,248
│ │ └─GELU: 3-18 [1656, 32, 20, 20] --
│ │ └─ConvTranspose2d: 3-19 [1656, 16, 40, 40] 4,624
│ │ └─GELU: 3-20 [1656, 16, 40, 40] --
│ │ └─Conv2d: 3-21 [1656, 16, 40, 40] 2,320
│ │ └─GELU: 3-22 [1656, 16, 40, 40] --
│ │ └─ConvTranspose2d: 3-23 [1656, 2, 80, 80] 290
==========================================================================================
Total params: 13,186,306
Trainable params: 13,186,306
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 85.34
==========================================================================================
Input size (MB): 593.51
Forward/backward pass size (MB): 2387.61
Params size (MB): 52.75
Estimated Total Size (MB): 3033.86
==========================================================================================
from torchsummary import summary
=(14, 80, 80)) summary(model_auto, input_size
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 16, 40, 40] 2,032
GELU-2 [-1, 16, 40, 40] 0
Conv2d-3 [-1, 16, 40, 40] 2,320
GELU-4 [-1, 16, 40, 40] 0
Conv2d-5 [-1, 32, 20, 20] 4,640
GELU-6 [-1, 32, 20, 20] 0
Conv2d-7 [-1, 32, 20, 20] 9,248
GELU-8 [-1, 32, 20, 20] 0
Conv2d-9 [-1, 64, 10, 10] 18,496
GELU-10 [-1, 64, 10, 10] 0
Flatten-11 [-1, 6400] 0
Linear-12 [-1, 1024] 6,554,624
Encoder-13 [-1, 1024] 0
Linear-14 [-1, 6400] 6,560,000
Unflatten-15 [-1, 64, 10, 10] 0
ConvTranspose2d-16 [-1, 32, 20, 20] 18,464
GELU-17 [-1, 32, 20, 20] 0
Conv2d-18 [-1, 32, 20, 20] 9,248
GELU-19 [-1, 32, 20, 20] 0
ConvTranspose2d-20 [-1, 16, 40, 40] 4,624
GELU-21 [-1, 16, 40, 40] 0
Conv2d-22 [-1, 16, 40, 40] 2,320
GELU-23 [-1, 16, 40, 40] 0
ConvTranspose2d-24 [-1, 2, 80, 80] 290
Decoder-25 [-1, 2, 80, 80] 0
================================================================
Total params: 13,186,306
Trainable params: 13,186,306
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.34
Forward/backward pass size (MB): 2.80
Params size (MB): 50.30
Estimated Total Size (MB): 53.44
----------------------------------------------------------------
training on all channel and predicting on all channel
=['TSURF_K',
target_var_96_list 'SNOWEW_M', 'SNOWAGE_HR', 'PRATE_MMpH', 'CLOUD_OD', 'U10_MpS',
'V10_MpS', 'T2_K', 'SWSFC_WpM2', 'SOLM_M3pM3', 'CLDTOP_KM', 'CAPE',
'PBL_WRF_M', 'PBL_YSU_M'] # ['U10_MpS', 'T2_K', 'V10_MpS']
= ['P25','P10']
target_var_120_list = get_data(target_var_96_list, target_var_120_list)
X,y
from sklearn.model_selection import train_test_split
= train_test_split(X, y, test_size=0.2, random_state=42)
X_train_all, X_test_all, y_train_all, y_test_all X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape
X shape (1656, 14, 80, 80)
y shape (1656, 2, 80, 80)
((1324, 14, 80, 80), (332, 14, 80, 80), (1324, 2, 80, 80), (332, 2, 80, 80))
print(X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape)
= CustomDataset(X_train_all, y_train_all)
train_custom_dataset # print(len(train_custom_dataset))
= 32
batch_size = data.DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
train_loader # print(len(train_loader))
= CustomDataset(X_test_all, y_test_all)
test_custom_dataset # print(len(test_custom_dataset))
= 32
batch_size = data.DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader # print(len(test_loader))
#################### Training the model ####################
= Autoencoder_CNN(image_size = 80, num_input_channels = 14, num_output_channels=2, c_hid = 64, latent_dim = 2048, activation= nn.GELU)
model
model.to(device) from torchinfo import summary
print(summary(model, input_size=(1656, 14, 80, 80)))
# Define the loss function and optimizer
= nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = []
losses # Training loop
= 50 #200
num_epochs for epoch in trange(num_epochs):
# Set the model to training mode
model.train() = 0.0
total_loss
for inputs, targets in train_loader:
# Zero the gradients
optimizer.zero_grad() = inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
# Backpropagation
loss.backward()
optimizer.step()
+= loss.item()
total_loss
# Print the average loss for this epoch
= total_loss / len(train_loader)
average_loss
losses.append(average_loss)if epoch % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {average_loss:.4f}")
############################# testing the model #############################
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets
# Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
+= loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss # test_loss_list_P_25.append(average_test_loss)
print(f"Average Test Loss: {average_test_loss:.4f}")
range(1, num_epochs + 1), losses)
plt.plot('Epoch')
plt.xlabel('Loss')
plt.ylabel(# plt.title('Train Loss vs. Epoch for channel '+target_var_96_list[x_channel])
True)
plt.grid( plt.show()
(1324, 14, 80, 80) (332, 14, 80, 80) (1324, 2, 80, 80) (332, 2, 80, 80)
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Autoencoder_CNN [1656, 2, 80, 80] --
├─Encoder: 1-1 [1656, 2048] --
│ └─Sequential: 2-1 [1656, 2048] --
│ │ └─Conv2d: 3-1 [1656, 64, 40, 40] 8,128
│ │ └─GELU: 3-2 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-3 [1656, 64, 40, 40] 36,928
│ │ └─GELU: 3-4 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-5 [1656, 128, 20, 20] 73,856
│ │ └─GELU: 3-6 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-7 [1656, 128, 20, 20] 147,584
│ │ └─GELU: 3-8 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-9 [1656, 256, 10, 10] 295,168
│ │ └─GELU: 3-10 [1656, 256, 10, 10] --
│ │ └─Flatten: 3-11 [1656, 25600] --
│ │ └─Linear: 3-12 [1656, 2048] 52,430,848
├─Decoder: 1-2 [1656, 2, 80, 80] --
│ └─Sequential: 2-2 [1656, 2, 80, 80] --
│ │ └─Linear: 3-13 [1656, 25600] 52,454,400
│ │ └─Unflatten: 3-14 [1656, 256, 10, 10] --
│ │ └─ConvTranspose2d: 3-15 [1656, 128, 20, 20] 295,040
│ │ └─GELU: 3-16 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-17 [1656, 128, 20, 20] 147,584
│ │ └─GELU: 3-18 [1656, 128, 20, 20] --
│ │ └─ConvTranspose2d: 3-19 [1656, 64, 40, 40] 73,792
│ │ └─GELU: 3-20 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-21 [1656, 64, 40, 40] 36,928
│ │ └─GELU: 3-22 [1656, 64, 40, 40] --
│ │ └─ConvTranspose2d: 3-23 [1656, 2, 80, 80] 1,154
==========================================================================================
Total params: 106,001,410
Trainable params: 106,001,410
Non-trainable params: 0
Total mult-adds (Units.TERABYTES): 1.09
==========================================================================================
Input size (MB): 593.51
Forward/backward pass size (MB): 9014.58
Params size (MB): 424.01
Estimated Total Size (MB): 10032.09
==========================================================================================
2%|▏ | 1/50 [00:04<04:03, 4.96s/it]
Epoch [1/50] Loss: 1620.3361
Average Test Loss: 377.9475
22%|██▏ | 11/50 [00:58<03:15, 5.00s/it]
Epoch [11/50] Loss: 249.2443
Average Test Loss: 228.7919
42%|████▏ | 21/50 [01:33<01:40, 3.45s/it]
Epoch [21/50] Loss: 168.6332
Average Test Loss: 172.4825
62%|██████▏ | 31/50 [02:21<01:14, 3.90s/it]
Epoch [31/50] Loss: 102.0034
Average Test Loss: 111.1758
80%|████████ | 40/50 [02:49<00:32, 3.22s/it]
Epoch [41/50] Loss: 65.1520
82%|████████▏ | 41/50 [02:53<00:30, 3.39s/it]
Average Test Loss: 60.7862
100%|██████████| 50/50 [03:30<00:00, 4.22s/it]
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets
# Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
+= loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss # test_loss_list_P_25.append(average_test_loss)
print(f"Average Test Loss: {average_test_loss:.4f}")
Average Test Loss: 20.4597
'model/auto_conv_in14_out2.pt')
torch.save(model.state_dict(), = Autoencoder_CNN(image_size = 80, num_input_channels = 14, num_output_channels=2, c_hid = 64, latent_dim = 2048, activation= nn.GELU)
model
model.to(device) 'model/auto_conv_in14_out2.pt')) model.load_state_dict(torch.load(
<All keys matched successfully>
Single channel input single channel output (P25)
=['TSURF_K',
target_var_96_list 'SNOWEW_M', 'SNOWAGE_HR', 'PRATE_MMpH', 'CLOUD_OD', 'U10_MpS',
'V10_MpS', 'T2_K', 'SWSFC_WpM2', 'SOLM_M3pM3', 'CLDTOP_KM', 'CAPE',
'PBL_WRF_M', 'PBL_YSU_M'] # ['U10_MpS', 'T2_K', 'V10_MpS']
= ['P25','P10']
target_var_120_list = get_data(target_var_96_list, target_var_120_list)
X,y
from sklearn.model_selection import train_test_split
= train_test_split(X, y, test_size=0.2, random_state=42)
X_train_all, X_test_all, y_train_all, y_test_all X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape
X shape (1656, 14, 80, 80)
y shape (1656, 2, 80, 80)
((1324, 14, 80, 80), (332, 14, 80, 80), (1324, 2, 80, 80), (332, 2, 80, 80))
= []
test_loss_list_P_25_conv_auto = 0 # selecting P25 as output
y_channel # x_channel = 0
for x_channel in range(X.shape[1]):
####################### Selecting the channel #######################
print('X Channel name : ', target_var_96_list[x_channel])
= X_train_all[:, x_channel:x_channel+1, :,:]
X_train = X_test_all[:, x_channel:x_channel+1, :,:]
X_test = y_train_all[:, y_channel:y_channel+1, :,:]
y_train = y_test_all[:, y_channel:y_channel+1, :,:]
y_test print('Shapes: ', X_train.shape, X_test.shape, y_train.shape, y_test.shape)
####################### Creating the dataset loader #######################
= CustomDataset(X_train, y_train)
train_custom_dataset # print(len(train_custom_dataset))
= 32
batch_size = data.DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
train_loader # print(len(train_loader))
= CustomDataset(X_test, y_test)
test_custom_dataset # print(len(test_custom_dataset))
= 32
batch_size = data.DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader # print(len(test_loader))
#################### Training the model ####################
= Autoencoder_CNN(image_size = 80, num_input_channels = 1, num_output_channels=1, c_hid = 8, latent_dim = 512, activation= nn.GELU)
model
model.to(device)# Define the loss function and optimizer
= nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = []
losses # Training loop
= 200
num_epochs for epoch in trange(num_epochs):
# Set the model to training mode
model.train() = 0.0
total_loss
for inputs, targets in train_loader:
# Zero the gradients
optimizer.zero_grad() = inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
# Backpropagation
loss.backward()
optimizer.step()
+= loss.item()
total_loss
# Print the average loss for this epoch
= total_loss / len(train_loader)
average_loss
losses.append(average_loss)if epoch % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {average_loss:.4f}")
############################# testing the model #############################
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs # Calculate the loss
= criterion(outputs, targets)
loss += loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss
test_loss_list_P_25_conv_auto.append(average_test_loss)print(f"Average Test Loss: {average_test_loss:.4f}")
range(1, num_epochs + 1), losses)
plt.plot('Epoch')
plt.xlabel('Loss')
plt.ylabel('Train Loss vs. Epoch for channel '+target_var_96_list[x_channel])
plt.title(True)
plt.grid( plt.show()
X Channel name : TSURF_K
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<02:14, 1.48it/s]
Epoch [1/200] Loss: 462.4725
10%|█ | 21/200 [00:06<00:52, 3.42it/s]
Epoch [21/200] Loss: 212.0043
20%|██ | 41/200 [00:12<00:48, 3.29it/s]
Epoch [41/200] Loss: 213.3173
30%|███ | 61/200 [00:18<00:40, 3.41it/s]
Epoch [61/200] Loss: 234.0997
40%|████ | 81/200 [00:24<00:34, 3.43it/s]
Epoch [81/200] Loss: 211.7991
50%|█████ | 101/200 [00:30<00:29, 3.39it/s]
Epoch [101/200] Loss: 216.9750
60%|██████ | 121/200 [00:35<00:23, 3.42it/s]
Epoch [121/200] Loss: 211.1435
70%|███████ | 141/200 [00:41<00:17, 3.46it/s]
Epoch [141/200] Loss: 214.2518
80%|████████ | 161/200 [00:47<00:11, 3.45it/s]
Epoch [161/200] Loss: 209.3667
90%|█████████ | 181/200 [00:53<00:05, 3.44it/s]
Epoch [181/200] Loss: 208.2786
100%|██████████| 200/200 [00:58<00:00, 3.39it/s]
Average Test Loss: 196.8076
X Channel name : SNOWEW_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.37it/s]
Epoch [1/200] Loss: 540.4606
10%|█ | 21/200 [00:06<00:52, 3.44it/s]
Epoch [21/200] Loss: 216.1673
20%|██ | 41/200 [00:12<00:47, 3.35it/s]
Epoch [41/200] Loss: 211.7699
30%|███ | 61/200 [00:17<00:40, 3.39it/s]
Epoch [61/200] Loss: 211.9133
40%|████ | 81/200 [00:23<00:34, 3.40it/s]
Epoch [81/200] Loss: 214.4790
50%|█████ | 101/200 [00:29<00:29, 3.39it/s]
Epoch [101/200] Loss: 212.5737
60%|██████ | 121/200 [00:35<00:23, 3.38it/s]
Epoch [121/200] Loss: 212.8330
70%|███████ | 141/200 [00:41<00:17, 3.44it/s]
Epoch [141/200] Loss: 209.0824
80%|████████ | 161/200 [00:47<00:11, 3.41it/s]
Epoch [161/200] Loss: 207.7827
90%|█████████ | 181/200 [00:53<00:05, 3.44it/s]
Epoch [181/200] Loss: 208.2993
100%|██████████| 200/200 [00:58<00:00, 3.42it/s]
Average Test Loss: 191.6095
X Channel name : SNOWAGE_HR
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:57, 3.49it/s]
Epoch [1/200] Loss: 590.1837
10%|█ | 21/200 [00:06<00:52, 3.43it/s]
Epoch [21/200] Loss: 211.4096
20%|██ | 41/200 [00:11<00:46, 3.42it/s]
Epoch [41/200] Loss: 213.1281
30%|███ | 61/200 [00:17<00:40, 3.45it/s]
Epoch [61/200] Loss: 208.5330
40%|████ | 81/200 [00:23<00:33, 3.58it/s]
Epoch [81/200] Loss: 209.0068
50%|█████ | 101/200 [00:28<00:27, 3.61it/s]
Epoch [101/200] Loss: 211.9243
60%|██████ | 121/200 [00:34<00:21, 3.60it/s]
Epoch [121/200] Loss: 212.2428
70%|███████ | 141/200 [00:40<00:16, 3.48it/s]
Epoch [141/200] Loss: 210.3188
80%|████████ | 161/200 [00:46<00:11, 3.40it/s]
Epoch [161/200] Loss: 211.6509
90%|█████████ | 181/200 [00:51<00:05, 3.43it/s]
Epoch [181/200] Loss: 212.8352
100%|██████████| 200/200 [00:57<00:00, 3.47it/s]
Average Test Loss: 191.7548
X Channel name : PRATE_MMpH
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:56, 3.52it/s]
Epoch [1/200] Loss: 750.4891
10%|█ | 21/200 [00:06<00:52, 3.44it/s]
Epoch [21/200] Loss: 213.9414
20%|██ | 41/200 [00:12<00:46, 3.39it/s]
Epoch [41/200] Loss: 214.3703
30%|███ | 61/200 [00:17<00:41, 3.38it/s]
Epoch [61/200] Loss: 209.9882
40%|████ | 81/200 [00:23<00:34, 3.43it/s]
Epoch [81/200] Loss: 210.4570
50%|█████ | 101/200 [00:29<00:29, 3.37it/s]
Epoch [101/200] Loss: 211.3801
60%|██████ | 121/200 [00:35<00:21, 3.76it/s]
Epoch [121/200] Loss: 214.6792
70%|███████ | 141/200 [00:41<00:17, 3.42it/s]
Epoch [141/200] Loss: 210.8055
80%|████████ | 161/200 [00:47<00:11, 3.41it/s]
Epoch [161/200] Loss: 214.0414
90%|█████████ | 181/200 [00:52<00:05, 3.61it/s]
Epoch [181/200] Loss: 206.5140
100%|██████████| 200/200 [00:58<00:00, 3.45it/s]
Average Test Loss: 194.9402
X Channel name : CLOUD_OD
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:54, 3.63it/s]
Epoch [1/200] Loss: 563.3871
10%|█ | 21/200 [00:05<00:46, 3.82it/s]
Epoch [21/200] Loss: 63.2729
20%|██ | 41/200 [00:11<00:45, 3.51it/s]
Epoch [41/200] Loss: 23.8580
30%|███ | 61/200 [00:17<00:40, 3.42it/s]
Epoch [61/200] Loss: 15.2393
40%|████ | 81/200 [00:23<00:37, 3.14it/s]
Epoch [81/200] Loss: 11.0129
50%|█████ | 101/200 [00:29<00:29, 3.32it/s]
Epoch [101/200] Loss: 9.3018
60%|██████ | 121/200 [00:35<00:22, 3.45it/s]
Epoch [121/200] Loss: 8.1090
70%|███████ | 141/200 [00:41<00:17, 3.40it/s]
Epoch [141/200] Loss: 7.4336
80%|████████ | 161/200 [00:47<00:11, 3.38it/s]
Epoch [161/200] Loss: 8.8601
90%|█████████ | 181/200 [00:53<00:05, 3.42it/s]
Epoch [181/200] Loss: 6.1420
100%|██████████| 200/200 [00:58<00:00, 3.40it/s]
Average Test Loss: 7.4988
X Channel name : U10_MpS
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:57, 3.48it/s]
Epoch [1/200] Loss: 684.8967
10%|█ | 21/200 [00:06<00:52, 3.41it/s]
Epoch [21/200] Loss: 111.6166
20%|██ | 41/200 [00:11<00:46, 3.41it/s]
Epoch [41/200] Loss: 22.6558
30%|███ | 61/200 [00:17<00:40, 3.41it/s]
Epoch [61/200] Loss: 10.9635
40%|████ | 81/200 [00:23<00:35, 3.40it/s]
Epoch [81/200] Loss: 12.9050
50%|█████ | 101/200 [00:29<00:29, 3.38it/s]
Epoch [101/200] Loss: 6.2923
60%|██████ | 121/200 [00:35<00:23, 3.40it/s]
Epoch [121/200] Loss: 9.0327
70%|███████ | 141/200 [00:41<00:16, 3.49it/s]
Epoch [141/200] Loss: 4.1764
80%|████████ | 161/200 [00:47<00:11, 3.52it/s]
Epoch [161/200] Loss: 3.8160
90%|█████████ | 181/200 [00:52<00:05, 3.57it/s]
Epoch [181/200] Loss: 3.3957
100%|██████████| 200/200 [00:57<00:00, 3.46it/s]
Average Test Loss: 3.1739
X Channel name : V10_MpS
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:56, 3.51it/s]
Epoch [1/200] Loss: 609.8626
10%|█ | 21/200 [00:06<00:53, 3.34it/s]
Epoch [21/200] Loss: 58.6357
20%|██ | 41/200 [00:12<00:46, 3.40it/s]
Epoch [41/200] Loss: 14.3107
30%|███ | 61/200 [00:18<00:41, 3.37it/s]
Epoch [61/200] Loss: 10.6218
40%|████ | 81/200 [00:23<00:35, 3.39it/s]
Epoch [81/200] Loss: 6.9871
50%|█████ | 101/200 [00:29<00:29, 3.39it/s]
Epoch [101/200] Loss: 5.3444
60%|██████ | 121/200 [00:35<00:23, 3.39it/s]
Epoch [121/200] Loss: 4.6127
70%|███████ | 141/200 [00:41<00:17, 3.43it/s]
Epoch [141/200] Loss: 4.0642
80%|████████ | 161/200 [00:47<00:11, 3.43it/s]
Epoch [161/200] Loss: 3.5636
90%|█████████ | 181/200 [00:53<00:05, 3.42it/s]
Epoch [181/200] Loss: 3.2190
100%|██████████| 200/200 [00:58<00:00, 3.40it/s]
Average Test Loss: 2.9050
X Channel name : T2_K
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:57, 3.45it/s]
Epoch [1/200] Loss: 458.4550
10%|█ | 21/200 [00:06<00:52, 3.41it/s]
Epoch [21/200] Loss: 213.3441
20%|██ | 41/200 [00:11<00:47, 3.38it/s]
Epoch [41/200] Loss: 211.7652
30%|███ | 61/200 [00:17<00:41, 3.39it/s]
Epoch [61/200] Loss: 212.4970
40%|████ | 81/200 [00:23<00:35, 3.38it/s]
Epoch [81/200] Loss: 210.7104
50%|█████ | 101/200 [00:29<00:28, 3.52it/s]
Epoch [101/200] Loss: 208.7177
60%|██████ | 121/200 [00:35<00:22, 3.44it/s]
Epoch [121/200] Loss: 212.2414
70%|███████ | 141/200 [00:41<00:17, 3.47it/s]
Epoch [141/200] Loss: 207.2175
80%|████████ | 161/200 [00:46<00:11, 3.43it/s]
Epoch [161/200] Loss: 211.3121
90%|█████████ | 181/200 [00:52<00:05, 3.42it/s]
Epoch [181/200] Loss: 208.7926
100%|██████████| 200/200 [00:58<00:00, 3.42it/s]
Average Test Loss: 191.9081
X Channel name : SWSFC_WpM2
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:56, 3.52it/s]
Epoch [1/200] Loss: 760.0155
10%|█ | 21/200 [00:06<00:53, 3.37it/s]
Epoch [21/200] Loss: 224.7744
20%|██ | 41/200 [00:12<00:47, 3.38it/s]
Epoch [41/200] Loss: 215.4670
30%|███ | 61/200 [00:18<00:41, 3.35it/s]
Epoch [61/200] Loss: 214.6593
40%|████ | 81/200 [00:24<00:35, 3.37it/s]
Epoch [81/200] Loss: 211.3864
50%|█████ | 101/200 [00:29<00:28, 3.43it/s]
Epoch [101/200] Loss: 214.5846
60%|██████ | 121/200 [00:35<00:23, 3.38it/s]
Epoch [121/200] Loss: 215.2816
70%|███████ | 141/200 [00:41<00:17, 3.32it/s]
Epoch [141/200] Loss: 210.3355
80%|████████ | 161/200 [00:47<00:11, 3.49it/s]
Epoch [161/200] Loss: 213.1056
90%|█████████ | 181/200 [00:53<00:05, 3.41it/s]
Epoch [181/200] Loss: 209.3009
100%|██████████| 200/200 [00:59<00:00, 3.38it/s]
Average Test Loss: 193.2200
X Channel name : SOLM_M3pM3
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.42it/s]
Epoch [1/200] Loss: 518.8583
10%|█ | 21/200 [00:06<00:51, 3.46it/s]
Epoch [21/200] Loss: 183.2842
20%|██ | 41/200 [00:12<00:51, 3.10it/s]
Epoch [41/200] Loss: 106.5493
30%|███ | 61/200 [00:18<00:41, 3.35it/s]
Epoch [61/200] Loss: 47.8360
40%|████ | 81/200 [00:24<00:35, 3.39it/s]
Epoch [81/200] Loss: 17.4837
50%|█████ | 101/200 [00:30<00:30, 3.28it/s]
Epoch [101/200] Loss: 12.6859
60%|██████ | 121/200 [00:36<00:23, 3.31it/s]
Epoch [121/200] Loss: 9.7521
70%|███████ | 141/200 [00:42<00:17, 3.30it/s]
Epoch [141/200] Loss: 8.4768
80%|████████ | 161/200 [00:47<00:11, 3.46it/s]
Epoch [161/200] Loss: 7.1218
90%|█████████ | 181/200 [00:53<00:05, 3.44it/s]
Epoch [181/200] Loss: 7.4880
100%|██████████| 200/200 [00:58<00:00, 3.40it/s]
Average Test Loss: 5.5514
X Channel name : CLDTOP_KM
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:57, 3.43it/s]
Epoch [1/200] Loss: 603.6223
10%|█ | 21/200 [00:06<00:52, 3.41it/s]
Epoch [21/200] Loss: 216.0129
20%|██ | 41/200 [00:12<00:50, 3.18it/s]
Epoch [41/200] Loss: 212.5011
30%|███ | 61/200 [00:18<00:41, 3.35it/s]
Epoch [61/200] Loss: 215.2624
40%|████ | 81/200 [00:24<00:35, 3.38it/s]
Epoch [81/200] Loss: 210.4688
50%|█████ | 101/200 [00:30<00:29, 3.40it/s]
Epoch [101/200] Loss: 211.2793
60%|██████ | 121/200 [00:35<00:23, 3.35it/s]
Epoch [121/200] Loss: 208.5935
70%|███████ | 141/200 [00:41<00:17, 3.38it/s]
Epoch [141/200] Loss: 209.1620
80%|████████ | 161/200 [00:47<00:11, 3.42it/s]
Epoch [161/200] Loss: 208.9249
90%|█████████ | 181/200 [00:53<00:05, 3.41it/s]
Epoch [181/200] Loss: 207.8753
100%|██████████| 200/200 [00:59<00:00, 3.38it/s]
Average Test Loss: 205.4745
X Channel name : CAPE
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.39it/s]
Epoch [1/200] Loss: 480.7086
10%|█ | 21/200 [00:06<00:52, 3.39it/s]
Epoch [21/200] Loss: 215.4105
20%|██ | 41/200 [00:12<00:46, 3.41it/s]
Epoch [41/200] Loss: 211.5271
30%|███ | 61/200 [00:17<00:40, 3.39it/s]
Epoch [61/200] Loss: 211.8129
40%|████ | 81/200 [00:23<00:35, 3.40it/s]
Epoch [81/200] Loss: 208.5153
50%|█████ | 101/200 [00:29<00:28, 3.44it/s]
Epoch [101/200] Loss: 208.6479
60%|██████ | 121/200 [00:35<00:23, 3.39it/s]
Epoch [121/200] Loss: 215.5420
70%|███████ | 141/200 [00:41<00:17, 3.44it/s]
Epoch [141/200] Loss: 207.9001
80%|████████ | 161/200 [00:47<00:11, 3.45it/s]
Epoch [161/200] Loss: 206.7139
90%|█████████ | 181/200 [00:53<00:05, 3.41it/s]
Epoch [181/200] Loss: 208.8671
100%|██████████| 200/200 [00:58<00:00, 3.41it/s]
Average Test Loss: 191.5815
X Channel name : PBL_WRF_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:57, 3.44it/s]
Epoch [1/200] Loss: 487.6684
10%|█ | 21/200 [00:06<00:53, 3.33it/s]
Epoch [21/200] Loss: 210.6665
20%|██ | 41/200 [00:12<00:46, 3.40it/s]
Epoch [41/200] Loss: 211.8588
30%|███ | 61/200 [00:18<00:41, 3.38it/s]
Epoch [61/200] Loss: 209.3883
40%|████ | 81/200 [00:24<00:35, 3.36it/s]
Epoch [81/200] Loss: 210.8873
50%|█████ | 101/200 [00:29<00:29, 3.39it/s]
Epoch [101/200] Loss: 209.3165
60%|██████ | 121/200 [00:35<00:23, 3.36it/s]
Epoch [121/200] Loss: 208.2383
70%|███████ | 141/200 [00:41<00:17, 3.40it/s]
Epoch [141/200] Loss: 208.4831
80%|████████ | 161/200 [00:47<00:11, 3.37it/s]
Epoch [161/200] Loss: 209.2067
90%|█████████ | 181/200 [00:53<00:05, 3.35it/s]
Epoch [181/200] Loss: 207.0104
100%|██████████| 200/200 [00:59<00:00, 3.37it/s]
Average Test Loss: 190.8408
X Channel name : PBL_YSU_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.43it/s]
Epoch [1/200] Loss: 538.0652
10%|█ | 21/200 [00:06<00:53, 3.38it/s]
Epoch [21/200] Loss: 213.2395
20%|██ | 41/200 [00:12<00:46, 3.40it/s]
Epoch [41/200] Loss: 209.5626
30%|███ | 61/200 [00:18<00:40, 3.41it/s]
Epoch [61/200] Loss: 218.8836
40%|████ | 81/200 [00:23<00:34, 3.44it/s]
Epoch [81/200] Loss: 213.1157
50%|█████ | 101/200 [00:29<00:28, 3.42it/s]
Epoch [101/200] Loss: 211.7407
60%|██████ | 121/200 [00:35<00:23, 3.39it/s]
Epoch [121/200] Loss: 214.4044
70%|███████ | 141/200 [00:41<00:17, 3.40it/s]
Epoch [141/200] Loss: 209.6131
80%|████████ | 161/200 [00:47<00:11, 3.39it/s]
Epoch [161/200] Loss: 209.1906
90%|█████████ | 181/200 [00:53<00:05, 3.39it/s]
Epoch [181/200] Loss: 207.2547
100%|██████████| 200/200 [00:58<00:00, 3.39it/s]
Average Test Loss: 196.5285
test_loss_list_P_25_conv_auto
[196.80755892666903,
191.6094970703125,
191.75477183948863,
194.94017861106178,
7.4987756772474805,
3.1738592061129483,
2.9050224044106225,
191.90811157226562,
193.21996238014916,
5.551396109841087,
205.47452198375356,
191.58150828968394,
190.84084944291547,
196.5285311612216]
= list(zip(target_var_96_list, test_loss_list_P_25_conv_auto))
name_loss_pairs
# Sort based on loss values
= sorted(name_loss_pairs, key=lambda x: x[1])
sorted_name_loss_pairs for pair in sorted_name_loss_pairs:
print(pair)
# print(sorted_name_loss_pairs)
# Select the lowest 3 names
= [pair[0] for pair in sorted_name_loss_pairs[:3]]
lowest_3_names
print("Lowest 3 names:", lowest_3_names)
('V10_MpS', 2.9050224044106225)
('U10_MpS', 3.1738592061129483)
('SOLM_M3pM3', 5.551396109841087)
('CLOUD_OD', 7.4987756772474805)
('PBL_WRF_M', 190.84084944291547)
('CAPE', 191.58150828968394)
('SNOWEW_M', 191.6094970703125)
('SNOWAGE_HR', 191.75477183948863)
('T2_K', 191.90811157226562)
('SWSFC_WpM2', 193.21996238014916)
('PRATE_MMpH', 194.94017861106178)
('PBL_YSU_M', 196.5285311612216)
('TSURF_K', 196.80755892666903)
('CLDTOP_KM', 205.47452198375356)
Lowest 3 names: ['V10_MpS', 'U10_MpS', 'SOLM_M3pM3']
=(1000, 1, 80, 80)) summary(model, input_size
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Autoencoder_CNN [1000, 1, 80, 80] --
├─Encoder: 1-1 [1000, 512] --
│ └─Sequential: 2-1 [1000, 512] --
│ │ └─Conv2d: 3-1 [1000, 8, 40, 40] 80
│ │ └─GELU: 3-2 [1000, 8, 40, 40] --
│ │ └─Conv2d: 3-3 [1000, 8, 40, 40] 584
│ │ └─GELU: 3-4 [1000, 8, 40, 40] --
│ │ └─Conv2d: 3-5 [1000, 16, 20, 20] 1,168
│ │ └─GELU: 3-6 [1000, 16, 20, 20] --
│ │ └─Conv2d: 3-7 [1000, 16, 20, 20] 2,320
│ │ └─GELU: 3-8 [1000, 16, 20, 20] --
│ │ └─Conv2d: 3-9 [1000, 32, 10, 10] 4,640
│ │ └─GELU: 3-10 [1000, 32, 10, 10] --
│ │ └─Flatten: 3-11 [1000, 3200] --
│ │ └─Linear: 3-12 [1000, 512] 1,638,912
├─Decoder: 1-2 [1000, 1, 80, 80] --
│ └─Sequential: 2-2 [1000, 1, 80, 80] --
│ │ └─Linear: 3-13 [1000, 3200] 1,641,600
│ │ └─Unflatten: 3-14 [1000, 32, 10, 10] --
│ │ └─ConvTranspose2d: 3-15 [1000, 16, 20, 20] 4,624
│ │ └─GELU: 3-16 [1000, 16, 20, 20] --
│ │ └─Conv2d: 3-17 [1000, 16, 20, 20] 2,320
│ │ └─GELU: 3-18 [1000, 16, 20, 20] --
│ │ └─ConvTranspose2d: 3-19 [1000, 8, 40, 40] 1,160
│ │ └─GELU: 3-20 [1000, 8, 40, 40] --
│ │ └─Conv2d: 3-21 [1000, 8, 40, 40] 584
│ │ └─GELU: 3-22 [1000, 8, 40, 40] --
│ │ └─ConvTranspose2d: 3-23 [1000, 1, 80, 80] 73
==========================================================================================
Total params: 3,298,065
Trainable params: 3,298,065
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 12.24
==========================================================================================
Input size (MB): 25.60
Forward/backward pass size (MB): 720.90
Params size (MB): 13.19
Estimated Total Size (MB): 759.69
==========================================================================================
Single channel input multiple channel output (P10)
= []
test_loss_list_P_10_conv_auto = 1 # selecting P10 as output
y_channel # x_channel = 0
for x_channel in range(X.shape[1]):
####################### Selecting the channel #######################
print('X Channel name : ', target_var_96_list[x_channel])
= X_train_all[:, x_channel:x_channel+1, :,:]
X_train = X_test_all[:, x_channel:x_channel+1, :,:]
X_test = y_train_all[:, y_channel:y_channel+1, :,:]
y_train = y_test_all[:, y_channel:y_channel+1, :,:]
y_test print('Shapes: ', X_train.shape, X_test.shape, y_train.shape, y_test.shape)
####################### Creating the dataset loader #######################
= CustomDataset(X_train, y_train)
train_custom_dataset # print(len(train_custom_dataset))
= 32
batch_size = data.DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
train_loader # print(len(train_loader))
= CustomDataset(X_test, y_test)
test_custom_dataset # print(len(test_custom_dataset))
= 32
batch_size = data.DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader # print(len(test_loader))
#################### Training the model ####################
= Autoencoder_CNN(image_size = 80, num_input_channels = 1, num_output_channels=1, c_hid = 8, latent_dim = 512, activation= nn.GELU)
model
model.to(device)# Define the loss function and optimizer
= nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = []
losses # Training loop
= 200
num_epochs for epoch in trange(num_epochs):
# Set the model to training mode
model.train() = 0.0
total_loss
for inputs, targets in train_loader:
# Zero the gradients
optimizer.zero_grad() = inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
# Backpropagation
loss.backward()
optimizer.step()
+= loss.item()
total_loss
# Print the average loss for this epoch
= total_loss / len(train_loader)
average_loss
losses.append(average_loss)if epoch % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {average_loss:.4f}")
############################# testing the model #############################
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs # Calculate the loss
= criterion(outputs, targets)
loss += loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss
test_loss_list_P_10_conv_auto.append(average_test_loss)print(f"Average Test Loss: {average_test_loss:.4f}")
range(1, num_epochs + 1), losses)
plt.plot('Epoch')
plt.xlabel('Loss')
plt.ylabel('Train Loss vs. Epoch for channel '+target_var_96_list[x_channel])
plt.title(True)
plt.grid( plt.show()
X Channel name : TSURF_K
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:02, 3.18it/s]
Epoch [1/200] Loss: 831.9463
10%|█ | 21/200 [00:06<00:52, 3.39it/s]
Epoch [21/200] Loss: 413.6695
20%|██ | 41/200 [00:12<00:46, 3.39it/s]
Epoch [41/200] Loss: 416.1795
30%|███ | 61/200 [00:17<00:40, 3.40it/s]
Epoch [61/200] Loss: 414.0692
40%|████ | 81/200 [00:23<00:35, 3.40it/s]
Epoch [81/200] Loss: 408.8671
50%|█████ | 101/200 [00:29<00:28, 3.45it/s]
Epoch [101/200] Loss: 408.6212
60%|██████ | 121/200 [00:35<00:22, 3.44it/s]
Epoch [121/200] Loss: 408.1124
70%|███████ | 141/200 [00:41<00:17, 3.39it/s]
Epoch [141/200] Loss: 407.2696
80%|████████ | 161/200 [00:47<00:11, 3.40it/s]
Epoch [161/200] Loss: 419.0648
90%|█████████ | 181/200 [00:53<00:05, 3.42it/s]
Epoch [181/200] Loss: 411.9767
100%|██████████| 200/200 [00:58<00:00, 3.41it/s]
Average Test Loss: 367.9784
X Channel name : SNOWEW_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.38it/s]
Epoch [1/200] Loss: 1193.0016
10%|█ | 21/200 [00:06<00:52, 3.38it/s]
Epoch [21/200] Loss: 413.8656
20%|██ | 41/200 [00:12<00:46, 3.40it/s]
Epoch [41/200] Loss: 416.9739
30%|███ | 61/200 [00:17<00:41, 3.38it/s]
Epoch [61/200] Loss: 405.7197
40%|████ | 81/200 [00:23<00:35, 3.40it/s]
Epoch [81/200] Loss: 414.4462
50%|█████ | 101/200 [00:29<00:28, 3.44it/s]
Epoch [101/200] Loss: 403.9555
60%|██████ | 121/200 [00:35<00:23, 3.35it/s]
Epoch [121/200] Loss: 406.6566
70%|███████ | 141/200 [00:41<00:17, 3.38it/s]
Epoch [141/200] Loss: 404.8177
80%|████████ | 161/200 [00:47<00:11, 3.36it/s]
Epoch [161/200] Loss: 414.5372
90%|█████████ | 181/200 [00:53<00:05, 3.19it/s]
Epoch [181/200] Loss: 418.9334
100%|██████████| 200/200 [00:59<00:00, 3.36it/s]
Average Test Loss: 373.6601
X Channel name : SNOWAGE_HR
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:57, 3.48it/s]
Epoch [1/200] Loss: 1176.4868
10%|█ | 21/200 [00:06<00:52, 3.43it/s]
Epoch [21/200] Loss: 410.2950
20%|██ | 41/200 [00:12<00:47, 3.37it/s]
Epoch [41/200] Loss: 414.8993
30%|███ | 61/200 [00:17<00:40, 3.41it/s]
Epoch [61/200] Loss: 405.8627
40%|████ | 81/200 [00:23<00:34, 3.43it/s]
Epoch [81/200] Loss: 416.4907
50%|█████ | 101/200 [00:29<00:28, 3.42it/s]
Epoch [101/200] Loss: 413.8539
60%|██████ | 121/200 [00:35<00:23, 3.39it/s]
Epoch [121/200] Loss: 406.4541
70%|███████ | 141/200 [00:41<00:17, 3.36it/s]
Epoch [141/200] Loss: 404.2276
80%|████████ | 161/200 [00:47<00:11, 3.38it/s]
Epoch [161/200] Loss: 407.0577
90%|█████████ | 181/200 [00:53<00:05, 3.37it/s]
Epoch [181/200] Loss: 402.6227
100%|██████████| 200/200 [00:58<00:00, 3.40it/s]
Average Test Loss: 393.0328
X Channel name : PRATE_MMpH
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:56, 3.54it/s]
Epoch [1/200] Loss: 1028.3382
10%|█ | 21/200 [00:06<00:53, 3.38it/s]
Epoch [21/200] Loss: 434.1988
20%|██ | 41/200 [00:12<00:46, 3.42it/s]
Epoch [41/200] Loss: 412.9816
30%|███ | 61/200 [00:17<00:40, 3.44it/s]
Epoch [61/200] Loss: 405.9959
40%|████ | 81/200 [00:23<00:34, 3.42it/s]
Epoch [81/200] Loss: 404.1220
50%|█████ | 101/200 [00:29<00:28, 3.43it/s]
Epoch [101/200] Loss: 415.6019
60%|██████ | 121/200 [00:35<00:23, 3.41it/s]
Epoch [121/200] Loss: 403.8031
70%|███████ | 141/200 [00:41<00:17, 3.46it/s]
Epoch [141/200] Loss: 407.2635
80%|████████ | 161/200 [00:47<00:11, 3.43it/s]
Epoch [161/200] Loss: 404.3506
90%|█████████ | 181/200 [00:53<00:05, 3.37it/s]
Epoch [181/200] Loss: 409.8351
100%|██████████| 200/200 [00:58<00:00, 3.41it/s]
Average Test Loss: 367.3246
X Channel name : CLOUD_OD
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:57, 3.49it/s]
Epoch [1/200] Loss: 1599.0141
10%|█ | 21/200 [00:06<00:53, 3.37it/s]
Epoch [21/200] Loss: 192.7198
20%|██ | 41/200 [00:12<00:47, 3.38it/s]
Epoch [41/200] Loss: 68.0225
30%|███ | 61/200 [00:18<00:41, 3.36it/s]
Epoch [61/200] Loss: 44.4028
40%|████ | 81/200 [00:24<00:35, 3.32it/s]
Epoch [81/200] Loss: 32.4964
50%|█████ | 101/200 [00:30<00:29, 3.38it/s]
Epoch [101/200] Loss: 25.7723
60%|██████ | 121/200 [00:35<00:23, 3.39it/s]
Epoch [121/200] Loss: 23.6108
70%|███████ | 141/200 [00:41<00:17, 3.37it/s]
Epoch [141/200] Loss: 19.3957
80%|████████ | 161/200 [00:47<00:11, 3.41it/s]
Epoch [161/200] Loss: 19.0788
90%|█████████ | 181/200 [00:53<00:05, 3.39it/s]
Epoch [181/200] Loss: 21.9598
100%|██████████| 200/200 [00:58<00:00, 3.40it/s]
Average Test Loss: 18.7717
X Channel name : U10_MpS
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:53, 3.73it/s]
Epoch [1/200] Loss: 1023.2784
10%|█ | 21/200 [00:05<00:52, 3.43it/s]
Epoch [21/200] Loss: 280.1561
20%|██ | 41/200 [00:11<00:46, 3.44it/s]
Epoch [41/200] Loss: 60.3793
30%|███ | 61/200 [00:17<00:40, 3.43it/s]
Epoch [61/200] Loss: 29.8831
40%|████ | 81/200 [00:23<00:35, 3.35it/s]
Epoch [81/200] Loss: 19.7393
50%|█████ | 101/200 [00:29<00:29, 3.35it/s]
Epoch [101/200] Loss: 16.8795
60%|██████ | 121/200 [00:35<00:23, 3.37it/s]
Epoch [121/200] Loss: 12.9342
70%|███████ | 141/200 [00:41<00:17, 3.39it/s]
Epoch [141/200] Loss: 11.2801
80%|████████ | 161/200 [00:47<00:10, 3.56it/s]
Epoch [161/200] Loss: 10.1249
90%|█████████ | 181/200 [00:52<00:05, 3.44it/s]
Epoch [181/200] Loss: 11.3411
100%|██████████| 200/200 [00:58<00:00, 3.42it/s]
Average Test Loss: 8.8163
X Channel name : V10_MpS
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:56, 3.53it/s]
Epoch [1/200] Loss: 1439.9822
10%|█ | 21/200 [00:06<00:51, 3.50it/s]
Epoch [21/200] Loss: 142.6731
20%|██ | 41/200 [00:11<00:46, 3.45it/s]
Epoch [41/200] Loss: 39.2041
30%|███ | 61/200 [00:17<00:38, 3.66it/s]
Epoch [61/200] Loss: 20.6722
40%|████ | 81/200 [00:23<00:33, 3.51it/s]
Epoch [81/200] Loss: 16.5665
50%|█████ | 101/200 [00:28<00:28, 3.42it/s]
Epoch [101/200] Loss: 12.4099
60%|██████ | 121/200 [00:34<00:23, 3.42it/s]
Epoch [121/200] Loss: 10.8300
70%|███████ | 141/200 [00:40<00:17, 3.47it/s]
Epoch [141/200] Loss: 9.6913
80%|████████ | 161/200 [00:46<00:11, 3.41it/s]
Epoch [161/200] Loss: 8.9197
90%|█████████ | 181/200 [00:52<00:05, 3.41it/s]
Epoch [181/200] Loss: 10.7642
100%|██████████| 200/200 [00:57<00:00, 3.45it/s]
Average Test Loss: 23.4499
X Channel name : T2_K
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:56, 3.54it/s]
Epoch [1/200] Loss: 1037.4250
10%|█ | 21/200 [00:06<00:53, 3.35it/s]
Epoch [21/200] Loss: 413.2349
20%|██ | 41/200 [00:12<00:47, 3.37it/s]
Epoch [41/200] Loss: 411.9822
30%|███ | 61/200 [00:17<00:41, 3.39it/s]
Epoch [61/200] Loss: 406.0331
40%|████ | 81/200 [00:23<00:34, 3.43it/s]
Epoch [81/200] Loss: 408.5417
50%|█████ | 101/200 [00:29<00:28, 3.41it/s]
Epoch [101/200] Loss: 403.6246
60%|██████ | 121/200 [00:35<00:23, 3.42it/s]
Epoch [121/200] Loss: 409.9414
70%|███████ | 141/200 [00:41<00:17, 3.41it/s]
Epoch [141/200] Loss: 402.6003
80%|████████ | 161/200 [00:47<00:11, 3.43it/s]
Epoch [161/200] Loss: 401.2209
90%|█████████ | 181/200 [00:53<00:05, 3.42it/s]
Epoch [181/200] Loss: 403.0725
100%|██████████| 200/200 [00:58<00:00, 3.41it/s]
Average Test Loss: 366.0120
X Channel name : SWSFC_WpM2
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:52, 3.82it/s]
Epoch [1/200] Loss: 1358.9732
10%|█ | 21/200 [00:06<00:52, 3.43it/s]
Epoch [21/200] Loss: 419.2162
20%|██ | 41/200 [00:11<00:46, 3.42it/s]
Epoch [41/200] Loss: 413.3607
30%|███ | 61/200 [00:17<00:40, 3.45it/s]
Epoch [61/200] Loss: 409.3754
40%|████ | 81/200 [00:23<00:34, 3.45it/s]
Epoch [81/200] Loss: 411.0923
50%|█████ | 101/200 [00:29<00:28, 3.45it/s]
Epoch [101/200] Loss: 404.1099
60%|██████ | 121/200 [00:35<00:22, 3.46it/s]
Epoch [121/200] Loss: 409.2167
70%|███████ | 141/200 [00:40<00:17, 3.43it/s]
Epoch [141/200] Loss: 408.5509
80%|████████ | 161/200 [00:46<00:11, 3.35it/s]
Epoch [161/200] Loss: 410.1073
90%|█████████ | 181/200 [00:52<00:05, 3.36it/s]
Epoch [181/200] Loss: 407.5864
100%|██████████| 200/200 [00:58<00:00, 3.42it/s]
Average Test Loss: 364.5863
X Channel name : SOLM_M3pM3
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:56, 3.51it/s]
Epoch [1/200] Loss: 1237.3417
10%|█ | 21/200 [00:06<00:52, 3.38it/s]
Epoch [21/200] Loss: 337.7638
20%|██ | 41/200 [00:12<00:46, 3.38it/s]
Epoch [41/200] Loss: 223.9836
30%|███ | 61/200 [00:17<00:40, 3.45it/s]
Epoch [61/200] Loss: 98.1059
40%|████ | 81/200 [00:23<00:34, 3.42it/s]
Epoch [81/200] Loss: 46.6351
50%|█████ | 101/200 [00:29<00:29, 3.34it/s]
Epoch [101/200] Loss: 31.9773
60%|██████ | 121/200 [00:35<00:23, 3.38it/s]
Epoch [121/200] Loss: 26.2758
70%|███████ | 141/200 [00:41<00:17, 3.40it/s]
Epoch [141/200] Loss: 19.9056
80%|████████ | 161/200 [00:47<00:11, 3.35it/s]
Epoch [161/200] Loss: 17.7226
90%|█████████ | 181/200 [00:53<00:05, 3.39it/s]
Epoch [181/200] Loss: 15.4444
100%|██████████| 200/200 [00:58<00:00, 3.39it/s]
Average Test Loss: 14.0181
X Channel name : CLDTOP_KM
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.40it/s]
Epoch [1/200] Loss: 1227.4362
10%|█ | 21/200 [00:06<00:53, 3.37it/s]
Epoch [21/200] Loss: 430.1239
20%|██ | 41/200 [00:12<00:46, 3.42it/s]
Epoch [41/200] Loss: 414.0475
30%|███ | 61/200 [00:17<00:40, 3.40it/s]
Epoch [61/200] Loss: 414.5093
40%|████ | 81/200 [00:24<00:35, 3.34it/s]
Epoch [81/200] Loss: 412.3910
50%|█████ | 101/200 [00:29<00:28, 3.42it/s]
Epoch [101/200] Loss: 413.5266
60%|██████ | 121/200 [00:35<00:23, 3.36it/s]
Epoch [121/200] Loss: 408.5565
70%|███████ | 141/200 [00:41<00:17, 3.36it/s]
Epoch [141/200] Loss: 404.2547
80%|████████ | 161/200 [00:47<00:11, 3.32it/s]
Epoch [161/200] Loss: 417.1860
90%|█████████ | 181/200 [00:53<00:05, 3.46it/s]
Epoch [181/200] Loss: 408.9205
100%|██████████| 200/200 [00:59<00:00, 3.38it/s]
Average Test Loss: 367.3491
X Channel name : CAPE
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:59, 3.37it/s]
Epoch [1/200] Loss: 1362.6006
10%|█ | 21/200 [00:06<00:54, 3.30it/s]
Epoch [21/200] Loss: 417.5348
20%|██ | 41/200 [00:11<00:46, 3.45it/s]
Epoch [41/200] Loss: 409.7356
30%|███ | 61/200 [00:17<00:40, 3.41it/s]
Epoch [61/200] Loss: 411.1969
40%|████ | 81/200 [00:23<00:35, 3.37it/s]
Epoch [81/200] Loss: 413.3693
50%|█████ | 101/200 [00:29<00:30, 3.29it/s]
Epoch [101/200] Loss: 414.5004
60%|██████ | 121/200 [00:35<00:23, 3.36it/s]
Epoch [121/200] Loss: 405.8445
70%|███████ | 141/200 [00:41<00:17, 3.44it/s]
Epoch [141/200] Loss: 404.8460
80%|████████ | 161/200 [00:47<00:11, 3.45it/s]
Epoch [161/200] Loss: 415.0352
90%|█████████ | 181/200 [00:53<00:05, 3.41it/s]
Epoch [181/200] Loss: 414.2297
100%|██████████| 200/200 [00:59<00:00, 3.39it/s]
Average Test Loss: 366.4497
X Channel name : PBL_WRF_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:57, 3.47it/s]
Epoch [1/200] Loss: 1069.1285
10%|█ | 21/200 [00:05<00:52, 3.42it/s]
Epoch [21/200] Loss: 408.2721
20%|██ | 41/200 [00:11<00:45, 3.51it/s]
Epoch [41/200] Loss: 408.3345
30%|███ | 61/200 [00:17<00:39, 3.50it/s]
Epoch [61/200] Loss: 411.1250
40%|████ | 81/200 [00:23<00:34, 3.43it/s]
Epoch [81/200] Loss: 406.6349
50%|█████ | 101/200 [00:29<00:29, 3.40it/s]
Epoch [101/200] Loss: 409.1261
60%|██████ | 121/200 [00:35<00:23, 3.43it/s]
Epoch [121/200] Loss: 403.4630
70%|███████ | 141/200 [00:41<00:17, 3.43it/s]
Epoch [141/200] Loss: 405.4774
80%|████████ | 161/200 [00:46<00:11, 3.39it/s]
Epoch [161/200] Loss: 403.7881
90%|█████████ | 181/200 [00:52<00:05, 3.37it/s]
Epoch [181/200] Loss: 408.1804
100%|██████████| 200/200 [00:58<00:00, 3.42it/s]
Average Test Loss: 367.3865
X Channel name : PBL_YSU_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.40it/s]
Epoch [1/200] Loss: 1065.3006
10%|█ | 21/200 [00:06<00:50, 3.55it/s]
Epoch [21/200] Loss: 443.1136
20%|██ | 41/200 [00:11<00:45, 3.52it/s]
Epoch [41/200] Loss: 407.0351
30%|███ | 61/200 [00:17<00:40, 3.46it/s]
Epoch [61/200] Loss: 409.3611
40%|████ | 81/200 [00:23<00:33, 3.50it/s]
Epoch [81/200] Loss: 403.8964
50%|█████ | 101/200 [00:29<00:28, 3.44it/s]
Epoch [101/200] Loss: 402.2005
60%|██████ | 121/200 [00:35<00:22, 3.45it/s]
Epoch [121/200] Loss: 402.1708
70%|███████ | 141/200 [00:40<00:17, 3.40it/s]
Epoch [141/200] Loss: 403.3740
80%|████████ | 161/200 [00:46<00:11, 3.37it/s]
Epoch [161/200] Loss: 398.9786
90%|█████████ | 181/200 [00:52<00:05, 3.45it/s]
Epoch [181/200] Loss: 408.3548
100%|██████████| 200/200 [00:58<00:00, 3.44it/s]
Average Test Loss: 380.6446
test_loss_list_P_10_conv_auto
[367.97836858575994,
373.660117409446,
393.0328174937855,
367.32460992986506,
18.771740219809793,
8.816288991407914,
23.44990175420588,
366.0119906338778,
364.58628151633525,
14.018102992664684,
367.34911554509944,
366.4497375488281,
367.38653841885656,
380.64462835138494]
= list(zip(target_var_96_list, test_loss_list_P_10_conv_auto))
name_loss_pairs
# Sort based on loss values
= sorted(name_loss_pairs, key=lambda x: x[1])
sorted_name_loss_pairs for pair in sorted_name_loss_pairs:
print(pair)
# print(sorted_name_loss_pairs)
# Select the lowest 3 names
= [pair[0] for pair in sorted_name_loss_pairs[:3]]
lowest_3_names
print("Lowest 3 names:", lowest_3_names)
('U10_MpS', 8.816288991407914)
('SOLM_M3pM3', 14.018102992664684)
('CLOUD_OD', 18.771740219809793)
('V10_MpS', 23.44990175420588)
('SWSFC_WpM2', 364.58628151633525)
('T2_K', 366.0119906338778)
('CAPE', 366.4497375488281)
('PRATE_MMpH', 367.32460992986506)
('CLDTOP_KM', 367.34911554509944)
('PBL_WRF_M', 367.38653841885656)
('TSURF_K', 367.97836858575994)
('SNOWEW_M', 373.660117409446)
('PBL_YSU_M', 380.64462835138494)
('SNOWAGE_HR', 393.0328174937855)
Lowest 3 names: ['U10_MpS', 'SOLM_M3pM3', 'CLOUD_OD']
Insights
# Adding values at corresponding indices
= [(x + y)/2 for x, y in zip(test_loss_list_P_10_conv_auto, test_loss_list_P_25_conv_auto)]
result
# Creating a DataFrame
= {'Input channel': target_var_96_list, 'P10': test_loss_list_P_10_conv_auto, 'P25': test_loss_list_P_25_conv_auto, 'P10+P25 avg': result}
data_frame = pd.DataFrame(data_frame)
df
# Sorting the DataFrame based on "List1 + List2"
= df.sort_values(by='P10+P25 avg')
df_sorted
# Displaying the sorted DataFrame
= df_sorted.round(1)
df_rounded '/home/rishabh.mondal/climax_alternative/Climax_2/results/test_loss_list_P_10_P25_conv_auto.csv', index=False)
df_rounded.to_csv( df_rounded
Input channel | P10 | P25 | P10+P25 avg | |
---|---|---|---|---|
5 | U10_MpS | 8.8 | 3.2 | 6.0 |
9 | SOLM_M3pM3 | 14.0 | 5.6 | 9.8 |
4 | CLOUD_OD | 18.8 | 7.5 | 13.1 |
6 | V10_MpS | 23.4 | 2.9 | 13.2 |
8 | SWSFC_WpM2 | 364.6 | 193.2 | 278.9 |
7 | T2_K | 366.0 | 191.9 | 279.0 |
11 | CAPE | 366.4 | 191.6 | 279.0 |
12 | PBL_WRF_M | 367.4 | 190.8 | 279.1 |
3 | PRATE_MMpH | 367.3 | 194.9 | 281.1 |
0 | TSURF_K | 368.0 | 196.8 | 282.4 |
1 | SNOWEW_M | 373.7 | 191.6 | 282.6 |
10 | CLDTOP_KM | 367.3 | 205.5 | 286.4 |
13 | PBL_YSU_M | 380.6 | 196.5 | 288.6 |
2 | SNOWAGE_HR | 393.0 | 191.8 | 292.4 |
Training on top 4 channel and predicting on all channel CONV
# target_var_96_list =['TSURF_K',
# 'SNOWEW_M', 'SNOWAGE_HR', 'PRATE_MMpH', 'CLOUD_OD', 'U10_MpS',
# 'V10_MpS', 'T2_K', 'SWSFC_WpM2', 'SOLM_M3pM3', 'CLDTOP_KM', 'CAPE',
# 'PBL_WRF_M', 'PBL_YSU_M'] # ['U10_MpS', 'T2_K', 'V10_MpS']
= ['V10_MpS','U10_MpS','SOLM_M3pM3', 'CLOUD_OD']
target_var_96_list= ['P25','P10']
target_var_120_list = get_data(target_var_96_list, target_var_120_list)
X,y
from sklearn.model_selection import train_test_split
= train_test_split(X, y, test_size=0.2, random_state=42)
X_train_all, X_test_all, y_train_all, y_test_all X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape
X shape (1656, 4, 80, 80)
y shape (1656, 2, 80, 80)
((1324, 4, 80, 80), (332, 4, 80, 80), (1324, 2, 80, 80), (332, 2, 80, 80))
print(X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape)
= CustomDataset(X_train_all, y_train_all)
train_custom_dataset # print(len(train_custom_dataset))
= 32
batch_size = data.DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
train_loader # print(len(train_loader))
= CustomDataset(X_test_all, y_test_all)
test_custom_dataset # print(len(test_custom_dataset))
= 32
batch_size = data.DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader # print(len(test_loader))
#################### Training the model ####################
= Autoencoder_CNN(image_size = 80, num_input_channels = 4, num_output_channels=2, c_hid = 64, latent_dim = 2048, activation= nn.GELU)
model
model.to(device) from torchinfo import summary
print(summary(model, input_size=(1656, 4, 80, 80)))
# Define the loss function and optimizer
= nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = []
losses # Training loop
= 80
num_epochs for epoch in trange(num_epochs):
# Set the model to training mode
model.train() = 0.0
total_loss
for inputs, targets in train_loader:
# Zero the gradients
optimizer.zero_grad() = inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
# Backpropagation
loss.backward()
optimizer.step()
+= loss.item()
total_loss
# Print the average loss for this epoch
= total_loss / len(train_loader)
average_loss
losses.append(average_loss)if epoch % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {average_loss:.4f}")
############################# testing the model #############################
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets
# Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
+= loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss # test_loss_list_P_25.append(average_test_loss)
print(f"Average Test Loss: {average_test_loss:.4f}")
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets
# Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
+= loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss # test_loss_list_P_25.append(average_test_loss)
print(f"Average Test Loss: {average_test_loss:.4f}")
range(1, num_epochs + 1), losses)
plt.plot('Epoch')
plt.xlabel('Loss')
plt.ylabel(# plt.title('Train Loss vs. Epoch for channel '+target_var_96_list[x_channel])
True)
plt.grid( plt.show()
(1324, 4, 80, 80) (332, 4, 80, 80) (1324, 2, 80, 80) (332, 2, 80, 80)
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Autoencoder_CNN [1656, 2, 80, 80] --
├─Encoder: 1-1 [1656, 2048] --
│ └─Sequential: 2-1 [1656, 2048] --
│ │ └─Conv2d: 3-1 [1656, 64, 40, 40] 2,368
│ │ └─GELU: 3-2 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-3 [1656, 64, 40, 40] 36,928
│ │ └─GELU: 3-4 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-5 [1656, 128, 20, 20] 73,856
│ │ └─GELU: 3-6 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-7 [1656, 128, 20, 20] 147,584
│ │ └─GELU: 3-8 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-9 [1656, 256, 10, 10] 295,168
│ │ └─GELU: 3-10 [1656, 256, 10, 10] --
│ │ └─Flatten: 3-11 [1656, 25600] --
│ │ └─Linear: 3-12 [1656, 2048] 52,430,848
├─Decoder: 1-2 [1656, 2, 80, 80] --
│ └─Sequential: 2-2 [1656, 2, 80, 80] --
│ │ └─Linear: 3-13 [1656, 25600] 52,454,400
│ │ └─Unflatten: 3-14 [1656, 256, 10, 10] --
│ │ └─ConvTranspose2d: 3-15 [1656, 128, 20, 20] 295,040
│ │ └─GELU: 3-16 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-17 [1656, 128, 20, 20] 147,584
│ │ └─GELU: 3-18 [1656, 128, 20, 20] --
│ │ └─ConvTranspose2d: 3-19 [1656, 64, 40, 40] 73,792
│ │ └─GELU: 3-20 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-21 [1656, 64, 40, 40] 36,928
│ │ └─GELU: 3-22 [1656, 64, 40, 40] --
│ │ └─ConvTranspose2d: 3-23 [1656, 2, 80, 80] 1,154
==========================================================================================
Total params: 105,995,650
Trainable params: 105,995,650
Non-trainable params: 0
Total mult-adds (Units.TERABYTES): 1.07
==========================================================================================
Input size (MB): 169.57
Forward/backward pass size (MB): 9014.58
Params size (MB): 423.98
Estimated Total Size (MB): 9608.13
==========================================================================================
1%|▏ | 1/80 [00:02<02:40, 2.03s/it]
Epoch [1/80] Loss: 1343.9761
Average Test Loss: 313.2759
26%|██▋ | 21/80 [00:35<01:40, 1.71s/it]
Epoch [21/80] Loss: 51.4848
Average Test Loss: 41.3688
51%|█████▏ | 41/80 [01:09<01:07, 1.72s/it]
Epoch [41/80] Loss: 19.3960
Average Test Loss: 15.7720
76%|███████▋ | 61/80 [01:43<00:32, 1.71s/it]
Epoch [61/80] Loss: 10.1627
Average Test Loss: 6.6679
100%|██████████| 80/80 [02:15<00:00, 1.70s/it]
Average Test Loss: 3.3631
'model/auto_conv_in4_out2.pt')
torch.save(model.state_dict(), = Autoencoder_CNN(image_size = 80, num_input_channels = 4, num_output_channels=2, c_hid = 64, latent_dim = 2048, activation= nn.GELU)
model
model.to(device) 'model/auto_conv_in4_out2.pt')) model.load_state_dict(torch.load(
<All keys matched successfully>
UNET Autoencoder with convolutional layers
Model defination
class Encoder(nn.Module):
def __init__(self, image_size, num_input_channels, num_output_channels = 1, c_hid = 16, latent_dim = 1024, activation= nn.GELU):
super(Encoder, self).__init__()
self.image_size = image_size
self.num_input_channels = num_input_channels
self.num_output_channels = num_output_channels
self.c_hid = c_hid
self.latent_dim = latent_dim
self.activation = activation
self.net = nn.Sequential(
=self.num_input_channels, out_channels=self.c_hid, kernel_size=3, stride=2, padding=1),
nn.Conv2d(in_channelsself.activation(),
=self.c_hid, out_channels=self.c_hid, kernel_size=3, padding=1),
nn.Conv2d(in_channelsself.activation(),
=self.c_hid, out_channels=2 * self.c_hid, kernel_size=3, stride=2, padding=1),
nn.Conv2d(in_channelsself.activation(),
=2 * self.c_hid, out_channels=2 * self.c_hid, kernel_size=3, padding=1),
nn.Conv2d(in_channelsself.activation(),
=2 * self.c_hid, out_channels=4*self.c_hid, kernel_size=3, stride=2, padding=1),
nn.Conv2d(in_channelsself.activation(),
nn.Flatten(),4 * self.c_hid) * self.image_size//8 * self.image_size//8, self.latent_dim)
nn.Linear((
)def forward(self, x):
return self.net(x)
class Decoder(nn.Module):
def __init__(self, image_size, num_input_channels, num_output_channels = 1,c_hid = 16, latent_dim = 1024, activation= nn.GELU):
super(Decoder, self).__init__()
self.image_size = image_size
self.num_input_channels = num_input_channels
self.num_output_channels = num_output_channels
self.c_hid = c_hid
self.latent_dim = latent_dim
self.activation = activation
self.decoder = nn.Sequential(
self.latent_dim, (4 * self.c_hid) * self.image_size//8 * self.image_size//8),
nn.Linear(1, (4*self.c_hid, self.image_size//8, self.image_size//8)),
nn.Unflatten(=4*self.c_hid, out_channels= 2*self.c_hid, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ConvTranspose2d(in_channelsself.activation(),
= 2*self.c_hid, out_channels= 2*self.c_hid, kernel_size=3, padding=1),
nn.Conv2d(in_channels self.activation(),
= 2*self.c_hid, out_channels = self.c_hid, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ConvTranspose2d(in_channels self.activation(),
= self.c_hid,out_channels= self.c_hid, kernel_size=3, padding=1),
nn.Conv2d(in_channels self.activation(),
= self.c_hid, out_channels= self.num_output_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ConvTranspose2d(in_channels
)def forward(self, x):
return self.decoder(x)
class Autoencoder_UNET(nn.Module):
def __init__(self, image_size, num_input_channels, num_output_channels = 1,c_hid = 16, latent_dim = 1024, activation= nn.GELU, encoder = Encoder, decoder = Decoder):
super(Autoencoder_UNET, self).__init__()
self.image_size = image_size
self.num_input_channels = num_input_channels
self.num_output_channels = num_output_channels
self.c_hid = c_hid
self.latent_dim = latent_dim
self.activation = activation
self.encoder = encoder(self.image_size, self.num_input_channels, self.num_output_channels, self.c_hid, self.latent_dim, self.activation)
self.decoder = decoder(self.image_size, self.num_input_channels, self.num_output_channels, self.c_hid, self.latent_dim, self.activation)
# def forward(self, x):
# x = self.encoder(x)
# # print(x.shape)
# x = self.decoder(x)
# return x
def forward(self, x):
= self.encoder.net[0]
conv1 = self.encoder.net[1]
activation1 = self.encoder.net[2]
conv2 = self.encoder.net[3]
activation2 = self.encoder.net[4]
conv3 = self.encoder.net[5]
activation3 = self.encoder.net[6]
conv4 = self.encoder.net[7]
activation4 = self.encoder.net[8]
conv5 = self.encoder.net[9]
activation5 = self.encoder.net[10]
flatten = self.encoder.net[11]
linear = self.decoder.decoder[0]
lineart = self.decoder.decoder[1]
unflattent = self.decoder.decoder[2]
convt2 = self.decoder.decoder[3]
activation7 = self.decoder.decoder[4]
convt3 = self.decoder.decoder[5]
activation8 = self.decoder.decoder[6]
convt4 = self.decoder.decoder[7]
activation9 = self.decoder.decoder[8]
convt5 = self.decoder.decoder[9]
activation10 = self.decoder.decoder[10]
convt6 = activation1(conv1(x))
x1 = activation2(conv2(x1))
x2 = activation3(conv3(x2))
x3 = activation4(conv4(x3))
x4 = activation5(conv5(x4))
x5 = flatten(x5)
x6 = linear(x6)
x7 = lineart(x7)
x8 = unflattent(x8)
x9 = activation7(convt2(x9+x5))
x10 = activation8(convt3(x10+x4))
x11 = activation9(convt4(x11+x3))
x12 = activation10(convt5(x12+x2))
x13 = convt6(x13+x1)
x13 return x13
= Autoencoder_UNET(image_size = 80, num_input_channels = 14, num_output_channels=2, c_hid = 16, latent_dim = 1024, activation= nn.GELU)
model_auto
model_auto= torch.randn(11, 14, 80, 80) # Assuming input size of (batch_size, num_num_input_channels, height, width)
dummy_input = model_auto(dummy_input)
out print(out.shape)
print(model_auto)
torch.Size([11, 2, 80, 80])
Autoencoder_UNET(
(encoder): Encoder(
(net): Sequential(
(0): Conv2d(14, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): GELU(approximate='none')
(2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): GELU(approximate='none')
(4): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(5): GELU(approximate='none')
(6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): GELU(approximate='none')
(8): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(9): GELU(approximate='none')
(10): Flatten(start_dim=1, end_dim=-1)
(11): Linear(in_features=6400, out_features=1024, bias=True)
)
)
(decoder): Decoder(
(decoder): Sequential(
(0): Linear(in_features=1024, out_features=6400, bias=True)
(1): Unflatten(dim=1, unflattened_size=(64, 10, 10))
(2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(3): GELU(approximate='none')
(4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): GELU(approximate='none')
(6): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(7): GELU(approximate='none')
(8): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): GELU(approximate='none')
(10): ConvTranspose2d(16, 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
)
)
)
from torchinfo import summary
=(1656, 14, 80, 80)) summary(model_auto, input_size
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Autoencoder_UNET [1656, 2, 80, 80] --
├─Encoder: 1-1 -- --
│ └─Sequential: 2-1 -- --
│ │ └─Conv2d: 3-1 [1656, 16, 40, 40] 2,032
│ │ └─GELU: 3-2 [1656, 16, 40, 40] --
│ │ └─Conv2d: 3-3 [1656, 16, 40, 40] 2,320
│ │ └─GELU: 3-4 [1656, 16, 40, 40] --
│ │ └─Conv2d: 3-5 [1656, 32, 20, 20] 4,640
│ │ └─GELU: 3-6 [1656, 32, 20, 20] --
│ │ └─Conv2d: 3-7 [1656, 32, 20, 20] 9,248
│ │ └─GELU: 3-8 [1656, 32, 20, 20] --
│ │ └─Conv2d: 3-9 [1656, 64, 10, 10] 18,496
│ │ └─GELU: 3-10 [1656, 64, 10, 10] --
│ │ └─Flatten: 3-11 [1656, 6400] --
│ │ └─Linear: 3-12 [1656, 1024] 6,554,624
├─Decoder: 1-2 -- --
│ └─Sequential: 2-2 -- --
│ │ └─Linear: 3-13 [1656, 6400] 6,560,000
│ │ └─Unflatten: 3-14 [1656, 64, 10, 10] --
│ │ └─ConvTranspose2d: 3-15 [1656, 32, 20, 20] 18,464
│ │ └─GELU: 3-16 [1656, 32, 20, 20] --
│ │ └─Conv2d: 3-17 [1656, 32, 20, 20] 9,248
│ │ └─GELU: 3-18 [1656, 32, 20, 20] --
│ │ └─ConvTranspose2d: 3-19 [1656, 16, 40, 40] 4,624
│ │ └─GELU: 3-20 [1656, 16, 40, 40] --
│ │ └─Conv2d: 3-21 [1656, 16, 40, 40] 2,320
│ │ └─GELU: 3-22 [1656, 16, 40, 40] --
│ │ └─ConvTranspose2d: 3-23 [1656, 2, 80, 80] 290
==========================================================================================
Total params: 13,186,306
Trainable params: 13,186,306
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 85.34
==========================================================================================
Input size (MB): 593.51
Forward/backward pass size (MB): 2387.61
Params size (MB): 52.75
Estimated Total Size (MB): 3033.86
==========================================================================================
from torchsummary import summary
=(14, 80, 80)) summary(model_auto, input_size
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 16, 40, 40] 2,032
GELU-2 [-1, 16, 40, 40] 0
Conv2d-3 [-1, 16, 40, 40] 2,320
GELU-4 [-1, 16, 40, 40] 0
Conv2d-5 [-1, 32, 20, 20] 4,640
GELU-6 [-1, 32, 20, 20] 0
Conv2d-7 [-1, 32, 20, 20] 9,248
GELU-8 [-1, 32, 20, 20] 0
Conv2d-9 [-1, 64, 10, 10] 18,496
GELU-10 [-1, 64, 10, 10] 0
Flatten-11 [-1, 6400] 0
Linear-12 [-1, 1024] 6,554,624
Linear-13 [-1, 6400] 6,560,000
Unflatten-14 [-1, 64, 10, 10] 0
ConvTranspose2d-15 [-1, 32, 20, 20] 18,464
GELU-16 [-1, 32, 20, 20] 0
Conv2d-17 [-1, 32, 20, 20] 9,248
GELU-18 [-1, 32, 20, 20] 0
ConvTranspose2d-19 [-1, 16, 40, 40] 4,624
GELU-20 [-1, 16, 40, 40] 0
Conv2d-21 [-1, 16, 40, 40] 2,320
GELU-22 [-1, 16, 40, 40] 0
ConvTranspose2d-23 [-1, 2, 80, 80] 290
================================================================
Total params: 13,186,306
Trainable params: 13,186,306
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.34
Forward/backward pass size (MB): 2.69
Params size (MB): 50.30
Estimated Total Size (MB): 53.34
----------------------------------------------------------------
Training on all channel and predicting on all channel UNET
=['TSURF_K',
target_var_96_list 'SNOWEW_M', 'SNOWAGE_HR', 'PRATE_MMpH', 'CLOUD_OD', 'U10_MpS',
'V10_MpS', 'T2_K', 'SWSFC_WpM2', 'SOLM_M3pM3', 'CLDTOP_KM', 'CAPE',
'PBL_WRF_M', 'PBL_YSU_M'] # ['U10_MpS', 'T2_K', 'V10_MpS']
= ['P25','P10']
target_var_120_list = get_data(target_var_96_list, target_var_120_list)
X,y
from sklearn.model_selection import train_test_split
= train_test_split(X, y, test_size=0.2, random_state=42)
X_train_all, X_test_all, y_train_all, y_test_all X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape
X shape (1656, 14, 80, 80)
y shape (1656, 2, 80, 80)
((1324, 14, 80, 80), (332, 14, 80, 80), (1324, 2, 80, 80), (332, 2, 80, 80))
print(X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape)
= CustomDataset(X_train_all, y_train_all)
train_custom_dataset # print(len(train_custom_dataset))
= 32
batch_size = data.DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
train_loader # print(len(train_loader))
= CustomDataset(X_test_all, y_test_all)
test_custom_dataset # print(len(test_custom_dataset))
= 32
batch_size = data.DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader # print(len(test_loader))
#################### Training the model ####################
= Autoencoder_UNET(image_size = 80, num_input_channels = 14, num_output_channels=2, c_hid = 64, latent_dim = 2048, activation= nn.GELU)
model
model.to(device) from torchinfo import summary
print(summary(model, input_size=(1656, 14, 80, 80)))
# Define the loss function and optimizer
= nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = []
losses # Training loop
= 200
num_epochs for epoch in trange(num_epochs):
# Set the model to training mode
model.train() = 0.0
total_loss
for inputs, targets in train_loader:
# Zero the gradients
optimizer.zero_grad() = inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
# Backpropagation
loss.backward()
optimizer.step()
+= loss.item()
total_loss
# Print the average loss for this epoch
= total_loss / len(train_loader)
average_loss
losses.append(average_loss)if epoch % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {average_loss:.4f}")
############################# testing the model #############################
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets
# Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
+= loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss # test_loss_list_P_25.append(average_test_loss)
print(f"Average Test Loss: {average_test_loss:.4f}")
range(1, num_epochs + 1), losses)
plt.plot('Epoch')
plt.xlabel('Loss')
plt.ylabel(# plt.title('Train Loss vs. Epoch for channel '+target_var_96_list[x_channel])
True)
plt.grid( plt.show()
(1324, 14, 80, 80) (332, 14, 80, 80) (1324, 2, 80, 80) (332, 2, 80, 80)
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Autoencoder_UNET [1656, 2, 80, 80] --
├─Encoder: 1-1 -- --
│ └─Sequential: 2-1 -- --
│ │ └─Conv2d: 3-1 [1656, 64, 40, 40] 8,128
│ │ └─GELU: 3-2 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-3 [1656, 64, 40, 40] 36,928
│ │ └─GELU: 3-4 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-5 [1656, 128, 20, 20] 73,856
│ │ └─GELU: 3-6 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-7 [1656, 128, 20, 20] 147,584
│ │ └─GELU: 3-8 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-9 [1656, 256, 10, 10] 295,168
│ │ └─GELU: 3-10 [1656, 256, 10, 10] --
│ │ └─Flatten: 3-11 [1656, 25600] --
│ │ └─Linear: 3-12 [1656, 2048] 52,430,848
├─Decoder: 1-2 -- --
│ └─Sequential: 2-2 -- --
│ │ └─Linear: 3-13 [1656, 25600] 52,454,400
│ │ └─Unflatten: 3-14 [1656, 256, 10, 10] --
│ │ └─ConvTranspose2d: 3-15 [1656, 128, 20, 20] 295,040
│ │ └─GELU: 3-16 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-17 [1656, 128, 20, 20] 147,584
│ │ └─GELU: 3-18 [1656, 128, 20, 20] --
│ │ └─ConvTranspose2d: 3-19 [1656, 64, 40, 40] 73,792
│ │ └─GELU: 3-20 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-21 [1656, 64, 40, 40] 36,928
│ │ └─GELU: 3-22 [1656, 64, 40, 40] --
│ │ └─ConvTranspose2d: 3-23 [1656, 2, 80, 80] 1,154
==========================================================================================
Total params: 106,001,410
Trainable params: 106,001,410
Non-trainable params: 0
Total mult-adds (Units.TERABYTES): 1.09
==========================================================================================
Input size (MB): 593.51
Forward/backward pass size (MB): 9014.58
Params size (MB): 424.01
Estimated Total Size (MB): 10032.09
==========================================================================================
0%| | 1/200 [00:03<11:37, 3.50s/it]
Epoch [1/200] Loss: 3185.9086
10%|█ | 21/200 [01:08<10:15, 3.44s/it]
Epoch [21/200] Loss: 238.7220
20%|██ | 41/200 [02:19<09:16, 3.50s/it]
Epoch [41/200] Loss: 90.1479
30%|███ | 61/200 [03:44<09:57, 4.30s/it]
Epoch [61/200] Loss: 19.3867
40%|████ | 81/200 [05:06<08:00, 4.04s/it]
Epoch [81/200] Loss: 7.6679
50%|█████ | 101/200 [06:30<07:00, 4.24s/it]
Epoch [101/200] Loss: 5.9376
60%|██████ | 121/200 [07:57<05:25, 4.12s/it]
Epoch [121/200] Loss: 7.9944
70%|███████ | 141/200 [09:21<03:59, 4.06s/it]
Epoch [141/200] Loss: 5.3282
80%|████████ | 161/200 [10:38<02:34, 3.97s/it]
Epoch [161/200] Loss: 4.2981
90%|█████████ | 181/200 [11:58<01:13, 3.87s/it]
Epoch [181/200] Loss: 4.3096
100%|██████████| 200/200 [13:15<00:00, 3.98s/it]
Average Test Loss: 3.9575
'model/auto_conv_unet_in14_out2.pt')
torch.save(model.state_dict(), = Autoencoder_UNET(image_size = 80, num_input_channels = 14, num_output_channels=2, c_hid = 64, latent_dim = 2048, activation= nn.GELU)
model
model.to(device) 'model/auto_conv_unet_in14_out2.pt')) model.load_state_dict(torch.load(
<All keys matched successfully>
Single channel input single channel output(P25) UNET
=['TSURF_K',
target_var_96_list 'SNOWEW_M', 'SNOWAGE_HR', 'PRATE_MMpH', 'CLOUD_OD', 'U10_MpS',
'V10_MpS', 'T2_K', 'SWSFC_WpM2', 'SOLM_M3pM3', 'CLDTOP_KM', 'CAPE',
'PBL_WRF_M', 'PBL_YSU_M'] # ['U10_MpS', 'T2_K', 'V10_MpS']
= ['P25','P10']
target_var_120_list = get_data(target_var_96_list, target_var_120_list)
X,y
from sklearn.model_selection import train_test_split
= train_test_split(X, y, test_size=0.2, random_state=42)
X_train_all, X_test_all, y_train_all, y_test_all X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape
X shape (1656, 14, 80, 80)
y shape (1656, 2, 80, 80)
((1324, 14, 80, 80), (332, 14, 80, 80), (1324, 2, 80, 80), (332, 2, 80, 80))
= []
test_loss_list_P_25_conv_auto_unet = 0 # selecting P25 as output
y_channel # x_channel = 0
for x_channel in range(X.shape[1]):
####################### Selecting the channel #######################
print('X Channel name : ', target_var_96_list[x_channel])
= X_train_all[:, x_channel:x_channel+1, :,:]
X_train = X_test_all[:, x_channel:x_channel+1, :,:]
X_test = y_train_all[:, y_channel:y_channel+1, :,:]
y_train = y_test_all[:, y_channel:y_channel+1, :,:]
y_test print('Shapes: ', X_train.shape, X_test.shape, y_train.shape, y_test.shape)
####################### Creating the dataset loader #######################
= CustomDataset(X_train, y_train)
train_custom_dataset # print(len(train_custom_dataset))
= 32
batch_size = data.DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
train_loader # print(len(train_loader))
= CustomDataset(X_test, y_test)
test_custom_dataset # print(len(test_custom_dataset))
= 32
batch_size = data.DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader # print(len(test_loader))
#################### Training the model ####################
= Autoencoder_UNET(image_size = 80, num_input_channels = 1, num_output_channels=1, c_hid = 8, latent_dim = 512, activation= nn.GELU)
model
model.to(device)# Define the loss function and optimizer
= nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = []
losses # Training loop
= 200
num_epochs for epoch in trange(num_epochs):
# Set the model to training mode
model.train() = 0.0
total_loss
for inputs, targets in train_loader:
# Zero the gradients
optimizer.zero_grad() = inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
# Backpropagation
loss.backward()
optimizer.step()
+= loss.item()
total_loss
# Print the average loss for this epoch
= total_loss / len(train_loader)
average_loss
losses.append(average_loss)if epoch % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {average_loss:.4f}")
############################# testing the model #############################
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs # Calculate the loss
= criterion(outputs, targets)
loss += loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss
test_loss_list_P_25_conv_auto_unet.append(average_test_loss)print(f"Average Test Loss: {average_test_loss:.4f}")
range(1, num_epochs + 1), losses)
plt.plot('Epoch')
plt.xlabel('Loss')
plt.ylabel('Train Loss vs. Epoch for channel '+target_var_96_list[x_channel])
plt.title(True)
plt.grid( plt.show()
X Channel name : TSURF_K
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 0/200 [00:00<?, ?it/s] 0%| | 1/200 [00:00<01:04, 3.09it/s]
Epoch [1/200] Loss: 2015.0104
10%|█ | 21/200 [00:06<00:57, 3.11it/s]
Epoch [21/200] Loss: 224.4622
20%|██ | 41/200 [00:13<00:49, 3.20it/s]
Epoch [41/200] Loss: 217.3994
30%|███ | 61/200 [00:19<00:43, 3.20it/s]
Epoch [61/200] Loss: 213.6307
40%|████ | 81/200 [00:25<00:37, 3.16it/s]
Epoch [81/200] Loss: 211.9555
50%|█████ | 101/200 [00:31<00:30, 3.20it/s]
Epoch [101/200] Loss: 212.0086
60%|██████ | 121/200 [00:38<00:24, 3.17it/s]
Epoch [121/200] Loss: 210.0710
70%|███████ | 141/200 [00:44<00:18, 3.18it/s]
Epoch [141/200] Loss: 216.9347
80%|████████ | 161/200 [00:50<00:12, 3.21it/s]
Epoch [161/200] Loss: 214.3995
90%|█████████ | 181/200 [00:56<00:05, 3.19it/s]
Epoch [181/200] Loss: 209.9813
100%|██████████| 200/200 [01:02<00:00, 3.18it/s]
Average Test Loss: 192.8667
X Channel name : SNOWEW_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:01, 3.26it/s]
Epoch [1/200] Loss: 720.3883
10%|█ | 21/200 [00:06<00:56, 3.15it/s]
Epoch [21/200] Loss: 215.2669
20%|██ | 41/200 [00:12<00:50, 3.15it/s]
Epoch [41/200] Loss: 212.9412
30%|███ | 61/200 [00:19<00:43, 3.19it/s]
Epoch [61/200] Loss: 210.6425
40%|████ | 81/200 [00:25<00:37, 3.19it/s]
Epoch [81/200] Loss: 209.0133
50%|█████ | 101/200 [00:31<00:30, 3.22it/s]
Epoch [101/200] Loss: 216.0626
60%|██████ | 121/200 [00:38<00:24, 3.21it/s]
Epoch [121/200] Loss: 209.8225
70%|███████ | 141/200 [00:44<00:18, 3.26it/s]
Epoch [141/200] Loss: 209.2404
80%|████████ | 161/200 [00:50<00:12, 3.16it/s]
Epoch [161/200] Loss: 212.0243
90%|█████████ | 181/200 [00:56<00:05, 3.18it/s]
Epoch [181/200] Loss: 208.9202
100%|██████████| 200/200 [01:02<00:00, 3.19it/s]
Average Test Loss: 191.8743
X Channel name : SNOWAGE_HR
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:01, 3.22it/s]
Epoch [1/200] Loss: 656.6864
10%|█ | 21/200 [00:06<00:56, 3.17it/s]
Epoch [21/200] Loss: 215.9682
20%|██ | 41/200 [00:12<00:50, 3.17it/s]
Epoch [41/200] Loss: 211.3083
30%|███ | 61/200 [00:19<00:43, 3.19it/s]
Epoch [61/200] Loss: 213.7716
40%|████ | 81/200 [00:25<00:36, 3.23it/s]
Epoch [81/200] Loss: 210.4161
50%|█████ | 101/200 [00:31<00:31, 3.16it/s]
Epoch [101/200] Loss: 208.3461
60%|██████ | 121/200 [00:38<00:24, 3.22it/s]
Epoch [121/200] Loss: 210.5009
70%|███████ | 141/200 [00:44<00:18, 3.19it/s]
Epoch [141/200] Loss: 210.7708
80%|████████ | 161/200 [00:50<00:11, 3.26it/s]
Epoch [161/200] Loss: 208.3222
90%|█████████ | 181/200 [00:56<00:06, 3.16it/s]
Epoch [181/200] Loss: 208.7196
100%|██████████| 200/200 [01:02<00:00, 3.18it/s]
Average Test Loss: 198.2665
X Channel name : PRATE_MMpH
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:02, 3.17it/s]
Epoch [1/200] Loss: 598.2960
10%|█ | 21/200 [00:06<00:56, 3.17it/s]
Epoch [21/200] Loss: 227.3699
20%|██ | 41/200 [00:12<00:49, 3.18it/s]
Epoch [41/200] Loss: 219.5535
30%|███ | 61/200 [00:19<00:43, 3.19it/s]
Epoch [61/200] Loss: 210.1886
40%|████ | 81/200 [00:25<00:37, 3.16it/s]
Epoch [81/200] Loss: 210.2708
50%|█████ | 101/200 [00:31<00:31, 3.14it/s]
Epoch [101/200] Loss: 209.0613
60%|██████ | 121/200 [00:38<00:24, 3.17it/s]
Epoch [121/200] Loss: 208.2215
70%|███████ | 141/200 [00:44<00:18, 3.14it/s]
Epoch [141/200] Loss: 207.7972
80%|████████ | 161/200 [00:50<00:12, 3.18it/s]
Epoch [161/200] Loss: 213.5013
90%|█████████ | 181/200 [00:56<00:05, 3.25it/s]
Epoch [181/200] Loss: 208.5548
100%|██████████| 200/200 [01:02<00:00, 3.18it/s]
Average Test Loss: 199.3592
X Channel name : CLOUD_OD
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:01, 3.23it/s]
Epoch [1/200] Loss: 586.9568
10%|█ | 21/200 [00:06<00:55, 3.20it/s]
Epoch [21/200] Loss: 62.2030
20%|██ | 41/200 [00:12<00:49, 3.21it/s]
Epoch [41/200] Loss: 21.8947
30%|███ | 61/200 [00:18<00:41, 3.35it/s]
Epoch [61/200] Loss: 13.4883
40%|████ | 81/200 [00:25<00:37, 3.16it/s]
Epoch [81/200] Loss: 13.7934
50%|█████ | 101/200 [00:31<00:31, 3.16it/s]
Epoch [101/200] Loss: 8.8402
60%|██████ | 121/200 [00:37<00:24, 3.16it/s]
Epoch [121/200] Loss: 8.3218
70%|███████ | 141/200 [00:44<00:18, 3.17it/s]
Epoch [141/200] Loss: 8.3719
80%|████████ | 161/200 [00:50<00:12, 3.25it/s]
Epoch [161/200] Loss: 7.2083
90%|█████████ | 181/200 [00:56<00:05, 3.29it/s]
Epoch [181/200] Loss: 6.5652
100%|██████████| 200/200 [01:02<00:00, 3.20it/s]
Average Test Loss: 7.9540
X Channel name : U10_MpS
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:59, 3.33it/s]
Epoch [1/200] Loss: 473.9317
10%|█ | 21/200 [00:06<00:55, 3.21it/s]
Epoch [21/200] Loss: 69.0522
20%|██ | 41/200 [00:12<00:48, 3.30it/s]
Epoch [41/200] Loss: 18.4941
30%|███ | 61/200 [00:18<00:43, 3.19it/s]
Epoch [61/200] Loss: 10.7839
40%|████ | 81/200 [00:25<00:37, 3.21it/s]
Epoch [81/200] Loss: 7.7059
50%|█████ | 101/200 [00:31<00:31, 3.19it/s]
Epoch [101/200] Loss: 6.4899
60%|██████ | 121/200 [00:37<00:24, 3.20it/s]
Epoch [121/200] Loss: 5.4837
70%|███████ | 141/200 [00:43<00:18, 3.25it/s]
Epoch [141/200] Loss: 7.5561
80%|████████ | 161/200 [00:50<00:12, 3.21it/s]
Epoch [161/200] Loss: 3.8554
90%|█████████ | 181/200 [00:56<00:05, 3.27it/s]
Epoch [181/200] Loss: 4.0530
100%|██████████| 200/200 [01:02<00:00, 3.21it/s]
Average Test Loss: 3.1830
X Channel name : V10_MpS
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:00, 3.29it/s]
Epoch [1/200] Loss: 561.4590
10%|█ | 21/200 [00:06<00:55, 3.23it/s]
Epoch [21/200] Loss: 71.1724
20%|██ | 41/200 [00:12<00:49, 3.21it/s]
Epoch [41/200] Loss: 19.3733
30%|███ | 61/200 [00:19<00:43, 3.21it/s]
Epoch [61/200] Loss: 9.0035
40%|████ | 81/200 [00:25<00:37, 3.16it/s]
Epoch [81/200] Loss: 7.7880
50%|█████ | 101/200 [00:31<00:31, 3.18it/s]
Epoch [101/200] Loss: 5.6728
60%|██████ | 121/200 [00:37<00:24, 3.23it/s]
Epoch [121/200] Loss: 8.2339
70%|███████ | 141/200 [00:44<00:18, 3.22it/s]
Epoch [141/200] Loss: 3.8740
80%|████████ | 161/200 [00:50<00:12, 3.24it/s]
Epoch [161/200] Loss: 4.9030
90%|█████████ | 181/200 [00:56<00:05, 3.22it/s]
Epoch [181/200] Loss: 4.0399
100%|██████████| 200/200 [01:02<00:00, 3.21it/s]
Average Test Loss: 3.4703
X Channel name : T2_K
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.37it/s]
Epoch [1/200] Loss: 4816.9447
10%|█ | 21/200 [00:06<00:57, 3.13it/s]
Epoch [21/200] Loss: 234.2117
20%|██ | 41/200 [00:12<00:49, 3.21it/s]
Epoch [41/200] Loss: 222.6848
30%|███ | 61/200 [00:19<00:43, 3.19it/s]
Epoch [61/200] Loss: 215.6618
40%|████ | 81/200 [00:25<00:37, 3.18it/s]
Epoch [81/200] Loss: 213.7606
50%|█████ | 101/200 [00:31<00:30, 3.22it/s]
Epoch [101/200] Loss: 213.5892
60%|██████ | 121/200 [00:37<00:25, 3.15it/s]
Epoch [121/200] Loss: 211.2940
70%|███████ | 141/200 [00:44<00:18, 3.23it/s]
Epoch [141/200] Loss: 213.9061
80%|████████ | 161/200 [00:50<00:12, 3.23it/s]
Epoch [161/200] Loss: 214.1695
90%|█████████ | 181/200 [00:56<00:05, 3.26it/s]
Epoch [181/200] Loss: 208.9214
100%|██████████| 200/200 [01:02<00:00, 3.21it/s]
Average Test Loss: 192.4815
X Channel name : SWSFC_WpM2
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:59, 3.34it/s]
Epoch [1/200] Loss: 512.7773
10%|█ | 21/200 [00:06<00:54, 3.28it/s]
Epoch [21/200] Loss: 214.1751
20%|██ | 41/200 [00:12<00:49, 3.23it/s]
Epoch [41/200] Loss: 212.7058
30%|███ | 61/200 [00:18<00:43, 3.17it/s]
Epoch [61/200] Loss: 214.4303
40%|████ | 81/200 [00:25<00:36, 3.22it/s]
Epoch [81/200] Loss: 212.5624
50%|█████ | 101/200 [00:31<00:31, 3.18it/s]
Epoch [101/200] Loss: 208.8632
60%|██████ | 121/200 [00:37<00:24, 3.19it/s]
Epoch [121/200] Loss: 210.4221
70%|███████ | 141/200 [00:43<00:18, 3.19it/s]
Epoch [141/200] Loss: 209.8110
80%|████████ | 161/200 [00:50<00:12, 3.24it/s]
Epoch [161/200] Loss: 207.3880
90%|█████████ | 181/200 [00:56<00:05, 3.24it/s]
Epoch [181/200] Loss: 207.7816
100%|██████████| 200/200 [01:02<00:00, 3.22it/s]
Average Test Loss: 191.0656
X Channel name : SOLM_M3pM3
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.40it/s]
Epoch [1/200] Loss: 574.8718
10%|█ | 21/200 [00:06<00:54, 3.26it/s]
Epoch [21/200] Loss: 181.5824
20%|██ | 41/200 [00:12<00:49, 3.23it/s]
Epoch [41/200] Loss: 115.5675
30%|███ | 61/200 [00:18<00:42, 3.25it/s]
Epoch [61/200] Loss: 39.5182
40%|████ | 81/200 [00:24<00:36, 3.22it/s]
Epoch [81/200] Loss: 16.1863
50%|█████ | 101/200 [00:31<00:30, 3.26it/s]
Epoch [101/200] Loss: 11.7081
60%|██████ | 121/200 [00:37<00:24, 3.21it/s]
Epoch [121/200] Loss: 8.9398
70%|███████ | 141/200 [00:43<00:18, 3.25it/s]
Epoch [141/200] Loss: 8.2249
80%|████████ | 161/200 [00:49<00:12, 3.22it/s]
Epoch [161/200] Loss: 6.6518
90%|█████████ | 181/200 [00:55<00:05, 3.21it/s]
Epoch [181/200] Loss: 5.9480
100%|██████████| 200/200 [01:01<00:00, 3.24it/s]
Average Test Loss: 5.3015
X Channel name : CLDTOP_KM
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.39it/s]
Epoch [1/200] Loss: 647.1522
10%|█ | 21/200 [00:06<00:55, 3.21it/s]
Epoch [21/200] Loss: 215.3608
20%|██ | 41/200 [00:12<00:49, 3.19it/s]
Epoch [41/200] Loss: 211.6647
30%|███ | 61/200 [00:19<00:44, 3.10it/s]
Epoch [61/200] Loss: 209.6907
40%|████ | 81/200 [00:25<00:37, 3.18it/s]
Epoch [81/200] Loss: 211.7445
50%|█████ | 101/200 [00:31<00:31, 3.17it/s]
Epoch [101/200] Loss: 207.4063
60%|██████ | 121/200 [00:38<00:24, 3.20it/s]
Epoch [121/200] Loss: 207.9890
70%|███████ | 141/200 [00:44<00:18, 3.25it/s]
Epoch [141/200] Loss: 207.9899
80%|████████ | 161/200 [00:50<00:12, 3.21it/s]
Epoch [161/200] Loss: 209.2992
90%|█████████ | 181/200 [00:56<00:05, 3.20it/s]
Epoch [181/200] Loss: 213.8687
100%|██████████| 200/200 [01:02<00:00, 3.20it/s]
Average Test Loss: 191.1090
X Channel name : CAPE
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:00, 3.30it/s]
Epoch [1/200] Loss: 589.8784
10%|█ | 21/200 [00:06<00:56, 3.15it/s]
Epoch [21/200] Loss: 213.6897
20%|██ | 41/200 [00:12<00:50, 3.14it/s]
Epoch [41/200] Loss: 215.2983
30%|███ | 61/200 [00:19<00:43, 3.17it/s]
Epoch [61/200] Loss: 216.7662
40%|████ | 81/200 [00:25<00:35, 3.31it/s]
Epoch [81/200] Loss: 213.0695
50%|█████ | 101/200 [00:31<00:29, 3.30it/s]
Epoch [101/200] Loss: 214.7085
60%|██████ | 121/200 [00:37<00:24, 3.22it/s]
Epoch [121/200] Loss: 213.2036
70%|███████ | 141/200 [00:43<00:18, 3.24it/s]
Epoch [141/200] Loss: 212.2645
80%|████████ | 161/200 [00:49<00:11, 3.34it/s]
Epoch [161/200] Loss: 207.4999
90%|█████████ | 181/200 [00:56<00:05, 3.21it/s]
Epoch [181/200] Loss: 213.0907
100%|██████████| 200/200 [01:01<00:00, 3.23it/s]
Average Test Loss: 191.7055
X Channel name : PBL_WRF_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:00, 3.27it/s]
Epoch [1/200] Loss: 613.1753
10%|█ | 21/200 [00:06<00:55, 3.24it/s]
Epoch [21/200] Loss: 215.8198
20%|██ | 41/200 [00:12<00:49, 3.21it/s]
Epoch [41/200] Loss: 211.0306
30%|███ | 61/200 [00:19<00:42, 3.27it/s]
Epoch [61/200] Loss: 210.1862
40%|████ | 81/200 [00:25<00:38, 3.11it/s]
Epoch [81/200] Loss: 208.5808
50%|█████ | 101/200 [00:31<00:30, 3.24it/s]
Epoch [101/200] Loss: 208.4701
60%|██████ | 121/200 [00:37<00:24, 3.18it/s]
Epoch [121/200] Loss: 209.8240
70%|███████ | 141/200 [00:44<00:18, 3.19it/s]
Epoch [141/200] Loss: 208.0958
80%|████████ | 161/200 [00:50<00:12, 3.19it/s]
Epoch [161/200] Loss: 209.3367
90%|█████████ | 181/200 [00:56<00:05, 3.24it/s]
Epoch [181/200] Loss: 207.0369
100%|██████████| 200/200 [01:02<00:00, 3.19it/s]
Average Test Loss: 192.7273
X Channel name : PBL_YSU_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:00, 3.28it/s]
Epoch [1/200] Loss: 419.9077
10%|█ | 21/200 [00:06<00:55, 3.20it/s]
Epoch [21/200] Loss: 210.7698
20%|██ | 41/200 [00:12<00:48, 3.25it/s]
Epoch [41/200] Loss: 208.6584
30%|███ | 61/200 [00:18<00:43, 3.20it/s]
Epoch [61/200] Loss: 210.8284
40%|████ | 81/200 [00:25<00:37, 3.19it/s]
Epoch [81/200] Loss: 208.8175
50%|█████ | 101/200 [00:31<00:31, 3.18it/s]
Epoch [101/200] Loss: 206.5321
60%|██████ | 121/200 [00:37<00:24, 3.21it/s]
Epoch [121/200] Loss: 207.8562
70%|███████ | 141/200 [00:44<00:18, 3.20it/s]
Epoch [141/200] Loss: 207.3433
80%|████████ | 161/200 [00:50<00:12, 3.16it/s]
Epoch [161/200] Loss: 206.2741
90%|█████████ | 181/200 [00:57<00:06, 3.01it/s]
Epoch [181/200] Loss: 207.5728
100%|██████████| 200/200 [01:02<00:00, 3.18it/s]
Average Test Loss: 194.3204
test_loss_list_P_25_conv_auto_unet
[192.86672002618963,
191.87428977272728,
198.26652665571734,
199.35923628373578,
7.954025506973267,
3.18301400271329,
3.4703480113636362,
192.4814786044034,
191.06564331054688,
5.301476088437167,
191.10903930664062,
191.70547346635297,
192.72732405229047,
194.32039434259588]
= list(zip(target_var_96_list, test_loss_list_P_25_conv_auto_unet))
name_loss_pairs
# Sort based on loss values
= sorted(name_loss_pairs, key=lambda x: x[1])
sorted_name_loss_pairs for pair in sorted_name_loss_pairs:
print(pair)
# print(sorted_name_loss_pairs)
# Select the lowest 3 names
= [pair[0] for pair in sorted_name_loss_pairs[:3]]
lowest_3_names
print("Lowest 3 names:", lowest_3_names)
('U10_MpS', 3.18301400271329)
('V10_MpS', 3.4703480113636362)
('SOLM_M3pM3', 5.301476088437167)
('CLOUD_OD', 7.954025506973267)
('SWSFC_WpM2', 191.06564331054688)
('CLDTOP_KM', 191.10903930664062)
('CAPE', 191.70547346635297)
('SNOWEW_M', 191.87428977272728)
('T2_K', 192.4814786044034)
('PBL_WRF_M', 192.72732405229047)
('TSURF_K', 192.86672002618963)
('PBL_YSU_M', 194.32039434259588)
('SNOWAGE_HR', 198.26652665571734)
('PRATE_MMpH', 199.35923628373578)
Lowest 3 names: ['U10_MpS', 'V10_MpS', 'SOLM_M3pM3']
=(1, 80, 80)) summary(model, input_size
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 8, 40, 40] 80
GELU-2 [-1, 8, 40, 40] 0
Conv2d-3 [-1, 8, 40, 40] 584
GELU-4 [-1, 8, 40, 40] 0
Conv2d-5 [-1, 16, 20, 20] 1,168
GELU-6 [-1, 16, 20, 20] 0
Conv2d-7 [-1, 16, 20, 20] 2,320
GELU-8 [-1, 16, 20, 20] 0
Conv2d-9 [-1, 32, 10, 10] 4,640
GELU-10 [-1, 32, 10, 10] 0
Flatten-11 [-1, 3200] 0
Linear-12 [-1, 512] 1,638,912
Linear-13 [-1, 3200] 1,641,600
Unflatten-14 [-1, 32, 10, 10] 0
ConvTranspose2d-15 [-1, 16, 20, 20] 4,624
GELU-16 [-1, 16, 20, 20] 0
Conv2d-17 [-1, 16, 20, 20] 2,320
GELU-18 [-1, 16, 20, 20] 0
ConvTranspose2d-19 [-1, 8, 40, 40] 1,160
GELU-20 [-1, 8, 40, 40] 0
Conv2d-21 [-1, 8, 40, 40] 584
GELU-22 [-1, 8, 40, 40] 0
ConvTranspose2d-23 [-1, 1, 80, 80] 73
================================================================
Total params: 3,298,065
Trainable params: 3,298,065
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 1.35
Params size (MB): 12.58
Estimated Total Size (MB): 13.95
----------------------------------------------------------------
=(1, 80, 80)) summary(model, input_size
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 8, 40, 40] 80
GELU-2 [-1, 8, 40, 40] 0
Conv2d-3 [-1, 8, 40, 40] 584
GELU-4 [-1, 8, 40, 40] 0
Conv2d-5 [-1, 16, 20, 20] 1,168
GELU-6 [-1, 16, 20, 20] 0
Conv2d-7 [-1, 16, 20, 20] 2,320
GELU-8 [-1, 16, 20, 20] 0
Conv2d-9 [-1, 32, 10, 10] 4,640
GELU-10 [-1, 32, 10, 10] 0
Flatten-11 [-1, 3200] 0
Linear-12 [-1, 1024] 3,277,824
Encoder-13 [-1, 1024] 0
Linear-14 [-1, 3200] 3,280,000
Unflatten-15 [-1, 32, 10, 10] 0
ConvTranspose2d-16 [-1, 16, 20, 20] 4,624
GELU-17 [-1, 16, 20, 20] 0
Conv2d-18 [-1, 16, 20, 20] 2,320
GELU-19 [-1, 16, 20, 20] 0
ConvTranspose2d-20 [-1, 8, 40, 40] 1,160
GELU-21 [-1, 8, 40, 40] 0
Conv2d-22 [-1, 8, 40, 40] 584
GELU-23 [-1, 8, 40, 40] 0
ConvTranspose2d-24 [-1, 1, 80, 80] 73
Decoder-25 [-1, 1, 80, 80] 0
================================================================
Total params: 6,575,377
Trainable params: 6,575,377
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 1.41
Params size (MB): 25.08
Estimated Total Size (MB): 26.51
----------------------------------------------------------------
Single channel input single channel output(P10) UNET
= []
test_loss_list_P_10_conv_auto_unet = 1 # selecting P10 as output
y_channel # x_channel = 0
for x_channel in range(X.shape[1]):
####################### Selecting the channel #######################
print('X Channel name : ', target_var_96_list[x_channel])
= X_train_all[:, x_channel:x_channel+1, :,:]
X_train = X_test_all[:, x_channel:x_channel+1, :,:]
X_test = y_train_all[:, y_channel:y_channel+1, :,:]
y_train = y_test_all[:, y_channel:y_channel+1, :,:]
y_test print('Shapes: ', X_train.shape, X_test.shape, y_train.shape, y_test.shape)
####################### Creating the dataset loader #######################
= CustomDataset(X_train, y_train)
train_custom_dataset # print(len(train_custom_dataset))
= 32
batch_size = data.DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
train_loader # print(len(train_loader))
= CustomDataset(X_test, y_test)
test_custom_dataset # print(len(test_custom_dataset))
= 32
batch_size = data.DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader # print(len(test_loader))
#################### Training the model ####################
= Autoencoder_UNET(image_size = 80, num_input_channels = 1, num_output_channels=1, c_hid = 8, latent_dim = 512, activation= nn.GELU)
model
model.to(device)# Define the loss function and optimizer
= nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = []
losses # Training loop
= 200
num_epochs for epoch in trange(num_epochs):
# Set the model to training mode
model.train() = 0.0
total_loss
for inputs, targets in train_loader:
# Zero the gradients
optimizer.zero_grad() = inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
# Backpropagation
loss.backward()
optimizer.step()
+= loss.item()
total_loss
# Print the average loss for this epoch
= total_loss / len(train_loader)
average_loss
losses.append(average_loss)if epoch % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {average_loss:.4f}")
############################# testing the model #############################
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs # Calculate the loss
= criterion(outputs, targets)
loss += loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss
test_loss_list_P_10_conv_auto_unet.append(average_test_loss)print(f"Average Test Loss: {average_test_loss:.4f}")
range(1, num_epochs + 1), losses)
plt.plot('Epoch')
plt.xlabel('Loss')
plt.ylabel('Train Loss vs. Epoch for channel '+target_var_96_list[x_channel])
plt.title(True)
plt.grid( plt.show()
X Channel name : TSURF_K
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:04, 3.10it/s]
Epoch [1/200] Loss: 9347.3935
10%|█ | 21/200 [00:06<00:55, 3.21it/s]
Epoch [21/200] Loss: 445.2209
20%|██ | 41/200 [00:12<00:49, 3.24it/s]
Epoch [41/200] Loss: 421.8352
30%|███ | 61/200 [00:19<00:43, 3.21it/s]
Epoch [61/200] Loss: 411.4913
40%|████ | 81/200 [00:25<00:37, 3.20it/s]
Epoch [81/200] Loss: 406.5355
50%|█████ | 101/200 [00:31<00:31, 3.19it/s]
Epoch [101/200] Loss: 410.0266
60%|██████ | 121/200 [00:37<00:24, 3.21it/s]
Epoch [121/200] Loss: 403.3963
70%|███████ | 141/200 [00:44<00:18, 3.24it/s]
Epoch [141/200] Loss: 402.2838
80%|████████ | 161/200 [00:50<00:12, 3.14it/s]
Epoch [161/200] Loss: 404.5083
90%|█████████ | 181/200 [00:56<00:05, 3.23it/s]
Epoch [181/200] Loss: 408.9846
100%|██████████| 200/200 [01:02<00:00, 3.18it/s]
Average Test Loss: 393.9739
X Channel name : SNOWEW_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:01, 3.23it/s]
Epoch [1/200] Loss: 1062.0485
10%|█ | 21/200 [00:06<00:57, 3.13it/s]
Epoch [21/200] Loss: 425.7831
20%|██ | 41/200 [00:12<00:50, 3.16it/s]
Epoch [41/200] Loss: 414.1087
30%|███ | 61/200 [00:19<00:43, 3.21it/s]
Epoch [61/200] Loss: 407.9692
40%|████ | 81/200 [00:25<00:37, 3.13it/s]
Epoch [81/200] Loss: 409.0262
50%|█████ | 101/200 [00:31<00:29, 3.32it/s]
Epoch [101/200] Loss: 413.5665
60%|██████ | 121/200 [00:37<00:24, 3.19it/s]
Epoch [121/200] Loss: 406.8342
70%|███████ | 141/200 [00:44<00:18, 3.20it/s]
Epoch [141/200] Loss: 401.0203
80%|████████ | 161/200 [00:50<00:12, 3.16it/s]
Epoch [161/200] Loss: 403.7887
90%|█████████ | 181/200 [00:56<00:05, 3.18it/s]
Epoch [181/200] Loss: 412.0851
100%|██████████| 200/200 [01:02<00:00, 3.18it/s]
Average Test Loss: 384.9228
X Channel name : SNOWAGE_HR
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:02, 3.19it/s]
Epoch [1/200] Loss: 1334.1772
10%|█ | 21/200 [00:06<00:55, 3.24it/s]
Epoch [21/200] Loss: 432.3582
20%|██ | 41/200 [00:12<00:51, 3.09it/s]
Epoch [41/200] Loss: 416.2745
30%|███ | 61/200 [00:19<00:42, 3.27it/s]
Epoch [61/200] Loss: 412.7553
40%|████ | 81/200 [00:25<00:36, 3.29it/s]
Epoch [81/200] Loss: 411.6587
50%|█████ | 101/200 [00:31<00:30, 3.28it/s]
Epoch [101/200] Loss: 409.5164
60%|██████ | 121/200 [00:37<00:25, 3.11it/s]
Epoch [121/200] Loss: 409.2917
70%|███████ | 141/200 [00:43<00:18, 3.17it/s]
Epoch [141/200] Loss: 414.7216
80%|████████ | 161/200 [00:50<00:12, 3.23it/s]
Epoch [161/200] Loss: 405.0483
90%|█████████ | 181/200 [00:56<00:05, 3.18it/s]
Epoch [181/200] Loss: 404.8374
100%|██████████| 200/200 [01:02<00:00, 3.19it/s]
Average Test Loss: 366.9726
X Channel name : PRATE_MMpH
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:07, 2.95it/s]
Epoch [1/200] Loss: 1228.4596
10%|█ | 21/200 [00:06<00:56, 3.19it/s]
Epoch [21/200] Loss: 418.4353
20%|██ | 41/200 [00:13<00:49, 3.20it/s]
Epoch [41/200] Loss: 411.4351
30%|███ | 61/200 [00:19<00:42, 3.30it/s]
Epoch [61/200] Loss: 411.3802
40%|████ | 81/200 [00:25<00:36, 3.27it/s]
Epoch [81/200] Loss: 408.2070
50%|█████ | 101/200 [00:31<00:30, 3.29it/s]
Epoch [101/200] Loss: 402.7285
60%|██████ | 121/200 [00:38<00:24, 3.26it/s]
Epoch [121/200] Loss: 405.0596
70%|███████ | 141/200 [00:44<00:18, 3.19it/s]
Epoch [141/200] Loss: 401.6986
80%|████████ | 161/200 [00:50<00:12, 3.22it/s]
Epoch [161/200] Loss: 401.9078
90%|█████████ | 181/200 [00:56<00:06, 3.16it/s]
Epoch [181/200] Loss: 405.1741
100%|██████████| 200/200 [01:02<00:00, 3.19it/s]
Average Test Loss: 397.5330
X Channel name : CLOUD_OD
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:00, 3.28it/s]
Epoch [1/200] Loss: 1086.5345
10%|█ | 21/200 [00:06<00:54, 3.29it/s]
Epoch [21/200] Loss: 115.0784
20%|██ | 41/200 [00:12<00:48, 3.26it/s]
Epoch [41/200] Loss: 55.1418
30%|███ | 61/200 [00:18<00:41, 3.35it/s]
Epoch [61/200] Loss: 35.0410
40%|████ | 81/200 [00:24<00:35, 3.32it/s]
Epoch [81/200] Loss: 25.9328
50%|█████ | 101/200 [00:31<00:30, 3.27it/s]
Epoch [101/200] Loss: 22.8977
60%|██████ | 121/200 [00:37<00:23, 3.32it/s]
Epoch [121/200] Loss: 22.6624
70%|███████ | 141/200 [00:43<00:18, 3.15it/s]
Epoch [141/200] Loss: 18.5878
80%|████████ | 161/200 [00:49<00:11, 3.25it/s]
Epoch [161/200] Loss: 16.7638
90%|█████████ | 181/200 [00:55<00:05, 3.31it/s]
Epoch [181/200] Loss: 18.7935
100%|██████████| 200/200 [01:01<00:00, 3.27it/s]
Average Test Loss: 16.7694
X Channel name : U10_MpS
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.40it/s]
Epoch [1/200] Loss: 1370.2549
10%|█ | 21/200 [00:06<00:53, 3.35it/s]
Epoch [21/200] Loss: 253.0794
20%|██ | 41/200 [00:12<00:48, 3.29it/s]
Epoch [41/200] Loss: 73.3127
30%|███ | 61/200 [00:18<00:44, 3.12it/s]
Epoch [61/200] Loss: 34.2972
40%|████ | 81/200 [00:24<00:37, 3.16it/s]
Epoch [81/200] Loss: 23.2712
50%|█████ | 101/200 [00:31<00:30, 3.21it/s]
Epoch [101/200] Loss: 18.8157
60%|██████ | 121/200 [00:37<00:25, 3.14it/s]
Epoch [121/200] Loss: 17.0560
70%|███████ | 141/200 [00:43<00:17, 3.46it/s]
Epoch [141/200] Loss: 13.7983
80%|████████ | 161/200 [00:49<00:12, 3.23it/s]
Epoch [161/200] Loss: 12.2436
90%|█████████ | 181/200 [00:55<00:05, 3.21it/s]
Epoch [181/200] Loss: 11.2890
100%|██████████| 200/200 [01:01<00:00, 3.26it/s]
Average Test Loss: 16.9951
X Channel name : V10_MpS
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:00, 3.28it/s]
Epoch [1/200] Loss: 1202.3667
10%|█ | 21/200 [00:06<00:54, 3.29it/s]
Epoch [21/200] Loss: 284.9583
20%|██ | 41/200 [00:12<00:48, 3.26it/s]
Epoch [41/200] Loss: 49.5132
30%|███ | 61/200 [00:18<00:42, 3.29it/s]
Epoch [61/200] Loss: 23.7217
40%|████ | 81/200 [00:24<00:37, 3.18it/s]
Epoch [81/200] Loss: 16.5870
50%|█████ | 101/200 [00:31<00:29, 3.36it/s]
Epoch [101/200] Loss: 13.4451
60%|██████ | 121/200 [00:39<00:33, 2.34it/s]
Epoch [121/200] Loss: 11.5799
70%|███████ | 141/200 [00:47<00:25, 2.32it/s]
Epoch [141/200] Loss: 10.4765
80%|████████ | 161/200 [00:55<00:15, 2.44it/s]
Epoch [161/200] Loss: 13.5249
90%|█████████ | 181/200 [01:03<00:06, 3.06it/s]
Epoch [181/200] Loss: 9.1645
100%|██████████| 200/200 [01:10<00:00, 2.85it/s]
Average Test Loss: 8.1653
X Channel name : T2_K
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:01, 3.25it/s]
Epoch [1/200] Loss: 1344.0557
10%|█ | 21/200 [00:08<01:31, 1.95it/s]
Epoch [21/200] Loss: 410.2120
20%|██ | 41/200 [00:15<00:48, 3.25it/s]
Epoch [41/200] Loss: 409.9936
30%|███ | 61/200 [00:21<00:42, 3.27it/s]
Epoch [61/200] Loss: 403.0425
40%|████ | 81/200 [00:27<00:34, 3.47it/s]
Epoch [81/200] Loss: 413.2495
50%|█████ | 101/200 [00:33<00:32, 3.06it/s]
Epoch [101/200] Loss: 410.5454
60%|██████ | 121/200 [00:41<00:22, 3.51it/s]
Epoch [121/200] Loss: 408.9813
70%|███████ | 141/200 [00:46<00:16, 3.63it/s]
Epoch [141/200] Loss: 404.3268
80%|████████ | 161/200 [00:52<00:11, 3.53it/s]
Epoch [161/200] Loss: 402.1354
90%|█████████ | 181/200 [00:58<00:05, 3.47it/s]
Epoch [181/200] Loss: 403.1094
100%|██████████| 200/200 [01:03<00:00, 3.14it/s]
Average Test Loss: 370.9240
X Channel name : SWSFC_WpM2
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:13, 2.70it/s]
Epoch [1/200] Loss: 1138.2640
10%|█ | 21/200 [00:06<00:51, 3.49it/s]
Epoch [21/200] Loss: 428.3663
20%|██ | 41/200 [00:11<00:45, 3.49it/s]
Epoch [41/200] Loss: 428.3100
30%|███ | 61/200 [00:17<00:39, 3.49it/s]
Epoch [61/200] Loss: 411.9817
40%|████ | 81/200 [00:23<00:34, 3.50it/s]
Epoch [81/200] Loss: 406.6649
50%|█████ | 101/200 [00:29<00:29, 3.36it/s]
Epoch [101/200] Loss: 406.5036
60%|██████ | 121/200 [00:35<00:22, 3.47it/s]
Epoch [121/200] Loss: 404.4717
70%|███████ | 141/200 [00:41<00:16, 3.52it/s]
Epoch [141/200] Loss: 413.3260
80%|████████ | 161/200 [00:46<00:10, 3.61it/s]
Epoch [161/200] Loss: 404.8024
90%|█████████ | 181/200 [00:52<00:05, 3.50it/s]
Epoch [181/200] Loss: 408.2683
100%|██████████| 200/200 [00:58<00:00, 3.45it/s]
Average Test Loss: 375.5593
X Channel name : SOLM_M3pM3
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:57, 3.47it/s]
Epoch [1/200] Loss: 1171.9761
10%|█ | 21/200 [00:06<00:52, 3.40it/s]
Epoch [21/200] Loss: 311.1736
20%|██ | 41/200 [00:11<00:46, 3.43it/s]
Epoch [41/200] Loss: 213.8882
30%|███ | 61/200 [00:17<00:40, 3.43it/s]
Epoch [61/200] Loss: 86.1118
40%|████ | 81/200 [00:23<00:35, 3.37it/s]
Epoch [81/200] Loss: 43.9884
50%|█████ | 101/200 [00:34<00:50, 1.95it/s]
Epoch [101/200] Loss: 33.4366
60%|██████ | 121/200 [00:48<00:21, 3.66it/s]
Epoch [121/200] Loss: 26.7442
70%|███████ | 141/200 [00:54<00:17, 3.45it/s]
Epoch [141/200] Loss: 25.1023
80%|████████ | 161/200 [00:59<00:12, 3.17it/s]
Epoch [161/200] Loss: 19.1235
90%|█████████ | 181/200 [01:05<00:05, 3.38it/s]
Epoch [181/200] Loss: 16.5218
100%|██████████| 200/200 [01:12<00:00, 2.74it/s]
Average Test Loss: 17.8878
X Channel name : CLDTOP_KM
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:58, 3.40it/s]
Epoch [1/200] Loss: 1209.7468
10%|█ | 21/200 [00:06<00:51, 3.46it/s]
Epoch [21/200] Loss: 418.6483
20%|██ | 41/200 [00:12<00:48, 3.31it/s]
Epoch [41/200] Loss: 411.9460
30%|███ | 61/200 [00:17<00:40, 3.44it/s]
Epoch [61/200] Loss: 416.4541
40%|████ | 81/200 [00:23<00:34, 3.48it/s]
Epoch [81/200] Loss: 410.3450
50%|█████ | 101/200 [00:29<00:28, 3.49it/s]
Epoch [101/200] Loss: 404.6752
60%|██████ | 121/200 [00:35<00:22, 3.50it/s]
Epoch [121/200] Loss: 411.7702
70%|███████ | 141/200 [00:40<00:17, 3.45it/s]
Epoch [141/200] Loss: 404.7753
80%|████████ | 161/200 [00:46<00:11, 3.43it/s]
Epoch [161/200] Loss: 410.6092
90%|█████████ | 181/200 [00:52<00:05, 3.49it/s]
Epoch [181/200] Loss: 406.9762
100%|██████████| 200/200 [00:58<00:00, 3.43it/s]
Average Test Loss: 364.9331
X Channel name : CAPE
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<01:05, 3.05it/s]
Epoch [1/200] Loss: 1286.6324
10%|█ | 21/200 [00:06<00:50, 3.56it/s]
Epoch [21/200] Loss: 416.3279
20%|██ | 41/200 [00:11<00:45, 3.46it/s]
Epoch [41/200] Loss: 412.8959
30%|███ | 61/200 [00:17<00:39, 3.53it/s]
Epoch [61/200] Loss: 432.9754
40%|████ | 81/200 [00:23<00:35, 3.37it/s]
Epoch [81/200] Loss: 423.2503
50%|█████ | 101/200 [00:29<00:26, 3.68it/s]
Epoch [101/200] Loss: 407.6195
60%|██████ | 121/200 [00:34<00:22, 3.56it/s]
Epoch [121/200] Loss: 405.0848
70%|███████ | 141/200 [00:40<00:16, 3.55it/s]
Epoch [141/200] Loss: 418.5362
80%|████████ | 161/200 [00:46<00:11, 3.53it/s]
Epoch [161/200] Loss: 398.3066
90%|█████████ | 181/200 [00:51<00:05, 3.50it/s]
Epoch [181/200] Loss: 399.7189
100%|██████████| 200/200 [00:57<00:00, 3.49it/s]
Average Test Loss: 366.6939
X Channel name : PBL_WRF_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:57, 3.44it/s]
Epoch [1/200] Loss: 1160.2099
10%|█ | 21/200 [00:06<00:52, 3.40it/s]
Epoch [21/200] Loss: 410.9463
20%|██ | 41/200 [00:11<00:46, 3.43it/s]
Epoch [41/200] Loss: 407.0937
30%|███ | 61/200 [00:18<00:40, 3.42it/s]
Epoch [61/200] Loss: 406.6331
40%|████ | 81/200 [00:23<00:33, 3.57it/s]
Epoch [81/200] Loss: 405.4273
50%|█████ | 101/200 [00:29<00:27, 3.55it/s]
Epoch [101/200] Loss: 406.6442
60%|██████ | 121/200 [00:35<00:23, 3.32it/s]
Epoch [121/200] Loss: 405.2662
70%|███████ | 141/200 [00:41<00:17, 3.47it/s]
Epoch [141/200] Loss: 400.4130
80%|████████ | 161/200 [00:46<00:11, 3.43it/s]
Epoch [161/200] Loss: 403.3931
90%|█████████ | 181/200 [00:52<00:05, 3.47it/s]
Epoch [181/200] Loss: 409.0273
100%|██████████| 200/200 [00:57<00:00, 3.45it/s]
Average Test Loss: 364.0550
X Channel name : PBL_YSU_M
Shapes: (1324, 1, 80, 80) (332, 1, 80, 80) (1324, 1, 80, 80) (332, 1, 80, 80)
0%| | 1/200 [00:00<00:54, 3.64it/s]
Epoch [1/200] Loss: 1176.6661
10%|█ | 21/200 [00:06<00:49, 3.63it/s]
Epoch [21/200] Loss: 410.8530
20%|██ | 41/200 [00:12<00:47, 3.35it/s]
Epoch [41/200] Loss: 408.6113
30%|███ | 61/200 [00:17<00:40, 3.41it/s]
Epoch [61/200] Loss: 400.7242
40%|████ | 81/200 [00:23<00:36, 3.23it/s]
Epoch [81/200] Loss: 407.2048
50%|█████ | 101/200 [00:30<00:30, 3.24it/s]
Epoch [101/200] Loss: 407.0074
60%|██████ | 121/200 [00:36<00:25, 3.12it/s]
Epoch [121/200] Loss: 405.1491
70%|███████ | 141/200 [00:42<00:18, 3.27it/s]
Epoch [141/200] Loss: 428.9850
80%|████████ | 161/200 [00:48<00:11, 3.32it/s]
Epoch [161/200] Loss: 411.7955
90%|█████████ | 181/200 [00:54<00:05, 3.29it/s]
Epoch [181/200] Loss: 412.6084
100%|██████████| 200/200 [01:00<00:00, 3.29it/s]
Average Test Loss: 375.2351
test_loss_list_P_10_conv_auto_unet
[393.97390192205256,
384.9228349165483,
366.9725619229403,
397.5330338911577,
16.76939444108443,
16.99508944424716,
8.165278738195246,
370.9240167791193,
375.55933172052556,
17.88784339211204,
364.93308327414775,
366.69390869140625,
364.0550065474077,
375.23513239080256]
= list(zip(target_var_96_list, test_loss_list_P_10_conv_auto_unet))
name_loss_pairs
# Sort based on loss values
= sorted(name_loss_pairs, key=lambda x: x[1])
sorted_name_loss_pairs for pair in sorted_name_loss_pairs:
print(pair)
# print(sorted_name_loss_pairs)
# Select the lowest 3 names
= [pair[0] for pair in sorted_name_loss_pairs[:3]]
lowest_3_names
print("Lowest 3 names:", lowest_3_names)
('V10_MpS', 8.165278738195246)
('CLOUD_OD', 16.76939444108443)
('U10_MpS', 16.99508944424716)
('SOLM_M3pM3', 17.88784339211204)
('PBL_WRF_M', 364.0550065474077)
('CLDTOP_KM', 364.93308327414775)
('CAPE', 366.69390869140625)
('SNOWAGE_HR', 366.9725619229403)
('T2_K', 370.9240167791193)
('PBL_YSU_M', 375.23513239080256)
('SWSFC_WpM2', 375.55933172052556)
('SNOWEW_M', 384.9228349165483)
('TSURF_K', 393.97390192205256)
('PRATE_MMpH', 397.5330338911577)
Lowest 3 names: ['V10_MpS', 'CLOUD_OD', 'U10_MpS']
Insights
# Adding values at corresponding indices
= [(x + y)/2 for x, y in zip(test_loss_list_P_10_conv_auto_unet, test_loss_list_P_25_conv_auto_unet)]
result
# Creating a DataFrame
= {'Input channel': target_var_96_list, 'P10': test_loss_list_P_10_conv_auto_unet, 'P25': test_loss_list_P_25_conv_auto_unet, 'P10+P25 avg': result}
data_frame = pd.DataFrame(data_frame)
df
# Sorting the DataFrame based on "List1 + List2"
= df.sort_values(by='P10+P25 avg')
df_sorted
# Displaying the sorted DataFrame
= df_sorted.round(1)
df_rounded '/home/rishabh.mondal/climax_alternative/Climax_2/results/test_loss_list_P_10_P25_conv_auto_unet.csv', index=False)
df_rounded.to_csv(
df_rounded
Input channel | P10 | P25 | P10+P25 avg | |
---|---|---|---|---|
6 | V10_MpS | 8.2 | 3.5 | 5.8 |
5 | U10_MpS | 17.0 | 3.2 | 10.1 |
9 | SOLM_M3pM3 | 17.9 | 5.3 | 11.6 |
4 | CLOUD_OD | 16.8 | 8.0 | 12.4 |
10 | CLDTOP_KM | 364.9 | 191.1 | 278.0 |
12 | PBL_WRF_M | 364.1 | 192.7 | 278.4 |
11 | CAPE | 366.7 | 191.7 | 279.2 |
7 | T2_K | 370.9 | 192.5 | 281.7 |
2 | SNOWAGE_HR | 367.0 | 198.3 | 282.6 |
8 | SWSFC_WpM2 | 375.6 | 191.1 | 283.3 |
13 | PBL_YSU_M | 375.2 | 194.3 | 284.8 |
1 | SNOWEW_M | 384.9 | 191.9 | 288.4 |
0 | TSURF_K | 394.0 | 192.9 | 293.4 |
3 | PRATE_MMpH | 397.5 | 199.4 | 298.4 |
Training on top 4 channel and predicting on all channel UNET
# target_var_96_list =['TSURF_K',
# 'SNOWEW_M', 'SNOWAGE_HR', 'PRATE_MMpH', 'CLOUD_OD', 'U10_MpS',
# 'V10_MpS', 'T2_K', 'SWSFC_WpM2', 'SOLM_M3pM3', 'CLDTOP_KM', 'CAPE',
# 'PBL_WRF_M', 'PBL_YSU_M'] # ['U10_MpS', 'T2_K', 'V10_MpS']
= ['V10_MpS','U10_MpS','SOLM_M3pM3', 'CLOUD_OD']
target_var_96_list= ['P25','P10']
target_var_120_list = get_data(target_var_96_list, target_var_120_list)
X,y
from sklearn.model_selection import train_test_split
= train_test_split(X, y, test_size=0.2, random_state=42)
X_train_all, X_test_all, y_train_all, y_test_all X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape
X shape (1656, 4, 80, 80)
y shape (1656, 2, 80, 80)
((1324, 4, 80, 80), (332, 4, 80, 80), (1324, 2, 80, 80), (332, 2, 80, 80))
print(X_train_all.shape, X_test_all.shape, y_train_all.shape, y_test_all.shape)
= CustomDataset(X_train_all, y_train_all)
train_custom_dataset # print(len(train_custom_dataset))
= 32
batch_size = data.DataLoader(train_custom_dataset, batch_size=batch_size, shuffle=True)
train_loader # print(len(train_loader))
= CustomDataset(X_test_all, y_test_all)
test_custom_dataset # print(len(test_custom_dataset))
= 32
batch_size = data.DataLoader(test_custom_dataset, batch_size=batch_size, shuffle=False)
test_loader # print(len(test_loader))
#################### Training the model ####################
= Autoencoder_UNET(image_size = 80, num_input_channels = 4, num_output_channels=2, c_hid = 64, latent_dim = 2048, activation= nn.GELU)
model
model.to(device) from torchinfo import summary
print(summary(model, input_size=(1656, 4, 80, 80)))
# Define the loss function and optimizer
= nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = []
losses # Training loop
= 100
num_epochs for epoch in trange(num_epochs):
# Set the model to training mode
model.train() = 0.0
total_loss
for inputs, targets in train_loader:
# Zero the gradients
optimizer.zero_grad() = inputs.to(device)
inputs = targets.to(device)
targets # Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
# Backpropagation
loss.backward()
optimizer.step()
+= loss.item()
total_loss
# Print the average loss for this epoch
= total_loss / len(train_loader)
average_loss
losses.append(average_loss)if epoch % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {average_loss:.4f}")
############################# testing the model #############################
eval() # Set the model to evaluation mode
model.= 0.0
test_loss
with torch.no_grad():
for inputs, targets in test_loader:
= inputs.to(device)
inputs = targets.to(device)
targets
# Forward pass
= model(inputs)
outputs
# Calculate the loss
= criterion(outputs, targets)
loss
+= loss.item()
test_loss
# Print the average test loss
= test_loss / len(test_loader)
average_test_loss # test_loss_list_P_25.append(average_test_loss)
print(f"Average Test Loss: {average_test_loss:.4f}")
range(1, num_epochs + 1), losses)
plt.plot('Epoch')
plt.xlabel('Loss')
plt.ylabel(# plt.title('Train Loss vs. Epoch for channel '+target_var_96_list[x_channel])
True)
plt.grid( plt.show()
(1324, 4, 80, 80) (332, 4, 80, 80) (1324, 2, 80, 80) (332, 2, 80, 80)
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Autoencoder_UNET [1656, 2, 80, 80] --
├─Encoder: 1-1 -- --
│ └─Sequential: 2-1 -- --
│ │ └─Conv2d: 3-1 [1656, 64, 40, 40] 2,368
│ │ └─GELU: 3-2 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-3 [1656, 64, 40, 40] 36,928
│ │ └─GELU: 3-4 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-5 [1656, 128, 20, 20] 73,856
│ │ └─GELU: 3-6 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-7 [1656, 128, 20, 20] 147,584
│ │ └─GELU: 3-8 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-9 [1656, 256, 10, 10] 295,168
│ │ └─GELU: 3-10 [1656, 256, 10, 10] --
│ │ └─Flatten: 3-11 [1656, 25600] --
│ │ └─Linear: 3-12 [1656, 2048] 52,430,848
├─Decoder: 1-2 -- --
│ └─Sequential: 2-2 -- --
│ │ └─Linear: 3-13 [1656, 25600] 52,454,400
│ │ └─Unflatten: 3-14 [1656, 256, 10, 10] --
│ │ └─ConvTranspose2d: 3-15 [1656, 128, 20, 20] 295,040
│ │ └─GELU: 3-16 [1656, 128, 20, 20] --
│ │ └─Conv2d: 3-17 [1656, 128, 20, 20] 147,584
│ │ └─GELU: 3-18 [1656, 128, 20, 20] --
│ │ └─ConvTranspose2d: 3-19 [1656, 64, 40, 40] 73,792
│ │ └─GELU: 3-20 [1656, 64, 40, 40] --
│ │ └─Conv2d: 3-21 [1656, 64, 40, 40] 36,928
│ │ └─GELU: 3-22 [1656, 64, 40, 40] --
│ │ └─ConvTranspose2d: 3-23 [1656, 2, 80, 80] 1,154
==========================================================================================
Total params: 105,995,650
Trainable params: 105,995,650
Non-trainable params: 0
Total mult-adds (Units.TERABYTES): 1.07
==========================================================================================
Input size (MB): 169.57
Forward/backward pass size (MB): 9014.58
Params size (MB): 423.98
Estimated Total Size (MB): 9608.13
==========================================================================================
1%| | 1/100 [00:01<03:17, 1.99s/it]
Epoch [1/100] Loss: 519.4141
21%|██ | 21/100 [00:36<02:15, 1.72s/it]
Epoch [21/100] Loss: 10.5276
41%|████ | 41/100 [01:10<01:41, 1.72s/it]
Epoch [41/100] Loss: 2.6655
61%|██████ | 61/100 [01:44<01:06, 1.72s/it]
Epoch [61/100] Loss: 5.3309
81%|████████ | 81/100 [02:19<00:32, 1.72s/it]
Epoch [81/100] Loss: 2.3606
100%|██████████| 100/100 [02:52<00:00, 1.72s/it]
Average Test Loss: 0.9041
'model/auto_conv_unet_in4_out2.pt')
torch.save(model.state_dict(), = Autoencoder_UNET(image_size = 80, num_input_channels = 4, num_output_channels=2, c_hid = 64, latent_dim = 2048, activation= nn.GELU)
model
model.to(device) 'model/auto_conv_unet_in4_out2.pt')) model.load_state_dict(torch.load(
<All keys matched successfully>