統計コンサルの議事メモ

統計や機械学習の話題を中心に、思うがままに

Random Forestの結果をSQLに落としたい

Random Forestは数ある機械学習アルゴリズムの中でも、高い精度を得やすく分散処理が容易であることから頻繁に用いられるものの一つであると思う。
Rでは{randomForest}パッケージが有名だが、最近では{ranger}や{Rborist}といった新しいパッケージが出ていたり、{edarf}のようなRandom Forestによるデータの解析や解釈を主眼にしたものが出ていたりして動きとして面白い。

そんなRandom Forestであるが、せっかくRで良いモデルを作ることができたとしても実際には使えないことがある。これは、例えばSQLに落とすことが容易ではないためシステムに組み込めないからだ。

ビジネスにおいてデータをRで分析する場面は数多くあるが、Rの結果を(文字通り)そのままシステムに埋め込むことはほとんどないと言って良いと思う。Rを使って得られたモデルは、SQLなどの他の環境に合わせて再構築した上で利用することが多いだろう。そこで今回はRの{randomForest}パッケージで作成したモデルをSQLに落とすための方法を紹介したい。

なおこの記事は基本的に以下の記事を参考に作成したが、ひとまずRの{randomForest}で作成したモデルからSQLを得るための流れを示すことが目的であったため大幅に簡略化してある。またSQLのインデントも私の好みに合わせたので、元記事も参照することをお薦めする。

https://gist.github.com/shanebutler/96f0e78a02c84cdcf558

処理の流れ

Random Forestというアルゴリズムは、大雑把に言えば簡易な決定木を大量に作成し、その結果の多数決によって最終的な判断を行うというものである。よってSQLに落とす場合も手順としては以下のようになる:

  1. それぞれの決定木の結果をSQLとして書き下す
  2. 1つのテーブルに各SQLの結果をINSERTする
  3. 結果を集約する

ただし実際には1と2は同時に行われるだろう。

1. 決定木の結果をSQLとして書き下す

それでは実装に移ろう。まずはSQLに書き下すところから。サンプルとしてirisを用い、簡単のためrandomForestの木の数は2としたモデルを作成する。

library(randomForest)
set.seed(123)
res <- randomForest(Species ~ ., data = iris, ntree = 2)

この処理で得られた結果はresオブジェクトに保存されるが、特にSQLを書くために必要となる決定木の分岐はgetTreeで取り出すことができる。取り出したい木が何番目であるかを指定し、ラベルの表示をTRUEにすることで:

> getTree(res, k = 1, labelVar = TRUE)
   left daughter right daughter    split var split point status prediction
1              2              3  Petal.Width        0.80      1       <NA>
2              0              0         <NA>        0.00     -1     setosa
3              4              5  Petal.Width        1.65      1       <NA>
4              6              7  Petal.Width        1.35      1       <NA>
5              8              9 Sepal.Length        6.05      1       <NA>
6              0              0         <NA>        0.00     -1 versicolor
7             10             11 Petal.Length        4.95      1       <NA>
8             12             13 Petal.Length        4.85      1       <NA>
9              0              0         <NA>        0.00     -1  virginica
10             0              0         <NA>        0.00     -1 versicolor
11             0              0         <NA>        0.00     -1  virginica
12            14             15 Sepal.Length        5.40      1       <NA>
13             0              0         <NA>        0.00     -1  virginica
14             0              0         <NA>        0.00     -1  virginica
15             0              0         <NA>        0.00     -1 versicolor

このようなテーブルが得られる。決定木は親となるノードから次々に分岐していく形で木が形成されるが、このテーブルはそのためのルールを示したものとなっており、「left daughter」が該当する場合に次に参照する行を、「right daughter」が該当しなかった場合に参照する行を示す。
例えばこの場合、1行目を読むことで「Petal.Widthが0.8未満であるか」を基準に分岐が発生することがわかり、また該当する場合(left daughter)には2行目を参照することでsetosaになることを示している。

このようなテーブルを各決定木について得ることができるため、SQLに落とすための流れとしては:

  1. テーブルをSQLにするための処理を書く
  2. 決定木の数だけループを回す

ことが考えられるだろう。そこでこのテーブルをSQLに落とすための関数を以下のように定義する:

Make_SQL <- function(Rule_Table, cnt, ind = 1) {
   Rule   <- Rule_Table[cnt, ]
   Indent <- paste(rep("   ", ind), collapse = "")
   var    <- as.character(Rule[, "split var"])
   val    <- Rule[, "split point"]
   
   if(Rule[, "status"] != -1) {
      cat(paste0("\n", Indent, "CASE\n", Indent, "   WHEN ", var, " <= ", val, " THEN"))
      Make_SQL(Rule_Table, Rule[, "left daughter"], ind = (ind + 2))
      cat(paste0("\n", Indent, "   ELSE"))
      Make_SQL(Rule_Table, Rule[, "right daughter"], ind = (ind + 2))
      cat(paste0("\n", Indent, "END"))
   } else { 
      cat(paste0(" '", Rule[, "prediction"], "'"))
   }
}

ここでRule_Tableは先ほどのテーブルを示している。また関数の序盤で定義しているIndentSQLを書くときのインデントを私の好みに合わせるために修正してある。

全体的に、SQLはCASE文で形成されることがわかる。これは決定木がIf THEN Elseで表現可能なルールによって作成されることからも自明だろう。なおRule_Tablestatusには、当該ノードが末端(ターミナルノード)であるかが示されているため、ここを参照することで予測値が決まるか(それ以上の分岐がないか)がわかる。もしターミナルノードでないなら分岐した先で同様の処理を行うため、再帰的な呼び出しが必要であることがわかる。分岐した場合、該当するものについてはleft daughterの行を参照し、そうでないものについてはright daughterを参照するようになっている。

これを実行すると、例えば以下のようなSQL文が得られる:

> Make_SQL(Rule_Table, 1, 1)

   CASE
      WHEN Petal.Width <= 0.75 THEN 'setosa'
      ELSE
         CASE
            WHEN Petal.Width <= 1.75 THEN
               CASE
                  WHEN Sepal.Width <= 2.25 THEN
                     CASE
                        WHEN Petal.Length <= 4.75 THEN 'versicolor'
                        ELSE 'virginica'
                     END
                  ELSE
                     CASE
                        WHEN Petal.Width <= 1.35 THEN 'versicolor'
                        ELSE
                           CASE
                              WHEN Sepal.Width <= 2.65 THEN
                                 CASE
                                    WHEN Petal.Width <= 1.45 THEN 'virginica'
                                    ELSE 'versicolor'
                                 END
                              ELSE 'versicolor'
                           END
                     END
               END
            ELSE
               CASE
                  WHEN Sepal.Length <= 5.95 THEN
                     CASE
                        WHEN Sepal.Width <= 3 THEN 'virginica'
                        ELSE 'versicolor'
                     END
                  ELSE 'virginica'
               END
         END
   END

これをRandom Forestのモデル作成に用いた決定木の数だけ作成すれば良いということになる。なお決定木の数を大きなものとした場合にSQL文が大きくなってしまうため、実行の際には注意が必要である。

このようにして得られたSQLはこのままだと表示して終わりであるため、何らかのテーブルに格納しておく必要がある。そこでINSERT文を前に置いておくことにする。すなわち:

cat(paste0("INSERT INTO tbl_rf \nSELECT\n   id, "))
Make_SQL(Rule_Table, 1, 1)

このような書き方になる。ここでidは予測する場合の粒度を示すことになる。irisの場合は1行が1データ(個体)になるため行ごとに連番を振れば良いが、ここでは省略する。

これで一つの決定木についてSQLによる結果をテーブルに格納することができるようになったため、上記のスクリプトをループで回すと以下のようになるだろう:

for (i in 1:(res$ntree)) {
   Rule_Table <- getTree(res, k = i, labelVar = TRUE)
   cat(paste0("INSERT INTO tbl_rf \nSELECT\n   id, "))
   Make_SQL(Rule_Table, 1, 1)
   cat(paste0(" as tree", i, "\nFROM\n", "   input_data", ";\n\n"))
}

res$ntreeにはモデルの作成に用いた決定木の数が格納されているため、その分だけループを回すことになる。またFROM句では決定木を当てはめるためのデータ(この例ではiris)を指定することになる。

1つのテーブルに各SQLの結果を集約する

上記の処理によりtbl_rfには各決定木の結果が格納されているだろう。これを用いて最終的な判断を行うには、各決定木の結果を集約する必要がある。以下のようになるだろう:

> cat(paste0("INSERT INTO rf_predictions\n",
+            "SELECT\n   a.id,\n   a.pred\n",
+            "FROM\n   (\n      SELECT id as id, pred, COUNT(*) as cnt \n",
+            "      FROM tbl_rf\n      GROUP BY id, pred\n   ) a\n",
+            "   INNER JOIN\n   (\n      ",
+            "SELECT id, MAX(cnt) as cnt\n",
+            "      FROM\n",
+            "         (\n", 
+            "            SELECT id as id, pred, COUNT(*) as cnt\n", 
+            "            FROM tbl_rf\n",
+            "            GROUP BY id, pred\n", 
+            "         )\n",
+            "      GROUP BY id\n   ) b\n",
+            "      ON a.id = b.id AND a.cnt = b.cnt;\n\n"))
INSERT INTO rf_predictions
SELECT
   a.id,
   a.pred
FROM
   (
      SELECT id as id, pred, COUNT(*) as cnt 
      FROM tbl_rf
      GROUP BY id, pred
   ) a
   INNER JOIN
   (
      SELECT id, MAX(cnt) as cnt
      FROM
         (
            SELECT id as id, pred, COUNT(*) as cnt
            FROM tbl_rf
            GROUP BY id, pred
         )
      GROUP BY id
   ) b
      ON a.id = b.id AND a.cnt = b.cnt;

irisSpeciesを用いる場合はclassificationとなるため、各決定木の判断を集約し、最も多かったクラスを最終的な予測とする。上記はそのためのクエリとなっていることがわかる。

最後に、これまでのスクリプトを一つにまとめたものを以下に示す:

library(randomForest)
set.seed(123)
res <- randomForest(Species ~ ., data = iris, ntree = 2)

Make_SQL <- function(Rule_Table, cnt, ind = 1) {
   Rule   <- Rule_Table[cnt, ]
   Indent <- paste(rep("   ", ind), collapse = "")
   var    <- as.character(Rule[, "split var"])
   val    <- Rule[, "split point"]
   
   if(Rule[, "status"] != -1) {
      cat(paste0("\n", Indent, "CASE\n", Indent, "   WHEN ", var, " <= ", val, " THEN"))
      Make_SQL(Rule_Table, Rule[, "left daughter"], ind = (ind + 2))
      cat(paste0("\n", Indent, "   ELSE"))
      Make_SQL(Rule_Table, Rule[, "right daughter"], ind = (ind + 2))
      cat(paste0("\n", Indent, "END"))
   } else { 
      cat(paste0(" '", Rule[, "prediction"], "'"))
   }
}

for (i in 1:(res$ntree)) {
   Rule_Table <- getTree(res, k = i, labelVar = TRUE)
   cat(paste0("INSERT INTO tbl_rf \nSELECT\n   id, "))
   Make_SQL(Rule_Table, 1, 1)
   cat(paste0(" as tree", i, "\nFROM\n", "   input_data", ";\n\n"))
}

cat(paste0("INSERT INTO rf_predictions\n",
           "SELECT\n   a.id,\n   a.pred\n",
           "FROM\n   (\n      SELECT id as id, pred, COUNT(*) as cnt \n",
           "      FROM tbl_rf\n      GROUP BY id, pred\n   ) a\n",
           "   INNER JOIN\n   (\n      ",
           "SELECT id, MAX(cnt) as cnt\n",
           "      FROM\n",
           "         (\n", 
           "            SELECT id as id, pred, COUNT(*) as cnt\n", 
           "            FROM tbl_rf\n",
           "            GROUP BY id, pred\n", 
           "         )\n",
           "      GROUP BY id\n   ) b\n",
           "      ON a.id = b.id AND a.cnt = b.cnt;\n\n"))

このスクリプトにおいて、for以降の部分をsinkで書き出せば必要なSQLを得ることができるだろう。

最後に

Random ForestはDeep Learningが注目されるようになった現在においても強力なアルゴリズムの一つである。またDeep Forestのように多段構造とすることでDeep Learningを上回る性能を得る場合もある。分散処理に強く、人による解釈も容易であるためビジネスにおける実用性は相当に高いアルゴリズムであるため、実装への障害によって敬遠されるのは非常に勿体無いと思う。この記事がそれらの障害を乗り越えるための一助となれば幸いである。

現状ではDeep LearningよりもRandom Forestを使いこなせる方が実用においては有益だろう。曲線的な分類に弱いなどの欠点もあるが、まだまだ発展のポテンシャルが高いアルゴリズムだと思うので今後も注目していきたい。