2017/08/01
0:00

TensorFlowで畳み込みニューラルネットワーク(CNN)実装② [ネットワークの記述]

みなさんこんにちは。人工知能ラボの助手です。

早速ですが、前回の続きをやっていきたいと思います。

前回記事:
TensorFlowで畳み込みニューラルネットワーク(CNN)実装①
[重み・バイアス生成と畳み込み・プーリング処理関数の定義]



前回の記事では、CNNで使用する重み・バイアスの生成関数、
そして、畳み込み・プーリングの計算処理を行う関数の定義を
行いました。

今回は、この関数からまずは、重み・バイアスを生成し、
さらにそれらを使用してCNNの計算フローを記述
していきたいと思います。

それでは早速、重みとバイアスの生成を行いましょう。

今回作成するネットワークは、畳み込み・プーリングの処理を
2回繰り返し、その後2層の全結合層に入力、クラス分類の形で
出力というものにしたいと思います。

まずはコードを載せて、その後に説明を書きたいと思います。

コード
# 畳み込み層1の重みとバイアス
w_c1 = set_weight([5, 5, 3, 32])
b_c1 = set_bias([32])

# プーリング処理には、重み・バイアス等は使用しません

# 畳み込み層2の重みとバイアス
w_c2 = set_weight([5, 5, 32, 64])
b_c2 = set_bias([64])

# 全結合層1の重みとバイアス
w_f1 = set_weight([3136, 2000])
b_f1 = set_bias([2000])

# 全結合層2(出力層)の重みとバイアス
w_f2 = set_weight([2000, 5])
b_f2 = set_bias([5])

今回は、入力画像データの大きさを、
28*28*3(カラー画像)とします。

畳み込み層の重みフィルターの各次元は、

フィルターの縦*フィルターの横
*入力のチャネル数(カラー画像なら3)
*出力のチャネル数(フィルターの枚数)

の4次元で構成されます。

バイアスは出力のチャネルごとに1つなので、
出力チャネル数の成分をもつベクトルで作成します。

コードを見ると分かるかと思いますが、

今回はカラー画像に対し、1層目では5*5のフィルター32枚
で畳み込み(ゼロパディングあり)、2*2の範囲で
最大値プーリングします。
(出力は14*14*32となります)

そして、2層目では1層目の出力に対しさらに5*5*32の
フィルター64枚で畳み込み(ゼロパディング)、
1層目と同様に2*2の範囲で最大値プーリングします。
(出力は7*7*64となります)

この出力をベクトル(7*7*64=3,136成分)に展開し、
全結合層に入力します。

全結合層は、通常のニューラルネットワークと
同様の計算を行います。
1層目のユニット数は2000としました。
今回は5クラスの分類を行う想定でいるので、
2層目(出力層)のユニット数は5となります。

上記のコードでは、重みとバイアスを生成しただけ
ですので、あとは計算フローの記述が必要です。

最後にそれを書いて、今回は終了にしましょう。
前回定義した畳み込み・プーリングの関数を使用します。
入力する画像データについては、今回は仮で
x(28*28*3のnumpy配列)としておき、次次回くらい
の記事で、詳しく解説しようかと思います。

コード
# 畳み込み・プーリング層1
h_c1 = tf.nn.relu(convolution(x, w_c1) + b_c1)
h_p1 = pooling(h_c1)

# 畳み込み・プーリング層2
h_c2 = tf.nn.relu(convolution(h_p1, w_c2) + b_c2)
h_p2 = pooling(h_c2)

# プーリング層2の出力を、ベクトルにreshape
h_p2_vec = tf.reshape(h_p2, [-1, 3136])

# 全結合層1
h_f1 = tf.nn.relu(tf.matmul(h_p2_vec, w_f1) + b_f1)

# 全結合層2
h_f2 = tf.nn.softmax(tf.matmul(h_f1, w_f2) + b_f2)

活性化関数については、出力層以外はReLU、
出力層は確率での出力なので、ソフトマックス
を使用しています。

畳み込み・プーリングについては、前回定義した関数を
使用しているので、その内容については前回記事
参照してください。

また、今回記述した処理を冒頭の画像で図にまとめて
ありますので、コードで分かりづらい方はそちらも
ご覧ください。

それでは、今回はここまでにしておきましょう。
次回は、誤差関数とそれを最小化する学習部分の
記述を行おうと思います。

お疲れ様でした!