(数式を使わない)
Transformer の直感的な説明

  1. RNN の欠点
  2. Transformer はこれをどう解決したか
  3. Transformer の動作原理
  4. 複数の要素間の関係を考慮する (Self-Attention、自己注意)
  5. 要素の順序を考慮する (Positional Encoding、位置エンコーディング)
  6. まとめ

概要: ChatGPT などで使われている Transformer モデルは、 ニューラルネットワークの世界にいくつかの革新的なアイデアをもたらした。 本記事では、プログラマに理解しやすい形でそれらのアイデアを解説する。 実際に使われている数学の詳細には触れない。 (技術的解説については元論文 Attention is All You Need か、 その注釈版である The Annotated Transformer を参照のこと。 日本語では この解説 がわかりやすい。)

必要な前提知識: ニューラルネットワークの基礎、RNN の原理、および Python の基礎。

1. RNN の欠点

これまで、可変長のデータ (音声やテキストなど) に対しては 再帰的ニューラルネットワーク (RNN) や LSTM などが 使われていた。これらのネットワークを使った代表的な seq2seq (sequence to sequence) モデルは以下のように機能する:

"King Midas has donkey ears" [17 29 54] "王様 の 耳 は ロバ の 耳" エンコード デコード 固定長

ここでは、入力列は一度、固定長の中間表現ベクトルに「圧縮 (エンコード)」され、 これを元に出力列が生成 (デコード) される。 自然言語文のようなものは入力のすべての部分を 見なければ適切な出力が求まらないので、一度すべてを読む必要があるためである。 しかし、このモデルには明らかなボトルネックがある。 中間表現は固定長なので、一定の情報量しか保持できないのである。 そのため、このモデルは長い入力列に対してはうまく動かなかった。

2. Transformer はこれをどう解決したか

Transformer モデルでは、入力列と中間表現は同じ長さをもっている。 これにより、上の問題が根本的に解決されている。 Transformer の原理をおおざっぱに図示すると、 以下のようになる:

"King Midas has donkey ears" [ 17 39 27 53 76 ] "王様 の 耳 は ロバ の 耳" エンコード デコード 入力列と同じ長さ

ここで使われている中間表現ベクトルは RNN とは まったく別種のものになっていることに注意してほしい。 RNN では、単語をひとつ処理するごとにベクトル内に情報が蓄積されていくので、 列が長くなればなるほど情報の圧縮が起こっていた。 つまり RNN では、 短い列と長い列の中間表現どうしを直接比較することはできない。 これに対して、Transformer モデルでは中間表現の各要素 (単語) がほぼ等しい量の 情報をもつため、短い列の中間表現は自然に長い列にも拡張できる。 これは情報の一貫性を上げ、学習に有利となる。

さらに RNN では各要素をひとつずつ処理する必要があったが、 Transformer ではこれらを同時に処理できるため、 訓練および推論プロセスを並列化できる。 そのため、(少なくとも理論的には) Transformer は既存の RNN・LSTM によるモデルよりも高い性能が期待できる。

3. Transformer の動作原理

では実際に Transformer の動作原理を見てみよう。 Transformer アルゴリズムを簡単な Python コードで示すと、 以下のようになる:

# 入力列 input に対する出力列 output を求める。
def transformer(input):
    # 入力列を memory に変換する。
    memory = encoder(input)
    # 出力列を初期化する。
    output = [BOS]
    while True:
        # ひとつずつ要素を足していく。
        elem = decoder(memory, output)
        output.append(elem)
        # EOS が来たら終了。
        if elem == EOS: break
    return output

アルゴリズムは非常にストレートである。 まず入力列が "memory" と呼ばれるものに変換され (これが何であるかについては後述する)、出力列がひとつずつ生成されていく。 BOSEOS は それぞれ「列の先頭 (Beginning of Sequence)」 「列の末尾 (End of Sequence)」を示す特別な記号である。

さて、この "memory" とは何だろうか? 直感的にいえば、これは「連想配列」や「ハッシュテーブル」 あるいは Python でいう「辞書」と呼ばれるものに相当する。 encoder() 関数はこの辞書を作成し、 decoder() 関数がその辞書を参照する。

以下はその直感的な Pythonバージョンである:

def encoder(src):
    # 入力列から2つの辞書を作成する。
    h1 = { key1(x): value1(x) for x in src }
    h2 = { key2(x): value2(x) for x in src }
    memory = (h1,h2)
    return memory

def decoder(memory, target):
    # 2つの辞書を使って出力を生成する。
    (h1,h2) = memory
    v1 = [ h1.get(query1(x)) for x in target ]
    v2 = [ h2.get(query2(x)) for x in target ]
    return ff(v1,v2)

ここで、関数 key1(), value1(), key2(), value2(), query1(), query2() および ff() はそれぞれ学習可能な関数であるとする。 実際のニューラルネットワークによる Transformer は 本物の Python辞書を使っていないことに注意 (なぜなら辞書は微分可能ではないからである)。 かわりに、行列の乗算による類似度の計算、およびベクトルの内積による 要素の「選択」を使って同様の処理を実現している。

また、上の encoder/decoder 関数は簡単のため 2つの辞書 (h1, h2) しか使っていないが、論文で提案されている実際の Transformer は 8つの 辞書らしきもの (論文中ではこれらは "head" と呼ばれている) を使っている。

上のフレームワークを使ったごく単純なモデルを作ってみると、 以下のようになる。これはただ入力された列と同じものを出力する (この例は Python だが、実際にはこれらの関数は ニューラルネットワークによって実現されるものとする):

BOS = 0
EOS = 999

# ニューラルネットワークによって学習された関数
def key1(x): return x
def value1(x): return x
def key2(x): return x
def value2(x): return 1
def query1(x): return x
def query2(x): return x+1
def ff(v1,v2):
    x1 = v1[-1]
    x2 = v2[-1]
    if x2 is None:
        return EOS
    else:
        return x1+1

print(transformer([BOS,1,2,3,4,5,EOS])) # [BOS,1,2,3,4,5,EOS]

しかし、このモデルはまだ不十分である。 このアルゴリズムでは、一度にひとつの要素しか考慮できない。 自然言語処理などでは、一般的に以下のような情報を考慮する必要がある:

  1. 複数の要素間の関係。
  2. 各要素の順序。

RNN では「すべての要素をいったん固定長のベクトルにまとめる」 ことでこの問題を解決しようとしていたが、先に見たように Transformer の中間層は可変長である。 以後、これらの問題を Transformer がどのように解決したかを見ていく。

4. 複数の要素間の関係を考慮する (Self-Attention、自己注意)

まず最初の問題から見てみよう。 Transformer は "Self-Attention (自己注意)" という仕組みを使って、 複数の要素からの情報を集約している。 これは Transformer 論文の題名となっている重要なアイデアであり、 近年のニューラルネットワーク研究におけるブレイクスルーのひとつである。

しかしながら、この「自己注意」という呼び名は誤解を招く。 むしろ「内部関係」とでも呼んだほうがわかりやすい。 なぜならこれが意味しているのは、入力列の各要素間の関係を 考慮するという処理に他ならないからである。 「各要素間の関係」とは、たとえば以下の文における 単語間の関係に相当する:

             +---object--+
             |           |
  +-subject--+           |
  |          |           |
  +-adj-+    |     +-adj-+
  |     |    |     |     |
"King Midas has donkey ears"

上の図で、それぞれの「関係」は対象となる 2つの単語および 関係のタイプ ("object" など) からなりたっている。Transformer の優れた点は、 学習によってこれらの関係を自動的に発見できるということである。 ただし、実際に Transformer が発見する「関係」は必ずしも 「きれいな」ものではないということに注意。 Transformer の抽出した関係を人間が見ても、 それが特定の要素間の関係であるということはわかるものの、 いったい「どんな種類の関係なのか」は理解できないことが多い。 これは畳み込みネットワーク (CNN) で、中間層を表示してみても、 それがいったいどんな特徴を表現しているのか、 人間には理解しがたいのに似ている。

この「関係」抽出処理を Python で表してみると、こうなる:

# 入力列 seq 内の self attention を求める。
def self_attn(seq):
    h1 = { sa_key1(x): sa_value1(x) for x in seq }
    h2 = { sa_key2(x): sa_value2(x) for x in seq }
    a1 = [ h1.get(sa_query1(x)) for x in seq ]
    a2 = [ h2.get(sa_query2(x)) for x in seq ]
    return [ aa(y1,y2) for (y1,y2) in zip(a1,a2) ]

ここで seq は入力列であり、関数 sa_key1(), sa_value1(), sa_key2(), sa_value2(), sa_query1(), sa_query2() および aa() は 学習可能な関数である。 まず 2つの Python辞書 h1h2 を作成しており、 これが 2つの「関係」 (論文では "head" と呼ばれている) に相当する。 この例では 2つの関係を最後に組み合わせ、出力列を生成する。

辞書内のキー/バリュー対は、それぞれ sa_keysa_value 関数によって要素ごとに計算される。 その後、関数 sa_query がこれらを参照しながら、 もう一度同じ要素列をスキャンする。 これにより、列内の各要素を任意の他の要素と比較できることになる。 以下の図は要素 "has" が他の要素を参照している 様子を例示したものである:

King Midas has Donkey ears sa_key sa_value sa_query

以下のコードでは、入力列の各要素が他のすべての要素と比較され、 その要素の2倍または 1/2 の要素が列中に含まれていれば 出力は 1 となり、そうでなれば 0 になる。

# ニューラルネットワークによって学習された関数
def sa_key1(x): return x
def sa_value1(x): return x
def sa_key2(x): return x
def sa_value2(x): return x
def sa_query1(x): return x*2
def sa_query2(x): return x/2
def aa(v1,v2):
    if v1 is None and v2 is None: return 0
    return 1

print(self_attn([BOS,1,2,3,4,5,8,EOS])) # [1,1,1,0,1,0,1,0]

実際の Transformer では、(2つではなく) 8つの辞書が出力される ようになっている。したがってこれらは8種類の異なった タイプの関係 (論文中では "Multi-Head Attention" と呼ばれている) を考慮することができる。これも本物の実装では Python の辞書は使っていないが、 行列計算と内積によって類似の処理をおこなっている。 また、関数 self_attn() の入力と出力は 同じ形のテンソルになっていることに注意してほしい。 実際の Transformer ではこの Self-Attention 層を 6つ 積み重ねており、まず要素間の「浅い関係」を抽出したのち、 それらの情報を使ってより複雑な関係を記述し… といったことができるようになっている。 この Self-Attention 機構こそが Transformer の処理能力のキモであるといってよい。

5. 要素の順序を考慮する (Positional Encoding、位置エンコーディング)

さて、複雑な順列を扱うニューラルネットワークを 設計するうえでの 2つ目の問題は、 入力列における各要素の順序を考慮することであった。 これにはいくつか方法がある。もっとも単純なのは、 各要素に順序をあらわす番号を付加することである:

['King', 'Midas', 'has', 'donkey', 'ears']
  ↓
[(0,'King'), (1,'Midas'), (2,'has'), (3,'donkey'), (4,'ears')]

しかしこの方法は追加の領域が必要になり、 ニューラルネットワークへの負荷が増すため Transformer では別の方法を使っている。 それは順序をあわらす番号のようなものを「透かし」として 元のデータに重ね合わせることである。これが論文中で "Positional Encoding (位置エンコーディング)" と呼ばれている手法である。

以下は非常に単純化した例である:

def add_positional(seq):
    return [ i*1000+x for (i, x) in enumerate(seq) ]

print(add_positional([BOS,2,5,7,9,20,EOS])) # [0, 1002, 2005, 3007, 4009, 5020, 6999]

ただし、この実装では 2つの要素を足し合わせると、 他の要素と区別できなくなる場合がある (例: 1002 + 2005 = 3007)。 実際の Positional Encoding はもう少し洗練されており、 このようなことは (ほとんど) 起こらないようになっている。 が、基本的なアイデアは同じである。

このような Positional Encoding を入力列に対して施すと、 「2つ後の要素」「先頭の要素」などの条件を考慮できるため、 より複雑な判断が可能になる。たとえば以下の例は 「同じ要素が2回連続して現れるパターン」も検出するものである:

# Positional Encoding + Self-Attention
def sa_key1(x): return x // 1000
def sa_value1(x): return x % 1000
def sa_key2(x): return x // 1000
def sa_value2(x): return x % 1000
def sa_query1(x): return x // 1000
def sa_query2(x): return (x // 1000)-1
def aa(v1,v2):
    if v1 != v2: return 0
    return 1

print(self_attn(add_positional([BOS,1,1,5,5,2,EOS]))) # [0, 0, 1, 0, 1, 0, 0]

6. まとめ

このように Transformer モデルは Self-Attention 機構と Positional Encoding を組み合わせて 各要素間で 8つの異なる関係を考慮させ、 さらにこの処理を 6回 くり返すことで 自然言語文などにみられる複雑な構造を扱えるようになっている。 以上が数式を使わない Transformer モデルの直感的な説明である。


クリエイティブ・コモンズ・ライセンス
この作品は、クリエイティブ・コモンズ 表示 - 継承 4.0 国際 ライセンスの下に提供されています。
Yusuke Shinyama