def read_pickle(pickle_file):
with open(pickle_file, 'rb') as f :
return pickle.load(f)
def show_image(image):
from google.colab.patches import cv2_imshow
cv2_imshow(image)
def seq_generator(image_embeddings, captions, SENTENCE_LEN, padding_element, NUM_WORDS, batch_size):
def pad_sequence(seq, seq_len, padding_element):
if len(seq) >= seq_len: return seq[0:seq_len]
else: return seq + [padding_element]*(seq_len-len(seq))
x1, x2, y = [], [], []
n = 0
while True :
for i, caption in enumerate(captions):
n += 1
seq = random.choice(caption)
for j in range(len(seq)):
inseq, outseq = seq[:j], seq[j]
x1.append(image_embeddings[i])
x2.append(pad_sequence(inseq, SENTENCE_LEN, padding_element))
y.append(to_categorical([outseq], NUM_WORDS)[0] )
if n == batch_size :
yield ([np.array(x1), np.array(x2)], np.array(y))
x1, x2, y = [], [], []
n = 0
def get_num_sentences(captions):
return len(captions) * 5
def split_sentence(sentence):
return list(filter(lambda x : len(x) > 2 and x not in stop_words, re.split('\W+', sentence.lower())))
def get_max_len(captions):
maxlength = 0
for caption in captions:
for sentence in caption:
maxlength = max(maxlength, len(sentence))
return maxlength
def get_vocab(captions, vocab_size):
vocab = dict()
for caption in captions :
for sentence in caption:
for word in split_sentence(sentence):
if len(word) < 3 or word in stop_words: continue
if word[-1] == '.' : word = word[:-1]
vocab[word] = vocab.get(word, 0) + 1
vocab = [word[0] for word in sorted(vocab.items(), key = lambda item : item[1], reverse = True)][0: vocab_size-4] + [PAD, UNK, START, END]
word2ix = {word:index for index, word in enumerate(vocab)}
ix2word = {index:word for index, word in enumerate(vocab)}
return vocab, word2ix, ix2word
def captions2captions(captions, word2ix):
new_captions = []
for caption in captions :
new_caption = []
for sentence in caption:
#there's a problem here, if a word ends with '.' it will be ignored while we ideally want to just erase the '.' and get the word
new_caption.append([word2ix[START]] + [word2ix[word] for word in split_sentence(sentence) if word in word2ix] + [word2ix[END]])
new_captions.append(new_caption)
return new_captions
def build_model(SENTENCE_LEN, NUM_WORDS, WORD_EMBED_SIZE, LSTM_UNITS):
image_features = L.Input(shape = (2048,))
densed_IF = L.Dense(256)(image_features)
densed_IF = L.Dense(WORD_EMBED_SIZE)(densed_IF)
initial_state = [densed_IF, densed_IF]
words = L.Input(shape = (SENTENCE_LEN,))
words_embeddings = L.Embedding(NUM_WORDS, WORD_EMBED_SIZE)(words)
lstm_output = L.LSTM(LSTM_UNITS)(inputs = words_embeddings, initial_state = initial_state)
decoded_words1 = L.Dense(512, activation = 'relu')(lstm_output)
decoded_words2 = L.Dense(1024, activation = 'relu')(decoded_words1)
final_output = L.Dense(NUM_WORDS, activation = 'softmax')(decoded_words2)
model = keras.Model(inputs = [image_features, words], outputs = final_output)
return model
def download_data():
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-5AhtN3za59P6WsHBhVRw7Lonx4wk6xu' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-5AhtN3za59P6WsHBhVRw7Lonx4wk6xu" -O train_image_features.pickle && rm -rf /tmp/cookies.txt
train_image_embeds = read_pickle('train_image_features.pickle')
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1--FpSwNWO9X8l5YneuTHcs--EtJlJZOP' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1--FpSwNWO9X8l5YneuTHcs--EtJlJZOP" -O val_image_embeds.pickle && rm -rf /tmp/cookies.txt
val_image_embeds = read_pickle('val_image_embeds.pickle')
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-729Stj7PWEztvH-YVJ_nivo4QmGR-ro' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-729Stj7PWEztvH-YVJ_nivo4QmGR-ro" -O train_captions.pickle && rm -rf /tmp/cookies.txt
train_captions = read_pickle('train_captions.pickle')
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-1WfbEjN052jaHSUb4h5J_t9BIgziSP3' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-1WfbEjN052jaHSUb4h5J_t9BIgziSP3" -O val_captions.pickle && rm -rf /tmp/cookies.txt
val_captions = read_pickle('val_captions.pickle')
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-77wJhnWLCnvmBOvOFuQrAv_ekjE-ooa' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-77wJhnWLCnvmBOvOFuQrAv_ekjE-ooa" -O train_image_fns.pickle && rm -rf /tmp/cookies.txt
train_image_fns = read_pickle('train_image_fns.pickle')
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-2HLbOGLT4E9V_IywjwNWBviyz49QjrY' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-2HLbOGLT4E9V_IywjwNWBviyz49QjrY" -O val_image_fns.pickle && rm -rf /tmp/cookies.txt
val_image_fns = read_pickle('val_image_fns.pickle')
return train_image_embeds, val_image_embeds, train_captions, val_captions, train_image_fns, val_image_fns