ピークがズレる!深層学習RNN/LSTMで陥りがちな罠について説明してみる

ピークがズレる!深層学習RNN/LSTMで陥りがちな罠について説明してみる

ディープラーニング – 深層学習がブームになって久しいですが、ネット上でもKerasやTensorFlowなどを用いたディープラーニングの例を多く見るようになりました。

中でも画像認識系と並んで人気なのが再帰型ニューラルネットワーク(Recurrent Neural Network)で、その中のLSTMと呼ばれる手法になります。

 

ネット上の株価予測などの例を見ていると、結構な頻度でLeakageしてるな…と思われる結果のグラフを見かけます。Leakageとは本来検証用データの中にしか存在しないはずのデータが学習データなどに紛れ込むことです。

一番多いのは、学習データへのフィット具合が完璧で、予測結果もほぼ完ぺき…のはずなんだけどなんか1ステップずれているように見える、これでは株価予測に使えないなというもの。

時系列データにおいてやっかいなのは、一見Leakageっぽくないのによく考えてみるとLeakしているという例があるためです。

LSTMでLeakageを起こしている具体的な例

どんな機械学習の手法でもそうですが、学習データに含まれている情報以上のものを学習したり、予測することはできません。

例えば、以下の2つの画像を見てみましょう。

学習データ:f:id:rk12liv:20180311150239j:plainテストデータ:

f:id:rk12liv:20180311150358j:plain

(LSTM for time series prediction · Issue #2856 · keras-team/keras · GitHubより抜粋)

緑色が学習データ / 検証データで、青色がLSTMでの予測値です。見ての通り、まだ見ていないはずのテストデータにおいてLSTMは超高精度な予測を行っています。これ、すごすぎる。

機械学習のGolden Rule:良すぎる結果が出たら疑う

機械学習においては、あまりにも良すぎる精度が出た場合はまず疑ってみるのが基本です。

なぜなら前述の通り、統計モデリングと違って機械学習では簡単にLeakageを起こしてしまうから。それは単にコーディングのミスだったり、データの分け方に問題があったり原因は様々です。

 

先ほどの例では、学習データは周期的でスパイクのようなものは存在していません。にもかかわらず、テストデータで初めて出現したスパイクを高精度で予測しているように見えます。

そんなことはありえません。何が起きているのでしょうか。

そのLSTM、実はただの後出しジャンケンだった

これは学習のされ方と、テストデータの作り方によるちょっと複雑なLeakageが理由です。

まず学習用データを作成する際、時系列データであればnからn-aまでのデータを学習用データとして、n+1を予測するというようにすることが多いです。

そうすると、n+2の学習用データにはn+1が含まれることになります。同じように、n+xの学習データにはn+x-1が含まれることになります。

つまり学習データにおいてはあるnを予測する際、その1つ前の数値が必ず含まれることになります。

データによっては、上記条件の際に学習器は”一つ前の入力データを返す”ことを学習してしまうことがあります。特に周期性が少なく、ランダムに近い動き方をするデータは危険です。

誤差を最小二乗誤差などにすると、特別なパターンが学習できない際は一つ前の数値を常に返した方が誤差を最小化できるというケースがあるからです。

 

そして”1つ前の入力データを返すこと”を学習したLSTMに今度は切り分けておいたテストデータを与えることにします。

この時点では、LSTMはテストデータを一切見たことがないはずです。

 

LSTMを使ってnからn-aまでのデータを元にn+1を予測するとします。テストデータでのnからn-aというデータは、本来ならLSTMが自分で予測しなければならないもののはずです。

なぜなら時系列データの学習の都合上、nからn-aが与えられたときn+1がどうなるか・・・という考え方をするため。ですが、例えばscikit learnならtrain test splitで何時刻目までを学習用、その時刻以降をテスト用とすることが多かったと思います。

 

時系列データでそれを行ってしまうと、ある時刻以降のデータ(予測したいデータ)がLSTM予測時の入力データとして紛れ込んでしまいます。

そして、一つ前の入力データを返すことを学習したLSTMは、無事に紛れ込んだデータを見てから次の予測でその数字を返すようになります。

 

こうして、非常に高精度に予測ができているように見えるが、1もしくは数ステップずれているように見える予測結果の出来上がりというわけです。

時系列データは、再帰的に予測する

再帰型ニューラルネットワークということで、予測の際も再帰的な手続きを踏んで予測することが必要です。

具体的には上記のような時系列データでは、テストの際の入力データは、学習用の入力データの一番最後の時刻のものを使うことになります。

 

要するに予測の際の入力データは、LSTMで予測させたその時刻より1ステップ前のものにしなければならないということです。

そうすれば、例えば1つ前の入力データを返すという学習をしてしまった場合にはずっと同じ数値を返し続け、学習に失敗していることがわかるはずです。

 

良い結果が出てしまうと信じたくなりますし、なんか合いすぎな気もするけど”天下のディープラーニングだから高精度で当たり前”という雰囲気に流されることもしばしば。

ディープラーニングといっても機械学習であり、機械学習である以上あまりにも良すぎる結果が出たら、まずは疑ってかかることにしましょう。

スポンサーリンク

機械学習カテゴリの最新記事