AI/TensorFlow & PyTorch

[PyTorch] FCN을 이용한 시맨틱 분할

byunghyun23 2023. 7. 11. 21:44

torchvision의 fcn_resnet101을 이용하여 시맨틱 분할을 해보겠습니다.
시맨틱 분할에 대한 내용은 이곳을 확인해 주세요.

 

torchvision은 파이토치에서 제공하는 데이터셋과 모델 패키지입니다.

fcn_resnet101은 말 그대로 ResNet101 기반의 FCN 입니다.

fcn_resnet101은 pre-trained model이기 때문에 따로 학습할 필요가 없습니다.

 

사용 방법은 간단합니다. 모델 출력의 'out'을 softmax를 이용하여 class를 분류하고 그 값을 픽셀값으로 사용하면 됩니다.

 

아래는 전체 코드입니다.

 

[fcn.py]

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt


# Load a pre-trained FCN model
model = torchvision.models.segmentation.fcn_resnet101(pretrained=True)
model.eval()

# Preprocess the input image
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load and preprocess the image
image_path = 'image.jpg'
image = Image.open(image_path).convert("RGB")
input_image = transform(image)
input_image = input_image.unsqueeze(0)

# Run the image through the model
with torch.no_grad():
    output = model(input_image)['out']

probs = torch.softmax(output, dim=1)
_, predicted_class = torch.max(probs, dim=1)

# Visualize the input image
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title('Input Image')
plt.axis('off')

# Visualize the predicted segmentation mask
semantic_segmentation = predicted_class.squeeze().cpu().numpy()
plt.subplot(1, 2, 2)
plt.imshow(semantic_segmentation, cmap='jet')
plt.title('Semantic Segmentation')
plt.axis('off')

plt.tight_layout()
plt.show()