画像分類は深層学習の中でも最も研究が進んだ分野の一つです。2012年のAlexNetによるブレークスルー以来、CNNが画像認識の王道であり続けた。しかし2020年、Googleが発表したVision Transformer(ViT)が、Transformerアーキテクチャによる画像分類の可能性を示し、パラダイムシフトが起きた。本記事では、PyTorchを使ってCNNとViTの両方を実装し、そのアーキテクチャの違いと特性を実践的に理解する。
画像分類の基本であるCNNから始めよう。CIFAR-10データセットを使ったシンプルなCNNモデルを実装する。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# データの前処理
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616)
),
])
train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True,
download=True, transform=transform_train
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=128,
shuffle=True, num_workers=4
)
単純な畳み込み層の積み重ねではなく、ResNetで導入されたスキップ接続を取り入れた実装を示す。
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels,
stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, 3,
stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, 3,
stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1,
stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = self.relu(out)
return out
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.layer1 = self._make_layer(3, 64, 2, stride=1)
self.layer2 = self._make_layer(64, 128, 2, stride=2)
self.layer3 = self._make_layer(128, 256, 2, stride=2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256, num_classes)
def _make_layer(self, in_ch, out_ch, blocks, stride):
layers = [ResidualBlock(in_ch, out_ch, stride)]
for _ in range(1, blocks):
layers.append(ResidualBlock(out_ch, out_ch))
return nn.Sequential(*layers)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avg_pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
CNNの強みは、畳み込み演算による局所的な特徴の効率的な抽出にある。カーネルが画像上をスライドすることで、エッジ、テクスチャ、形状といった階層的な特徴を学習する。この帰納バイアス(局所性と平行移動不変性)のおかげで、比較的少ないデータでも効果的に学習できます。
次に、ViTの実装に移ろう。ViTの核心的なアイデアは、画像をパッチに分割してシーケンスとして扱い、Transformerで処理するというものです。
class PatchEmbedding(nn.Module):
def __init__(self, img_size=32, patch_size=4,
in_channels=3, embed_dim=256):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0,
drop_rate=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(
embed_dim, num_heads,
dropout=drop_rate, batch_first=True
)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim,
int(embed_dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(drop_rate),
nn.Linear(int(embed_dim * mlp_ratio),
embed_dim),
nn.Dropout(drop_rate),
)
def forward(self, x):
x = x + self.attn(
self.norm1(x), self.norm1(x), self.norm1(x)
)[0]
x = x + self.mlp(self.norm2(x))
return x
class ViT(nn.Module):
def __init__(self, img_size=32, patch_size=4,
num_classes=10, embed_dim=256,
depth=6, num_heads=8):
super().__init__()
self.patch_embed = PatchEmbedding(
img_size, patch_size, 3, embed_dim
)
num_patches = (img_size // patch_size) ** 2
self.cls_token = nn.Parameter(
torch.zeros(1, 1, embed_dim)
)
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim)
)
self.blocks = nn.Sequential(
*[TransformerBlock(embed_dim, num_heads)
for _ in range(depth)]
)
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.pos_embed
x = self.blocks(x)
x = self.norm(x[:, 0])
x = self.head(x)
return x
実際に両モデルを訓練して比較すると、興味深い違いが見えてくる。
筆者の経験では、データ量が10万枚未満のプロジェクトではCNNベース(EfficientNetやConvNeXtなど)の方が安定した結果を得られます。一方、大規模データセットが利用可能で、かつ事前学習済みモデルをファインチューニングする場合は、ViTベースのモデルが優位に立つ。最近はConvNeXtのようにCNNの設計をTransformerの知見で改良したモデルも登場しており、両者の境界は曖昧になりつつある。
CNNからViTへの進化は、画像認識のパラダイムを大きく変えた。しかし、どちらが絶対的に優れているわけではなく、タスクの特性とデータ量に応じて適切なアーキテクチャを選択することが重要です。PyTorchで両方を実装してみることで、それぞれの仕組みと特性を肌感覚で理解できます。実際のプロジェクトでは、まず事前学習済みモデルからのファインチューニングを試し、必要に応じてアーキテクチャをカスタマイズするアプローチが効率的です。