IMPORT & DEFINE

In [ ]:
%%capture 

from tensorflow import keras 
import keras.layers as L 
import keras.backend as K
from keras.optimizers import Adam 
from keras.optimizers.schedules import ExponentialDecay
from keras.utils import to_categorical

import numpy as np 
import cv2 
import pickle 
import random 

import re
import nltk
nltk.download("popular")
from nltk.corpus import stopwords 
stop_words = set(stopwords.words('english'))


PAD = "#PAD#"
UNK = "#UNK#" 
START = "#START#" 
END = "#END#"

FUNCTIONS

In [ ]:
def read_pickle(pickle_file):
  with open(pickle_file, 'rb') as f :
    return pickle.load(f)

def show_image(image):
  from google.colab.patches import cv2_imshow
  cv2_imshow(image)




def seq_generator(image_embeddings, captions, SENTENCE_LEN, padding_element, NUM_WORDS, batch_size):

  def pad_sequence(seq, seq_len, padding_element):
    if len(seq) >= seq_len: return seq[0:seq_len] 
    else: return seq + [padding_element]*(seq_len-len(seq))

  x1, x2, y = [], [], [] 
  n = 0

  while True :
    for i, caption in enumerate(captions):
      n += 1 
      seq = random.choice(caption)
      for j in range(len(seq)):
        inseq, outseq = seq[:j], seq[j] 
        x1.append(image_embeddings[i])
        x2.append(pad_sequence(inseq, SENTENCE_LEN, padding_element))
        y.append(to_categorical([outseq], NUM_WORDS)[0] )
      if n == batch_size :
        yield ([np.array(x1), np.array(x2)], np.array(y))
        x1, x2, y = [], [], [] 
        n = 0

def get_num_sentences(captions):
  return len(captions) * 5

def split_sentence(sentence):
  return list(filter(lambda x : len(x) > 2 and x not in stop_words, re.split('\W+', sentence.lower())))

def get_max_len(captions):
  maxlength = 0 
  for caption in captions:
    for sentence in caption:
      maxlength = max(maxlength, len(sentence))
  return maxlength

def get_vocab(captions, vocab_size):
  vocab = dict()
  for caption in captions :
    for sentence in caption:
      for word in split_sentence(sentence):
        if len(word) < 3 or word in stop_words: continue 
        if word[-1] == '.' : word = word[:-1]
        vocab[word] = vocab.get(word, 0) + 1

  vocab = [word[0] for word in sorted(vocab.items(), key = lambda item : item[1], reverse = True)][0: vocab_size-4] + [PAD, UNK, START, END]
  word2ix = {word:index for index, word in enumerate(vocab)} 
  ix2word = {index:word for index, word in enumerate(vocab)} 
  return vocab, word2ix, ix2word

def captions2captions(captions, word2ix):
  new_captions = [] 
  for caption in captions : 
    new_caption = [] 
    for sentence in caption:
      #there's a problem here, if a word ends with '.' it will be ignored while we ideally want to just erase the '.' and get the word
      new_caption.append([word2ix[START]] + [word2ix[word] for word in split_sentence(sentence) if word in word2ix] + [word2ix[END]])
    new_captions.append(new_caption)
  return new_captions


def build_model(SENTENCE_LEN, NUM_WORDS, WORD_EMBED_SIZE, LSTM_UNITS):
  image_features = L.Input(shape = (2048,)) 
  densed_IF = L.Dense(256)(image_features) 
  densed_IF = L.Dense(WORD_EMBED_SIZE)(densed_IF)
  initial_state = [densed_IF, densed_IF]

  words = L.Input(shape = (SENTENCE_LEN,))
  words_embeddings = L.Embedding(NUM_WORDS, WORD_EMBED_SIZE)(words) 

  lstm_output = L.LSTM(LSTM_UNITS)(inputs = words_embeddings, initial_state = initial_state) 

  decoded_words1 = L.Dense(512, activation = 'relu')(lstm_output) 
  decoded_words2 = L.Dense(1024, activation = 'relu')(decoded_words1)
  final_output = L.Dense(NUM_WORDS, activation = 'softmax')(decoded_words2) 

  model = keras.Model(inputs = [image_features, words], outputs = final_output) 
  return model

def download_data():

  !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-5AhtN3za59P6WsHBhVRw7Lonx4wk6xu' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-5AhtN3za59P6WsHBhVRw7Lonx4wk6xu" -O train_image_features.pickle && rm -rf /tmp/cookies.txt
  train_image_embeds = read_pickle('train_image_features.pickle') 

  !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1--FpSwNWO9X8l5YneuTHcs--EtJlJZOP' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1--FpSwNWO9X8l5YneuTHcs--EtJlJZOP" -O val_image_embeds.pickle && rm -rf /tmp/cookies.txt
  val_image_embeds = read_pickle('val_image_embeds.pickle') 

  !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-729Stj7PWEztvH-YVJ_nivo4QmGR-ro' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-729Stj7PWEztvH-YVJ_nivo4QmGR-ro" -O train_captions.pickle && rm -rf /tmp/cookies.txt
  train_captions = read_pickle('train_captions.pickle') 

  !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-1WfbEjN052jaHSUb4h5J_t9BIgziSP3' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-1WfbEjN052jaHSUb4h5J_t9BIgziSP3" -O val_captions.pickle && rm -rf /tmp/cookies.txt
  val_captions = read_pickle('val_captions.pickle') 

  !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-77wJhnWLCnvmBOvOFuQrAv_ekjE-ooa' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-77wJhnWLCnvmBOvOFuQrAv_ekjE-ooa" -O train_image_fns.pickle && rm -rf /tmp/cookies.txt
  train_image_fns = read_pickle('train_image_fns.pickle') 

  !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-2HLbOGLT4E9V_IywjwNWBviyz49QjrY' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-2HLbOGLT4E9V_IywjwNWBviyz49QjrY" -O val_image_fns.pickle && rm -rf /tmp/cookies.txt
  val_image_fns = read_pickle('val_image_fns.pickle') 

  return train_image_embeds, val_image_embeds, train_captions, val_captions, train_image_fns, val_image_fns 

CAPTIONS

In [ ]:
%%capture 
train_image_embeds, val_image_embeds, train_captions, val_captions, train_image_fns, val_image_fns = download_data() 
all_captions = train_captions + val_captions 
In [ ]:
NUM_WORDS = 5000
NUM_SENTENCES = get_num_sentences(train_captions) 
SENTENCE_LEN = 20
WORD_EMBED_SIZE = 200

VOCAB, word2ix, ix2word = get_vocab(all_captions, vocab_size = NUM_WORDS) 
train_captions_ix = captions2captions(train_captions, word2ix)

MODEL

TRAIN & SAVE

In [ ]:
K.clear_session() 
model = build_model(SENTENCE_LEN, NUM_WORDS, WORD_EMBED_SIZE, LSTM_UNITS = 200) 

scheduler = ExponentialDecay(initial_learning_rate= 1e-3, decay_rate = 0.93, decay_steps = 10000)
model.compile(loss = 'categorical_crossentropy', optimizer = Adam(scheduler))
In [ ]:
epochs = 10
BS = 50
steps = NUM_SENTENCES // BS
train_generator = seq_generator(train_image_embeds, train_captions_ix, SENTENCE_LEN, word2ix[PAD], NUM_WORDS, batch_size = BS)

for epoch in range(epochs) :
  model.fit(train_generator, epochs = 1, steps_per_epoch = steps, verbose = 1) 
8278/8278 [==============================] - 246s 30ms/step - loss: 3.7889
8278/8278 [==============================] - 240s 29ms/step - loss: 3.0566
8278/8278 [==============================] - 238s 29ms/step - loss: 2.9034
8278/8278 [==============================] - 240s 29ms/step - loss: 2.8114
8278/8278 [==============================] - 237s 29ms/step - loss: 2.7491
8278/8278 [==============================] - 236s 29ms/step - loss: 2.6986
8278/8278 [==============================] - 237s 29ms/step - loss: 2.6576
8278/8278 [==============================] - 237s 29ms/step - loss: 2.6211
8278/8278 [==============================] - 236s 29ms/step - loss: 2.5885
8278/8278 [==============================] - 235s 28ms/step - loss: 2.5622
In [ ]:
model_path = '/content/drive/My Drive/image_caption_project/partial_model.h5'
model.save_weights(model_path)

LOAD & PLAY

In [ ]:
model_path = '/content/drive/My Drive/image_caption_project/partial_model.h5'

K.clear_session() 
model = build_model(SENTENCE_LEN, NUM_WORDS, WORD_EMBED_SIZE, LSTM_UNITS = 200) 
model.load_weights(model_path) 
In [ ]: