from keras.datasets import mnist
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import glob
from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.layers import LeakyReLU, Dropout
from tensorflow.keras.models import Sequential, Model, load_model
import time
# a function to format display the losses
def hmsString(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60
    return "{}:{:>02}:{:>05.2f}".format(h, m, s)

# downsample and introduce noise in the images
def downSampleAndNoisyfi(X):
    shape = X[0].shape
    X_down = []
    for x_i in X:
       x_c = cv2.resize(x_i, (shape[0]//4, shape[1]//4), interpolation = cv2.INTER_AREA)
       x_c = np.clip(x_c+ np.random.normal(0, 5, x_c.shape) , 0, 255).astype('uint8')
       X_down.append(x_c)
    X_down = np.array(X_down, dtype = 'uint8')
    return X_down
################# CODE FOR GENERATOR BLOCK
def Generator(input_shape):
    X_input = Input(input_shape)
    X = Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(X_input)
    X = BatchNormalization(momentum=0.5)(X)
    X = Activation('relu')(X)
    X_shortcut = X
    X = Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(X)
    X = BatchNormalization(momentum=0.5)(X)
    X = Activation('relu')(X)  
    X = Add()([X_shortcut, X])  
    X_shortcut = X
    X = Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(X)
    X = BatchNormalization(momentum=0.5)(X)
    X = Activation('relu')(X)  
    X = Add()([X_shortcut, X])
    X = Activation('relu')(X)
    X = UpSampling2D(size=2)(X)
    
    X = Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(X)
    X = BatchNormalization(momentum=0.5)(X)
    X = Activation('relu')(X)
    X_shortcut = X
    X = Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(X)
    X = BatchNormalization(momentum=0.5)(X)
    X = Activation('relu')(X)
    X = Add()([X_shortcut, X])
    X_shortcut = X
    X = Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(X)
    X = BatchNormalization(momentum=0.5)(X)
    X = Activation('relu')(X)   
    X = Add()([X_shortcut, X])
    X = Activation('relu')(X)
    X = UpSampling2D(size=2)(X)
    
    X_shortcut = X
    X = Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(X)
    X = BatchNormalization(momentum=0.5)(X)
    X = Activation('relu')(X)
    
    X = Conv2D(filters = 1, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(X)
    X = BatchNormalization(momentum=0.5)(X)
    X = Activation('relu')(X)
    
    generator_model = Model(inputs=X_input, outputs=X)
    return generator_model
################# CODE FOR DISCRIMINATOR BLOCK
def Discriminator(input_shape):
    X_input = Input(input_shape)
    X = Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(X_input)
    X = Activation('relu')(X)
    
    X = Conv2D(filters = 64, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(X)
    X = BatchNormalization(momentum=0.8)(X)
    X = Activation('relu')(X)
    
    discriminator_model = Model(inputs=X_input, outputs=X)
    return discriminator_model
# One step of the test step
@tf.function
def train_step(X, Y, generator, discriminator, generator_optimizer, discriminator_optimizer):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    generated_images = generator(X, training=True)

    real_output = discriminator(Y, training=True)
    fake_output = discriminator(generated_images, training=False)

    gen_loss = tf.keras.losses.MSE(Y, generated_images)
    disc_loss = tf.keras.losses.MSE(real_output, fake_output)
    

    gradients_of_generator = gen_tape.gradient(\
        gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(\
        disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(
        gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(
        gradients_of_discriminator, 
        discriminator.trainable_variables))
  return gen_loss,disc_loss

# The main function to train the GAN
def train(X_train, Y_train, generator, discriminator, batch_size=100, epochs=50):
    generator_optimizer = tf.keras.optimizers.Adam(1.5e-4,0.5)
    discriminator_optimizer = tf.keras.optimizers.Adam(1.5e-4,0.5)
    start = time.time()
    for epoch in range(epochs):
        epoch_start = time.time()
        gen_loss_list = []
        disc_loss_list = []
        
        prev_i = 0
        for i in range(X_train.shape[0]):
            if((i+1)%batch_size == 0):
                t = train_step(X_train[prev_i:i+1], Y_train[prev_i:i+1], generator, discriminator, generator_optimizer, discriminator_optimizer)
                gen_loss_list.append(t[0])
                disc_loss_list.append(t[1])
                prev_i = i+1
        g_loss = np.sum(np.array(gen_loss_list)) / np.sum(np.array(gen_loss_list).shape)
        d_loss = np.sum(np.array(disc_loss_list)) / np.sum(np.array(disc_loss_list).shape)
        
        epoch_elapsed = time.time()-epoch_start
        print (f'Epoch {epoch+1}, gen loss={g_loss},disc loss={d_loss}, {hmsString(epoch_elapsed)}')
        
    elapsed = time.time()-start
    print (f'Training time: {hmsString(elapsed)}')
    
    
# loading the dataset(the original image are the HR 28*28 images)
(Y_train, _), (Y_test, _) = mnist.load_data()
# downsampling and introducing gaussian noise
# this downsampled and noised dataset is out X or inputs
X_train = downSampleAndNoisyfi(Y_train)
X_test = downSampleAndNoisyfi(Y_test)

# introduce a new dimension to the data (None, 28, 28, 1)
X_test = X_test[..., np.newaxis]
X_train = X_train[..., np.newaxis]
Y_train = Y_train[..., np.newaxis]
Y_test = Y_test[..., np.newaxis]

# Creating a generator and discriminator model
generator = Generator((7,7,1))
discriminator = Discriminator((28,28,1))

# Showing the summary of generator and discriminator
generator.summary()
discriminator.summary()
# training with batch size of 100 and for 50 epochs
train(X_train, Y_train, generator, discriminator, 100, 5)#50)

# save the generator model for future use
generator.save("mnist_generator_model")
generator.save("mnist_generator_model.h5")
Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 input_5 (InputLayer)        [(None, 7, 7, 1)]            0         []                            
                                                                                                  
 conv2d_20 (Conv2D)          (None, 7, 7, 32)             320       ['input_5[0][0]']             
                                                                                                  
 batch_normalization_18 (Ba  (None, 7, 7, 32)             128       ['conv2d_20[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_24 (Activation)  (None, 7, 7, 32)             0         ['batch_normalization_18[0][0]
                                                                    ']                            
                                                                                                  
 conv2d_21 (Conv2D)          (None, 7, 7, 32)             9248      ['activation_24[0][0]']       
                                                                                                  
 batch_normalization_19 (Ba  (None, 7, 7, 32)             128       ['conv2d_21[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_25 (Activation)  (None, 7, 7, 32)             0         ['batch_normalization_19[0][0]
                                                                    ']                            
                                                                                                  
 add_8 (Add)                 (None, 7, 7, 32)             0         ['activation_24[0][0]',       
                                                                     'activation_25[0][0]']       
                                                                                                  
 conv2d_22 (Conv2D)          (None, 7, 7, 32)             9248      ['add_8[0][0]']               
                                                                                                  
 batch_normalization_20 (Ba  (None, 7, 7, 32)             128       ['conv2d_22[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_26 (Activation)  (None, 7, 7, 32)             0         ['batch_normalization_20[0][0]
                                                                    ']                            
                                                                                                  
 add_9 (Add)                 (None, 7, 7, 32)             0         ['add_8[0][0]',               
                                                                     'activation_26[0][0]']       
                                                                                                  
 activation_27 (Activation)  (None, 7, 7, 32)             0         ['add_9[0][0]']               
                                                                                                  
 up_sampling2d_4 (UpSamplin  (None, 14, 14, 32)           0         ['activation_27[0][0]']       
 g2D)                                                                                             
                                                                                                  
 conv2d_23 (Conv2D)          (None, 14, 14, 32)           9248      ['up_sampling2d_4[0][0]']     
                                                                                                  
 batch_normalization_21 (Ba  (None, 14, 14, 32)           128       ['conv2d_23[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_28 (Activation)  (None, 14, 14, 32)           0         ['batch_normalization_21[0][0]
                                                                    ']                            
                                                                                                  
 conv2d_24 (Conv2D)          (None, 14, 14, 32)           9248      ['activation_28[0][0]']       
                                                                                                  
 batch_normalization_22 (Ba  (None, 14, 14, 32)           128       ['conv2d_24[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_29 (Activation)  (None, 14, 14, 32)           0         ['batch_normalization_22[0][0]
                                                                    ']                            
                                                                                                  
 add_10 (Add)                (None, 14, 14, 32)           0         ['activation_28[0][0]',       
                                                                     'activation_29[0][0]']       
                                                                                                  
 conv2d_25 (Conv2D)          (None, 14, 14, 32)           9248      ['add_10[0][0]']              
                                                                                                  
 batch_normalization_23 (Ba  (None, 14, 14, 32)           128       ['conv2d_25[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_30 (Activation)  (None, 14, 14, 32)           0         ['batch_normalization_23[0][0]
                                                                    ']                            
                                                                                                  
 add_11 (Add)                (None, 14, 14, 32)           0         ['add_10[0][0]',              
                                                                     'activation_30[0][0]']       
                                                                                                  
 activation_31 (Activation)  (None, 14, 14, 32)           0         ['add_11[0][0]']              
                                                                                                  
 up_sampling2d_5 (UpSamplin  (None, 28, 28, 32)           0         ['activation_31[0][0]']       
 g2D)                                                                                             
                                                                                                  
 conv2d_26 (Conv2D)          (None, 28, 28, 32)           9248      ['up_sampling2d_5[0][0]']     
                                                                                                  
 batch_normalization_24 (Ba  (None, 28, 28, 32)           128       ['conv2d_26[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_32 (Activation)  (None, 28, 28, 32)           0         ['batch_normalization_24[0][0]
                                                                    ']                            
                                                                                                  
 conv2d_27 (Conv2D)          (None, 28, 28, 1)            289       ['activation_32[0][0]']       
                                                                                                  
 batch_normalization_25 (Ba  (None, 28, 28, 1)            4         ['conv2d_27[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_33 (Activation)  (None, 28, 28, 1)            0         ['batch_normalization_25[0][0]
                                                                    ']                            
                                                                                                  
==================================================================================================
Total params: 56997 (222.64 KB)
Trainable params: 56547 (220.89 KB)
Non-trainable params: 450 (1.76 KB)
__________________________________________________________________________________________________
Model: "model_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_6 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d_28 (Conv2D)          (None, 28, 28, 32)        320       
                                                                 
 activation_34 (Activation)  (None, 28, 28, 32)        0         
                                                                 
 conv2d_29 (Conv2D)          (None, 28, 28, 64)        18496     
                                                                 
 batch_normalization_26 (Ba  (None, 28, 28, 64)        256       
 tchNormalization)                                               
                                                                 
 activation_35 (Activation)  (None, 28, 28, 64)        0         
                                                                 
=================================================================
Total params: 19072 (74.50 KB)
Trainable params: 18944 (74.00 KB)
Non-trainable params: 128 (512.00 Byte)
_________________________________________________________________
Epoch 1, gen loss=444240933.9259259,disc loss=5967.341931216931, 0:06:59.43
Epoch 2, gen loss=442459325.6296296,disc loss=2971.5945767195767, 0:06:50.58
Epoch 3, gen loss=441174395.2592593,disc loss=1824.6559193121693, 0:05:31.16
Epoch 4, gen loss=439941780.994709,disc loss=1088.3287037037037, 0:05:21.85
Epoch 5, gen loss=438728021.3333333,disc loss=618.3444527116402, 0:05:26.83
Training time: 0:30:09.87
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: mnist_generator_model\assets
INFO:tensorflow:Assets written to: mnist_generator_model\assets
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
C:\Users\HP\AppData\Roaming\Python\Python311\site-packages\keras\src\engine\training.py:3000: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.
  saving_api.save_model(
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
# testing the model
Y_pred = generator.predict(X_test)
# showing the first 5 results
fig,a =  plt.subplots(3,5)
fig.subplots_adjust(hspace=0.5, wspace=0.1)
for i in range(5):
    a[0][i].imshow(X_test[i])
    a[0][i].axes.get_xaxis().set_visible(False)
    a[0][i].axes.get_yaxis().set_visible(False)
    a[0][i].title.set_text("LR: "+str(i+1))
    
    a[1][i].imshow(Y_pred[i])
    a[1][i].axes.get_xaxis().set_visible(False)
    a[1][i].axes.get_yaxis().set_visible(False)
    a[1][i].title.set_text("SR: "+str(i+1)) 
    
    a[2][i].imshow(Y_test[i])
    a[2][i].axes.get_xaxis().set_visible(False)
    a[2][i].axes.get_yaxis().set_visible(False)
    a[2][i].title.set_text("HR: "+str(i+1)) 
    
313/313 [==============================] - 9s 27ms/step

# showing the first 5 random results
import random
figb,ab =  plt.subplots(3,5)
figb.subplots_adjust(hspace=0.5, wspace=0.1)
for i in range(5):
    ii = random.randint(0, 10000) 
    
    ab[0][i].imshow(X_test[ii])
    ab[0][i].axes.get_xaxis().set_visible(False)
    ab[0][i].axes.get_yaxis().set_visible(False)
    ab[0][i].title.set_text("LR: "+str(i+1))
    
    ab[1][i].imshow(Y_pred[ii])
    ab[1][i].axes.get_xaxis().set_visible(False)
    ab[1][i].axes.get_yaxis().set_visible(False)
    ab[1][i].title.set_text("SR: "+str(i+1)) 
    
    ab[2][i].imshow(Y_test[ii])
    ab[2][i].axes.get_xaxis().set_visible(False)
    ab[2][i].axes.get_yaxis().set_visible(False)
    ab[2][i].title.set_text("HR: "+str(i+1))