5.1 RNN, 그 이후

RNN은 분명 언어 모델로서 매력적이지만 그 정보가 오래 가지 못한다는 단점이 있었다. 따라서 이후 LSTM, GRU 등이 존재하였고, 요 근래의 Transformer 또한 등장하였다.

5.2 트랜스포머를 이용한 예제

이번 img2txt task를 진행하기 위해 tensorflow public documentation에 있는img2txt 코드를 통해 실습을 진행할 것이다.

이제 코드를 살펴보자.

import os
import re
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import efficientnet
from tensorflow.keras.layers import TextVectorization

위 코드는 필요한 라이브러리를 가지고 오는 코드이다. 또한 런타임을 GPU로 설정해주도록 하자.

from google.colab import drive
drive.mount('/content/drive')

위 코드는 google colab에서 내 드라이브를 마운트 하는 단계이다. 만약 개인 컴퓨터 환경을 사용한다면 이는 무시하도록 하자.

!mkdir img2txt_data
cd img2txt_data
!wget -q <https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip>
!wget -q <https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip>
!unzip -qq Flickr8k_Dataset.zip
!unzip -qq Flickr8k_text.zip
!rm Flickr8k_Dataset.zip Flickr8k_text.zip

위 리눅스 명령어들은 데이터를 특정 위치에 다운로드 받는 것이다. 개인 환경에서 진행한다면 원하는 폴더로 cd를 이용하여 먼저 이동하자

seed = 777
np.random.seed(seed)
tf.random.set_seed(seed)

결과 값을 고정하기 위해 seed를 주도록 하자.

IMAGES_PATH = "/content/img2txt_data/Flicker8k_Dataset"

IMAGE_SIZE = (299, 299)

VOCAB_SIZE = 10000

SEQ_LENGTH = 25

EMBED_DIM = 512

FF_DIM = 512

BATCH_SIZE = 64
EPOCHS = 30
AUTOTUNE = tf.data.AUTOTUNE

하이퍼 파라미터들을 정의한다. 이 때 이미지 path는 본인이 데이터를 저장한 path로 지정해야 한다.

def load_captions_data(filename):
    with open(filename) as caption_file:
        caption_data = caption_file.readlines()
        caption_mapping = {}
        text_data = []
        images_to_skip = set()

        for line in caption_data:
            line = line.rstrip("\n")
            img_name, caption = line.split("\t")

            img_name = img_name.split("#")[0]
            img_name = os.path.join(IMAGES_PATH, img_name.strip())

            tokens = caption.strip().split()

            if len(tokens) < 5 or len(tokens) > SEQ_LENGTH:
                images_to_skip.add(img_name)
                continue

            if img_name.endswith("jpg") and img_name not in images_to_skip:
                caption = " " + caption.strip() + " "
                text_data.append(caption)

                if img_name in caption_mapping:
                    caption_mapping[img_name].append(caption)
                else:
                    caption_mapping[img_name] = [caption]

        for img_name in images_to_skip:
            if img_name in caption_mapping:
                del caption_mapping[img_name]

        return caption_mapping, text_data

위 함수는 파일을 읽어 preprocess를 진행하는 함수이다. 토큰을 확인해 너무 작아 학습에 방해가 될 만한 문장들을 제외하고 처리를 진행한다.