第8回 さらにディープな世界へ: 勾配消失問題と残差ネットワーク (ResNet)

  1. 勾配消失問題とは
  2. 残差ネットワークとは
  3. PyTorchをさらに活用するために
  4. まとめ

1. 勾配消失問題とは

本講座では、これまで 「ニューラルネットワークのレイヤーを増やすほど学習能力は上がる」と 説明してきた。 しかし当然レイヤーを増やすことによるコストもある。 とくに顕著なのは、 「レイヤーを増やせば増やすほど訓練に時間がかかる」ということである。 これには 2つの要因があって:

  1. レイヤーを増やすとノードの数が増えるため、ネットワーク全体の計算に時間がかかるようになる。
  2. レイヤーを増やすと、最初のほうの (入力に近い) レイヤーの重み・バイアスの変化が指数的に遅くなる。
ためである。 特に b. の現象を「勾配消失問題 (vanishing gradient problem)」という。 勾配消失問題が起こる原因は、ニューラルネットワークの訓練方法のためである。 第3回の誤差逆伝播法を思い出そう。 ここでは、勾配を計算するため backward() メソッドが 最終レイヤーから順に呼ばれていた:
x y y y y0 delta delta delta forward forward forward backward backward backward mse_loss レイヤー1 レイヤー2 レイヤー3

そして実際の backward() メソッドは次のようになっている (読みやすさのため、NumPy で書かれたバージョン を使っている):

    def backward(self, delta):
        # self.y が計算されたときのシグモイド関数の微分を求める。
        ds = d_sigmoid(self.y)
        # 各偏微分を計算する。
        self.dw += (delta * ds).reshape(self.nout, 1) * self.x
        self.db += delta * ds
        # 各入力値の微分を求める。
        dx = np.dot(delta * ds, self.w)
        return dx

各ノードにおける重み・バイアスの変化 (self.dwself.db) は、どちらも引数 delta に依存している。 さらに delta の値は、最終レイヤーから最初のレイヤーへと 渡されていくことを思い出してほしい。 活性化関数にシグモイド関数を使っている場合、ここでの 各ノードへの入力 x および微分 ds はどちらも 1 以下なので、delta の値はレイヤーが戻るたびに減衰していく。 これが最初のレイヤーに到達するころには、 delta は 最終レイヤーの delta よりもずっと小さくなってしまう (消失してしまう)。

演習8-1. 勾配消失問題を実際に観察する

第4回の演習4-7.のコードを変更した以下のプログラムを実行せよ。 このコードはN個の中間レイヤー (+ Softmax レイヤー) を使って、 MNIST の学習をおこなうものである。 ミニバッチごとに、各レイヤーの重みの変化 (dw) の大きさを表示する。 中間レイヤーの数 N を増やし、N が増えると各レイヤーの変化率が どのように異なっていくか観察せよ。

# 訓練データの画像・ラベルを読み込む (パス名は適宜変更)。
train_images = load_mnist('train-images-idx3-ubyte.gz')
train_labels = load_mnist('train-labels-idx1-ubyte.gz')
# レイヤーをリストで管理する。
layers = [Layer(784, 100)]  # 最初のレイヤー
# N個の中間レイヤーを作成。
N = 2
for _ in range(N):
    layers.append(Layer(100, 100))
softmax = SoftmaxLayer(100, 10)
n = 0
for i in range(1):
    for (image,label) in zip(train_images, train_labels):
        x = (image/255).reshape(784)
        ya = np.zeros(10)
        ya[label] = 1
        # 各レイヤーに入力を与える。
        for layer in layers:
            x = layer.forward(x)
        y = softmax.forward(x)
        delta = softmax.cross_entropy_loss_backward(ya)
        # 逆順で勾配を与える。
        for layer in reversed(layers):
            delta = layer.backward(delta)
        n += 1
        if (n % 50 == 0):
            # 各レイヤーの重み変化 (dw) の平均的な大きさを計算する。
            grads = [ np.sqrt((layer.dw**2).mean()) for layer in layers ]
            print(n, softmax.loss, grads)
            for layer in layers:
                layer.update(0.01)
            softmax.update(0.01)

上の演習を実行すると、レイヤーが増えるに従って 最初のほうのレイヤーの dw が顕著に小さくなっているのがわかる。 これらはほぼ指数的に小さくなっている。つまり、 レイヤーが増えるに従って、訓練にかかる時間は指数的に増大するのである。 実際、N=2 のときは損失 (softmax.loss) がゆるやかに減少していくが、 N=10 のときはある時点でほとんど変化しなくなってしまう。

勾配消失問題を緩和する対策のひとつが 「なるべく勾配が各レイヤーで減衰しないようにする」ことである。 第5回で紹介した 活性化関数 ReLU はまさにこのような目的で導入された。 しかしこれは問題を本質的に解決しているわけではない。 レイヤーが多くなれば (つまりネットワークが「ディープ」になればなるほど) いずれ勾配は消失してしまう。 勾配消失問題は、よりディープなニューラルネットワークを 設計しようとするときに避けて通れない問題なのである。

2. 残差ネットワークとは

勾配消失問題を解決するために考え出されたのが 残差ネットワーク (residual network) である。 これは以下のような「残差ブロック」を重ねて作られた ニューラルネットワークで、 図中の + はベクトルの加算をあらわす。 中央のレイヤーを飛び越して入力と出力を直接つなぐ 「スキップ接続 (skip connection)」に注目してほしい。 これが勾配消失問題を解決するためのメカニズムである。

Layer +
残差ブロックとスキップ接続

残差ブロックの機能を数式で表すと、次のようになる。 簡単のため、レイヤーの入力 x と 出力 y はそれぞれ同数のチャンネルをもつと仮定する。 レイヤーが 2つの入力 x1, x2 および 2つの出力 y1, y2 をもつとき、 残差ブロックの出力は以下のように表せる (活性化関数にはシグモイド関数 σ を使っている):

y1 = σ(w11·x1 + w12·x2 + b1) + x1
y2 = σ(w21·x1 + w22·x2 + b2) + x2

学習時、重み・バイアスの勾配は以前と同じである:

y1/w11 = σ′(w11·x1 + w12·x2 + b1) · x1
y1/w12 = σ′(w11·x1 + w12·x2 + b1) · x2
y2/w21 = σ′(w21·x1 + w22·x2 + b2) · x1
y2/w22 = σ′(w21·x1 + w22·x2 + b2) · x2
y1/b1 = σ′(w11·x1 + w12·x2 + b1) · 1
y2/b2 = σ′(w21·x1 + w22·x2 + b2) · 1

しかし誤差逆伝播法でひとつ前のレイヤーに戻る x1, x2 に対する勾配は違っている:

y1/x1 = σ′(w11·x1 + w12·x2 + b1) · w11 + 1
y2/x1 = σ′(w21·x1 + w22·x2 + b2) · w21 + 0

y1/x2 = σ′(w11·x1 + w12·x2 + b1) · w12 + 0
y2/x2 = σ′(w21·x1 + w22·x2 + b2) · w22 + 1

y1 と y2 の微分は足し合わされるので、 結局のところ x1, x2 に対する勾配は もともとの勾配の各成分に 1 を足したものになっている。 したがって、backward() メソッドにおける返り値は 従来の dx から dx + delta になり、delta の値は減衰せずにそのまま 前のレイヤーに渡されることになる。

演習8-2. 残差ブロックを実装する

Layer クラスに残差ブロックの機能を追加した 以下の ResidualLayer クラスを定義し、 演習 8-1. の例で N個の Layerのかわりにこれを利用する。 N=10 のときに勾配が消失しないことを観察せよ。

class ResidualLayer(Layer):

    def forward(self, x):
        # xは nin個の要素をもつ入力値のリスト。
        # 与えられた入力に対する各ノードの出力を計算する。
        self.x = x
        self.y = sigmoid(np.dot(self.w, x) + self.b)
        # yは nout個の要素をもつ出力値のリスト
        return self.y + x

    def backward(self, delta):
        # self.y が計算されたときのシグモイド関数の微分を求める。
        ds = d_sigmoid(self.y)
        # 各偏微分を計算する。
        self.dw += (delta * ds).reshape(self.nout, 1) * self.x
        self.db += delta * ds
        # 各入力値の微分を求める。
        dx = np.dot(delta * ds, self.w)
        return dx + delta

2.1. 画像認識の最先端 ResNet

では残差ネットワークを応用したニューラルネットワーク ResNet を見てみよう。 ResNet は VGG の翌年に発表されたネットワークで、 ImageNet の画像を 94% 程度の精度で認識できる。 ResNet にはレイヤーの数が 34個から 152個までのいくつかの実装があるが、 そのひとつである ResNet-50 は計50個のレイヤーをもっており、 これは現在、ディープラーニングを使った画像認識では デファクト・スタンダードとなっている (YOLO の後期バージョンである YOLOv3 も ResNet の構造を参考にしている)。

入力画像 (3×224×224) Conv-7, /2 (64) Max Pooling (64×56×56) Conv-1, /2 (64) Conv-3 (64) Conv-1 (256) ... ×3 (256×28×28) Conv-1, /2 (128) Conv-3 (128) Conv-1 (512) ... ×4 (512×14×14) Conv-1, /2 (256) Conv-3 (256) Conv-1 (1024) ... ×6 (1024×7×7) Conv-1, /2 (512) Conv-3 (512) Conv-1 (2048) ... ×3 (2048×7×7) Linear (1000) Softmax (1000)
ResNet-50 のレイヤー (/2 は最初のブロックのみ)

ResNet-50 では、ひとつの残差ブロックは 3つの畳み込みレイヤーから 構成されている。これを複数回くり返し、さらに「画像を縮小しつつ チャンネルを増やす」という従来の手法を使うことによって 高精度な画像認識を達成できている。

PyTorch を使って ResNet を実装する場合は、前述の backward() は必要ないため、 forward() メソッド中で単に入力の x を 出力時に足せばよい。PyTorch を使った残差ブロックの定義は以下のようになる:

class ResBlock(nn.Module):

    def __init__(self, cin, cmid, cout, stride=1):
        nn.Module.__init__(self)
        # 1つの残差ブロックは 3つの畳み込みレイヤー (+バッチ正規化) からなる。
        self.conv1 = nn.Conv2d(cin, cmid, 1, stride=stride)
        self.norm1 = nn.BatchNorm2d(cmid)
        self.conv2 = nn.Conv2d(cmid, cmid, 3, padding=1)
        self.norm2 = nn.BatchNorm2d(cmid)
        self.conv3 = nn.Conv2d(cmid, cout, 1)
        self.norm3 = nn.BatchNorm2d(cout)
        return

    def forward(self, x):
        skip = x  # スキップ接続
        x = self.conv1(x)
        x = self.norm1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = F.relu(x)
        # スキップ接続の値を x に足す。
        return x + expand(skip, x.shape)

ここで使っている関数 expand() は、 残差ブロックの入力・出力チャンネル数が違う場合に、 出力に合わせてテンソルを拡大する (不足部分を 0 で埋める) ものである:

# expand: テンソル t をサイズ target に拡張する。
def expand(t, target):
    padding = sum(( (0,a-b) for (a,b) in zip(target, t.shape) ), ())
    return F.pad(ta, padding)

PyTorch では、nn.Module クラスから派生したクラスは 1つのレイヤーのように扱える。そのため、ResNet50 クラスでは 先に定義した ResBlock クラスをあたかもひとつのレイヤーのように 使うことができる:

class ResNet50(nn.Module):

    def __init__(self):
        nn.Module.__init__(self)
        self.conv1 = nn.Conv2d(3, 64, 7, padding=3, stride=2)
        self.pool1 = nn.MaxPool2d(2)
        # 最初の3ブロック。
        self.conv2_1 = ResBlock(64, 64, 256, stride=2)
        self.conv2_2 = ResBlock(256, 64, 256)
        self.conv2_3 = ResBlock(256, 64, 256)
        # 次の4ブロック。
        self.conv3_1 = ResBlock(128, 512, stride=2)
        ...

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)
        # 各ブロックをひとつのレイヤーのように使う。
        x = self.conv2_1(x)
        x = F.relu(x)
        x = self.conv2_2(x)
        x = F.relu(x)
        x = self.conv2_3(x)
        x = F.relu(x)
        ...

ResNet の訓練には非常に時間がかかるため、ここでは実際に実行はしないが、 完成した実際の ResNet50 モデル (重み・バイアス) は torchvision モジュール (後述) として利用可能である。

なお、残差ネットワークの利点は単に勾配消失問題がない (レイヤーが多い) 以外にも あるのではないか、という指摘がされている。 まず、本来ニューラルネットワークの出力には入力の要素がそのまま含まれていることが多く、 f(x) という関数 f を直接推定するよりも、 f'(x)+x における 関数 f' を推定するほうが容易だという説がある。さらに、残差ブロックに入力された情報は 通常のレイヤーとスキップ接続の 2通りの経路をたどると考えられるため、 たとえば以下のような 3層の残差ネットワークがある場合、 情報の経路は 23 = 8 通り存在する。 これは事実上、8種類のニューラルネットワークを並列に組み合わせている、 と考えることができるというのである。

=
残差ブロックが並列なネットワークを構成する

3. PyTorchをさらに活用するために

ここでは、PyTorch をさらに開発するための いくつかのトピックについて紹介する。

3.1. torchvision を使う

PyTorch には、いくつもの派生プロジェクトが存在する。 ここではよく知られる torchvision モジュールについて紹介する。 これは PyTorch で画像処理をするときの便利な機能をまとめたものである。 また VGG や ResNet など、よく知られている画像認識モデルの PyTorch による完全な実装が含まれており、 これらのモデルを (重み・バイアスつきで) ダウンロードして すぐに使えるようになっている。 (これ以外にも PyTorch Hub という枠組みがあり、 公開済みのモデルを自動的にダウンロードして利用できるのだが、 現時点ではまだ動作が不安定なことが多いため、ここでは解説しない。)

まず torchvision モジュールをインストールしよう (Google Colab を使っている場合はインストール不要):

Torchvision で提供されているモデルの利用は、非常に簡単である。 ただ単に torchvision.models.vgg16torchvision.models.resnet50 といった関数を呼べばよい:

どちらも関数も pretrainedTrue の場合、 定義されたモデルを返すだけでなく、その重み・バイアスも torch.load() 関数を使って自動的に外部サイトからダウンロードする (ダウンロードした重み・バイアスは、 ホームディレクトリ中の ~/.cache/torch/ 以下のフォルダに格納される):

>>> import torchvision
>>> model = torchvision.models.vgg16(True)     # VGG-16 を利用
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/euske/.cache\torch\hub\checkpoints\vgg16-397923af.pth
100.0%
>>> model = torchvision.models.resnet50(True)  # ResNet-50 を利用
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/euske/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth
100.0%

ここで返される modelnn.Module 型の インスタンスであり、すぐに画像のテンソルを渡して推論させることができる:

import torch
import numpy as np
from PIL import Image

# RGB画像を読み込み、224×224 に縮小する
image = Image.open('input.jpg')
image.thumbnail((224, 224))
# 画像をndarray型に変換し、正規化する。
a = convert_image(image)
# さらにTensor型に変換。
t = torch.tensor(a)
# 次元をひとつ追加し (N×C×H×W) の形式にする。
x = t.reshape(1,3,224,224)
# 推論を実行。
y = model(x)
# もっとも確率の高いラベルを取得する。
i = torch.argmax(y)

なお、ここでは画像を ndarray型に変換し、色を正規化する関数 convert_image() を使っている:

def convert_image(image):
    # ndarray型に変換。
    a = np.array(image)
    # (H×W×C) → (C×H×W) の形式に並び換えをおこなう。
    a = a.transpose(2,0,1)
    # RGBの色を正規化する。
    a = a/255
    a[0] = (a[0]-0.485)/0.229
    a[1] = (a[1]-0.456)/0.224
    a[2] = (a[2]-0.406)/0.225
    return a
演習8-3. Torchvision を使って画像認識を実行する

適当な画像を用意し、上の例にならって torchvision モジュールの ResNet-50 を使って認識を実行せよ。

3.2. 転移学習を実装する

転移学習 (transfer learning) とは、 「すでに訓練されたニューラルネットワーク (の一部) を別の目的に転用する」 ことである。ニューラルネットワークに限らず、機械学習では一般に 精度を上げるためには大量の訓練データを必要とするが、 転移学習はすでに学習されたニューラルネットワークを 微調整するだけなので、以下のメリットがあると考えられている:

第7回で紹介した事前学習も、 あらかじめ学習したネットワークを別のネットワークの一部に使うことから、 転移学習の一種といえる。 実際、現在の実用的な画像認識システムのほとんどは あらかじめ ImageNet などを使って訓練した VGG や ResNet をもとに 転移学習を使って作られている。自然言語処理の分野では BERT のように 転移学習に使われることを前提として作られたモデルも存在する。 今後、ニューラルネットワークの利用の拡大にともなって 転移学習はますます利用されるようになっていくと思われる。

一般的な転移学習では、 あらかじめ訓練されたニューラルネットワーク末尾の 数レイヤーを削除し、そこに新しいレイヤーを追加する:

末尾の 数レイヤーを 削除 新規に レイヤーを 追加
一般的な転移学習の方法

さらに、転移学習には2種類の訓練方法が存在する:

  1. 訓練されたモデルを「特徴量抽出器 (feature extractor)」として利用し その出力を使って新たに学習する方法。もとのモデルは独立した 「モジュール」として扱い、そこに新たなレイヤーを付加する。 元のモデルの重み・バイアスは変更しない。
  2. 訓練されたモデルを、別タスク用に「再調整 (fine tuning)」する方法。 元のモデルと付加したレイヤーをまとめて「ひとつのモデル」として扱い、 重み・バイアスをまるごと学習する。 (以前に説明した YOLO の事前学習はこの方法を使っている)

PyTorch で転移学習を実現するには

では実際に PyTorch を使って転移学習を実装する。 ここでは ResNet の簡易版である ResNet-18 (torchvision モジュールに含まれる) を、 YOLO の訓練に用いた VOC データセットの画像認識に適用してみよう。 VOC データセットは画像の各オブジェクトに 20種類のラベルがついているが、 ここでは各画像のうち、もっとも大きな矩形をもつ物体のラベルを 「その画像のラベル」と定義することにする。 また、ResNet-18 は特徴量抽出器として利用し、 あらかじめ学習された重み・バイアスは変えないものとする。

最初に、torchvision モジュールで定義されている ResNet-18 から最後のレイヤーを除いたものを取得する。 じつは PyTorch は転移学習を簡単に実現するための枠組みが 用意されているわけではなく、いったん定義された nn.Moduleクラスから特定のレイヤーを 除去する決まった方法はない。したがって、 ここでは場当たり的なやり方を使うことにする。 まず torchvision.models.resnet18() を使って定義した モデルを見ると、以下のようなものが表示される:

>>> import torchvision
>>> pretrained = torchvision.models.resnet18(True)
>>> print(prerained)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  ...
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

ここから ResNet クラスの全接続レイヤー (fc) が 512要素の特徴量を受けとり、1000要素の出力を返す Linear レイヤーであることが読みとれる。 (PyTorch の慣例により、fc の後に活性化関数は適用されていない。) ResNetクラスはすでに 定義されてしまっているので、このレイヤーを直接削除することはできないが、 同様の効果を得る方法はいくつかある:

ここでは後者の方法を使ってみよう。作成した ResNet インスタンスの fc を何もしないレイヤーに置き換える. この nn.Identity は、 f(x) = x という関数を実現するだけのレイヤーで、 入力の値はそのまま出力され、勾配も変化しない:

pretrained.fc = nn.Identity()

さて、訓練済みレイヤーができたとして、 次に追加で訓練すべきレイヤーを定義する:

class AdapterNet(nn.Module):

    def __init__(self):
        super().__init__()
        # 512個の特徴量から 20種類のラベルを推論する。
        self.fc = nn.Linear(512, 20)

    def forward(self, x):
        x = self.fc(x)
        return x

adapter = AdapterNet()

(AdapterNet クラスは 1つのレイヤーしかないので、 クラスを定義せずに adapter = nn.Linear(512, 20) としてもかまわない。)

PyTorch では、2つのニューラルネットワークを接続するのは 非常に簡単である。ただ1つの関数呼び出しの結果を別の関数に渡せばよい。 PyTorch では勾配は各レイヤーではなく、 渡される値 (テンソル) の中に格納されるので、 このようにしても正しく勾配降下法が実行できる:

y = adapter(pretrained(x))
pretrained adapter
2つのニューラルネットワークを接続する

今回は pretrained の重み・バイアスは固定するので、 訓練するのは adapter だけである。 PyTorch で訓練時に特定のレイヤーの重み・バイアスを固定するには 3つの方法がある:

  1. 固定するレイヤーの重み・バイアスが格納されている各テンソルを requires_grad = False に設定する。
  2. 固定するレイヤーを with torch.no_grad(): ブロックの中で計算する。
  3. 最適化器 (optimizer) に、 固定するレイヤーの重み・バイアスを渡さないようにする。

ここでは上の a. および c. を両方実装してみよう。 訓練部分のコードは以下のようになる (強調部分はそれぞれレイヤーの重み・バイアスを固定する部分である) :

# 訓練ずみのモデルを取得する。
pretrained = torchvision.models.resnet18(True)
pretrained.fc = nn.Identity()
pretrained.eval()
# 訓練ずみレイヤーの重み・バイアスを固定する (勾配の計算を禁止する)。
for p in pretrained.parameters():
     p.requires_grad = False
# 追加レイヤーを作成する。
adapter = AdapterNet()
adapter.train()
# 最適化器と学習率を定義する (追加レイヤーのみ更新する)。
optimizer = optim.Adam(adapter.parameters(), lr=0.001)
# 10エポック分の訓練をおこなう。
for epoch in range(10):
    # 各ミニバッチを処理する。
    for (idx, samples) in enumerate(loader):
        (inputs, labels) = make_batch(samples)
        inputs = torch.tensor(inputs).to(device)
        labels = torch.tensor(labels).to(device)
        # すべての勾配(.grad)をクリアしておく。
        optimizer.zero_grad()
        # 与えられたミニバッチをニューラルネットワークに処理させる。
        outputs = adapter(pretrained(inputs))
        # 損失を計算する。
        loss = F.cross_entropy(outputs, labels)
        # 勾配を計算する。
        loss.backward()
        # 重み・バイアスを更新する。
        optimizer.step()
演習8-4. ResNet-18 を使った転移学習を実行する

上で説明した ResNet-18 を使った転移学習をおこなうコード transfer.py をダウンロードし、実行せよ。 (実行には 演習7-3. で使った PASCAL VOC データセットが必要。)

$ python transfer.py ./PASCALVOC2007.zip
次に、転移学習を使ず、ランダムな重み・バイアスから 初めて全体を訓練するモードを実行せよ:
$ python transfer.py --start-random ./PASCALVOC2007.zip

上の演習のプログラムは VOC データセットを使って、各エポックの訓練後に 検証データを使って精度を測定している。これを実行してみると、転移学習をおこなった場合は 3エポックほど終了した時点で、すでに 75% 程度の精度が出ていることがわかる。 最初から (ResNetの重みも含めて) すべて学習した場合は 30% 程度しか出ていないことからも、 少ないデータ・短い学習時間で効果が出せるという転移学習のメリットがうかがえる。

3.3. ONNX形式を使ってモデルを出力する

さて、本講座では PyTorch を使ってニューラルネットワークを開発する方法を 説明してきたが、機械学習フレームワークと呼ばれるものには PyTorch 以外にも多く存在する。代表的な例をあげると:

といったものである。これらのフレームワークはどれも基本的な機能は似ているが、 細かい仕様は異なっており、当然ながらこれらを使って開発された プログラム間には互換性がない。 しかし、これらで訓練したニューラルネットワークを、 言語やフレームワークに依存しない共通の形式で (重み・バイアスも含めて) エクポートする方法がある。 それが Open Neural Network Exchange (ONNX)形式 である。 たとえば PyTorch で作成したモデル (nn.Module クラス) を ONNX形式でエクスポートすると、PyTorch が動かないプラットフォーム (iOS や RaspberryPi など) でもモデルを使って推論することができる。 ONNX Model Zoo のページでは、 ONNX形式に変換されたさまざまなモデルがダウンロード可能である:
PyTorch TensorFlow CNTK ONNX 形式 ONNX Runtime TensorRT ... ...
異なるフレームワークで開発されたモデルの統一形式としての ONNX

ONNX形式はいくつかの 演算子が定義された 小規模なプログラミング言語のようなものである。 したがって、PyTorch のモデルを ONNX形式に変換する処理は、 Python プログラムを他のプログラミング言語に変換する処理とみなせる。 PyTorch では、このために 2つの方法を用意している:

  1. Tracing - Pythonプログラムをダミーの入力を使って実際に実行し、 各処理を動的に ONNX の演算子に変換する。 手軽に変換できるが、単調なプログラムしか変換できない。
  2. Scripting - Pythonプログラムを事前に解析し、ONNX にコンパイルする。 場合によっては Pythonプログラムの一部変更が必要。

今回はプログラムを変更する必要のない tracing方式を使って モデルを ONNX形式に変換してみよう。そのためには、モデルを定義して 実際に実行させる必要がある。 一般的に、あらゆる PyTorch モデルが ONNX形式に変換できるわけではない。 Tracing方式で ONNX形式に変換できるのは、forward() メソッド中に 条件分岐やループを含まない「単調な」処理だけである。 また、ONNXの仕様にはいくつかの「バージョン」が存在し、 出力した ONNX 形式がすべてのプラットフォームで使える保証はないので注意。

PyTorch のモデルを ONNX にエクスポートする手順は以下のとおり:

  1. モデルを定義し、インスタンスを作成する:
    class Net(nn.Module):
        ...
    
    model = Net()
    
  2. 各レイヤーの重み・バイアスを設定する (あるいは、読み込む)。 また、モデルを評価モードに設定する:
    params = torch.load(path)
    model.load_state_dict(params)
    model.eval()
    
  3. そのモデルを実行させるためのダミー入力を作成する。 たとえば、このニューラルネットワークが 3×224×224 のテンソルを入力する場合は、 ダミーのバッチサイズ 1 を加えて以下のような Tensor を作成する。
    dummy = torch.rand((1, 3, 224, 224))
    
  4. モデルのインスタンスを torch.onnx.export() 関数に渡す。 このときダミー入力と、出力する ONNX ファイル名も指定する。 また、モデルの入力・出力に特定の名前をつける。 (入力・出力名がリストになっているのは、複数の入力・出力をもつ ニューラルネットワークを想定しているためである。)
    torch.onnx.export(
        model,                    # モデル
        dummy,                    # ダミー入力
        "model.onnx",             # ONNXファイル名
        input_names=["image"],    # モデルの入力名
        output_names=["classes"]  # モデルの出力名
    )
    

以上で PyTorch のモデルを ONNX形式で出力できた。 これは 1×3×224×224 の "image" テンソルを入力し、 (ネットワークが決めた) "classes" テンソルを返すような ニューラルネットワークである。

image model.onnx classes

エクスポートした ONNX形式を使って推論する

ここでは代表的な ONNX Runtime を使って、 生成された ONNXファイルを使って推論してみる。 ONNX Runtime は PyTorch とは完全に独立したプロジェクトで、 GPU (CUDA) を使って推論するコードも含まれている。まず GPU を使うか否かに応じて、onnxruntime-gpu あるいは onnxruntime のどちらかのパッケージをインストールする:

C:\> pip install onnxruntime-gpu  # GPUを使って推論。
あるいは
C:\> pip install onnxruntime      # CPUを使って推論。

ONNX Runtime を使った推論は以下のようにおこなう。 まず、ONNX形式のファイルを指定し、InferenceSession インスタンスを作成する。このとき引数 providers で どのバックエンド (CPU または GPU) を使うかを指定する:

import onnxruntime as ort
# CPUを使う場合。
ort_sess = ort.InferenceSession('model.onnx', providers=['CPUExecutionProvider'])
# GPU (CUDA) を使う場合。
ort_sess = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])

推論を実行するには、作成したインスタンスの run() メソッドを呼ぶ。ONNX Runtime では、 入出力のデータ型として PyTorch のテンソルではなく NumPy の ndarray型を利用する。 また ONNX形式は入力・出力ともに複数ある場合を想定しているため、 入力は Python の辞書として与え、出力はリストとして受け取る。

# 入力値 (ndarray)
a = np.array(...)
outputs = ort_sess.run(None, {'image': a})
# 出力値
print(outputs[0])
演習8-5. ResNet-18 を ONNX形式に変換し、ONNX Runtime を使って推論する
  1. torchvisionモジュールで提供されている訓練済みの ResNet-18 モデルを resnet18.onnx というファイルで出力せよ。
  2. 以下のコードを用いて onnxruntime を使って推論し、同一の画像の 推論結果が PyTorch で実行したときと等しくなうことを確認せよ。
    # 推論セッションを作成する。
    ort_sess = ort.InferenceSession('resnet18.onnx', providers=['CPUExecutionProvider'])
    # RGB画像を読み込み、224×224 に縮小する
    img = Image.open('input.jpg')
    img.thumbnail((224, 224))
    # 画像をndarray型に変換し、正規化する。
    a = convert_image(image)
    # 次元をひとつ追加し (N×C×H×W) の形式にする。
    a = a.reshape(1,3,224,224)
    # 推論を実行。
    inputs = {'image': a}
    outputs = ort_sess.run(None, {'image':a})
    # もっとも確率の高いラベルを取得する。
    i = np.argmax(outputs[0])
    

ONNXモデルを利用するさいの注意として、 ONNX モデルは単なる「テンソルを入力し、テンソルを出力する」だけの関数であり、 各入力・出力の意味や使い方についてモデルは何も知らない ということがある。たとえば公開されている YOLO の ONNX モデルは 「画像をどのように前処理すべきか」 「テンソルの次元は (N×C×H×W) か (N×H×W×C) にすべきか」 「出力をどう解釈すべきか」などについては、すべて利用する側で 正しく実装せねばならず、さもないとまったく意味不明な 結果が返されることになる。

4. まとめ


クリエイティブ・コモンズ・ライセンス
この作品は、クリエイティブ・コモンズ 表示 - 継承 4.0 国際 ライセンスの下に提供されています。
Yusuke Shinyama