AI/TensorFlow & PyTorch

[TensorFlow] InceptionV3을 이용한 이미지 검색

byunghyun23 2023. 6. 19. 00:59

InceptionV3를 이용하여 이미지를 검색해 보겠습니다.
이미지 검색에 대한 내용은 이곳을 확인해 주세요.

 

이미지 검색은 쿼리 이미지와 데이터베이스 이미지들의 유사도를 계산하여 쿼리 이미지와 유사한 이미지를 찾는 것입니다. 계산하는 유사도는 이미지의 feature이며, InceptionV3의 출력입니다. InceptionV3의 출력은 2048입니다.

또한 유사도는 유클리드 거리(Euclidean distance)를 계산하여 판단합니다.

 

데이터셋은 The Oxford-IIIT Pet Dataset에서 다운로드 후 data 디렉토리를 생성하여 저장하면 됩니다.

 

쿼리 이미지는 Abyssinian_1.jpg입니다.

 

 

아래와 같이 쿼리 이미지와 유사한 이미지 50개를 확인합니다.

Abyssinian_1.jpg는 쿼리 이미지이기 때문에 유클리드 거리가 0입니다.

유사도 계산에서 유클리드 거리는 적을수록 비슷하다는 의미입니다.

50개중 13개는 잘못 찾았습니다. 하지만 눈으로 봤을 때는 거의 다 비슷하게 보이네요.

 

참고로 성능 향상을 위해 주어진 데이터셋으로 InceptionV3를 전이학습하면, 성능 향상을 기대해 볼 수 있으나 오버피팅(overfitting) 가능성이 있습니다.

 

 

아래는 전체 코드입니다.

 

[search.py]

import numpy as np
import tensorflow as tf
import click
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import pickle


def preprocessing(img):
    img = img.resize((224, 224), Image.ANTIALIAS)
    img = tf.keras.preprocessing.image.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = tf.keras.applications.inception_v3.preprocess_input(img)

    return img


@click.command()
@click.option('--query_image_name', default='data/Abyssinian_1.jpg', help='Query image name')
@click.option('--target_image_dir', default='data', help='Target image path')
def run(query_image_name, target_image_dir):
    rank = {}

    kwargs = {'input_shape': (224, 224, 3),
              'include_top': False,
              'weights': 'imagenet',
              'pooling': 'avg'}
    pretrained_model = tf.keras.applications.InceptionV3(**kwargs)

    query = Image.open(query_image_name)
    query.load()

    query = preprocessing(query)
    if query.shape[3] == 1:
        raise Exception('Expected axis -1 of input shape to have value 3.')

    query_feature = pretrained_model.predict(query)

    target_images = list(Path(target_image_dir).glob(r'**/*.jpg'))
    for target_image in target_images:
        target = Image.open(str(target_image))
        target.load()

        target = preprocessing(target)
        if target.shape[3] != 3:
            continue
        target_feature = pretrained_model.predict(target)

        dist = np.linalg.norm(query_feature - target_feature)
        rank[str(target_image)] = dist

    rank = dict(sorted(rank.items(), key=lambda x: x[1]))
    print(rank)

    with open('rank.pkl', 'wb') as f:
        pickle.dump(rank, f)

    # rank = {}
    # try:
    #     with open('rank.pkl', 'rb') as f:
    #         rank = pickle.load(f)
    # except Exception as e:
    #     pass
    # print(rank)

    fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(15, 7),
                             subplot_kw={'xticks': [], 'yticks': []})

    it = iter(rank)
    for i, ax in enumerate(axes.flat):
        image_name = next(it)
        print(image_name, rank[image_name])
        ax.imshow(plt.imread(image_name))
    plt.tight_layout(pad=0.5)
    plt.show()


if __name__ == '__main__':
    run()