torchvision의 fcn_resnet101을 이용하여 시맨틱 분할을 해보겠습니다.
시맨틱 분할에 대한 내용은 이곳을 확인해 주세요.
torchvision은 파이토치에서 제공하는 데이터셋과 모델 패키지입니다.
fcn_resnet101은 말 그대로 ResNet101 기반의 FCN 입니다.
fcn_resnet101은 pre-trained model이기 때문에 따로 학습할 필요가 없습니다.
사용 방법은 간단합니다. 모델 출력의 'out'을 softmax를 이용하여 class를 분류하고 그 값을 픽셀값으로 사용하면 됩니다.
아래는 전체 코드입니다.
[fcn.py]
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
# Load a pre-trained FCN model
model = torchvision.models.segmentation.fcn_resnet101(pretrained=True)
model.eval()
# Preprocess the input image
transform = transforms.Compose([
transforms.ToTensor(),
])
# Load and preprocess the image
image_path = 'image.jpg'
image = Image.open(image_path).convert("RGB")
input_image = transform(image)
input_image = input_image.unsqueeze(0)
# Run the image through the model
with torch.no_grad():
output = model(input_image)['out']
probs = torch.softmax(output, dim=1)
_, predicted_class = torch.max(probs, dim=1)
# Visualize the input image
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title('Input Image')
plt.axis('off')
# Visualize the predicted segmentation mask
semantic_segmentation = predicted_class.squeeze().cpu().numpy()
plt.subplot(1, 2, 2)
plt.imshow(semantic_segmentation, cmap='jet')
plt.title('Semantic Segmentation')
plt.axis('off')
plt.tight_layout()
plt.show()
'AI > TensorFlow & PyTorch' 카테고리의 다른 글
[PyTorch] Mask R-CNN을 이용한 인스턴스 분할 (0) | 2023.07.11 |
---|---|
[PyTorch] Faster R-CNN을 이용한 객체 탐지 (0) | 2023.07.09 |
[TensorFlow] InceptionV3을 이용한 이미지 검색 (0) | 2023.06.19 |
[PyTorch] CNNs을 이용한 이미지 분류 (1) | 2023.06.15 |
[TensorFlow] CNNs을 이용한 이미지 분류 (0) | 2023.06.15 |