一般社団法人 全国個人事業主支援協会

COLUMN コラム

  • PyTorchで実装する画像分類モデル:CNNからViTへの進化

画像分類の歴史的転換点

画像分類は深層学習の中でも最も研究が進んだ分野の一つです。2012年のAlexNetによるブレークスルー以来、CNNが画像認識の王道であり続けた。しかし2020年、Googleが発表したVision Transformer(ViT)が、Transformerアーキテクチャによる画像分類の可能性を示し、パラダイムシフトが起きた。本記事では、PyTorchを使ってCNNとViTの両方を実装し、そのアーキテクチャの違いと特性を実践的に理解する。

まずはCNNの基本実装

画像分類の基本であるCNNから始めよう。CIFAR-10データセットを使ったシンプルなCNNモデルを実装する。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# データの前処理
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616)
),
])

train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True,
download=True, transform=transform_train
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=128,
shuffle=True, num_workers=4
)

モダンなCNNアーキテクチャ

単純な畳み込み層の積み重ねではなく、ResNetで導入されたスキップ接続を取り入れた実装を示す。

class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels,
stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, 3,
stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, 3,
stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)

self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1,
stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)

def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = self.relu(out)
return out


class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.layer1 = self._make_layer(3, 64, 2, stride=1)
self.layer2 = self._make_layer(64, 128, 2, stride=2)
self.layer3 = self._make_layer(128, 256, 2, stride=2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256, num_classes)

def _make_layer(self, in_ch, out_ch, blocks, stride):
layers = [ResidualBlock(in_ch, out_ch, stride)]
for _ in range(1, blocks):
layers.append(ResidualBlock(out_ch, out_ch))
return nn.Sequential(*layers)

def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avg_pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x

CNNの強みは、畳み込み演算による局所的な特徴の効率的な抽出にある。カーネルが画像上をスライドすることで、エッジ、テクスチャ、形状といった階層的な特徴を学習する。この帰納バイアス(局所性と平行移動不変性)のおかげで、比較的少ないデータでも効果的に学習できます。

Vision Transformerの実装

次に、ViTの実装に移ろう。ViTの核心的なアイデアは、画像をパッチに分割してシーケンスとして扱い、Transformerで処理するというものです。

class PatchEmbedding(nn.Module):
def __init__(self, img_size=32, patch_size=4,
in_channels=3, embed_dim=256):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size,
stride=patch_size
)

def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x


class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0,
drop_rate=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(
embed_dim, num_heads,
dropout=drop_rate, batch_first=True
)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim,
int(embed_dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(drop_rate),
nn.Linear(int(embed_dim * mlp_ratio),
embed_dim),
nn.Dropout(drop_rate),
)

def forward(self, x):
x = x + self.attn(
self.norm1(x), self.norm1(x), self.norm1(x)
)[0]
x = x + self.mlp(self.norm2(x))
return x


class ViT(nn.Module):
def __init__(self, img_size=32, patch_size=4,
num_classes=10, embed_dim=256,
depth=6, num_heads=8):
super().__init__()
self.patch_embed = PatchEmbedding(
img_size, patch_size, 3, embed_dim
)
num_patches = (img_size // patch_size) ** 2

self.cls_token = nn.Parameter(
torch.zeros(1, 1, embed_dim)
)
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim)
)
self.blocks = nn.Sequential(
*[TransformerBlock(embed_dim, num_heads)
for _ in range(depth)]
)
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)

nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)

def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.pos_embed
x = self.blocks(x)
x = self.norm(x[:, 0])
x = self.head(x)
return x

CNNとViTの特性比較

実際に両モデルを訓練して比較すると、興味深い違いが見えてくる。

  • データ効率:CNNは少量のデータでも合理的な精度を達成する。ViTは大規模データセットで真価を発揮するが、小規模データでは過学習しやすい
  • 計算コスト:ViTのSelf-Attentionはパッチ数の二乗に比例する計算量を持つ。高解像度画像では特に注意が必要
  • 特徴の捉え方:CNNが局所的な特徴から大域的な特徴を積み上げるのに対し、ViTは最初から大域的な関係性を捉えられます
  • 解釈可能性:ViTのAttention Mapを可視化することで、モデルが画像のどの部分に注目しているかを直感的に理解できます

実務での選択指針

筆者の経験では、データ量が10万枚未満のプロジェクトではCNNベース(EfficientNetやConvNeXtなど)の方が安定した結果を得られます。一方、大規模データセットが利用可能で、かつ事前学習済みモデルをファインチューニングする場合は、ViTベースのモデルが優位に立つ。最近はConvNeXtのようにCNNの設計をTransformerの知見で改良したモデルも登場しており、両者の境界は曖昧になりつつある。

まとめ

CNNからViTへの進化は、画像認識のパラダイムを大きく変えた。しかし、どちらが絶対的に優れているわけではなく、タスクの特性とデータ量に応じて適切なアーキテクチャを選択することが重要です。PyTorchで両方を実装してみることで、それぞれの仕組みと特性を肌感覚で理解できます。実際のプロジェクトでは、まず事前学習済みモデルからのファインチューニングを試し、必要に応じてアーキテクチャをカスタマイズするアプローチが効率的です。

この記事をシェアする

  • Twitterでシェア
  • Facebookでシェア
  • LINEでシェア