統計コンサルの議事メモ

統計や機械学習の話題を中心に、思うがままに

FindBestSplitを書いてみる

背景

前回、前々回の記事でrandomForestを使ってみたのですが、ソースコードを読んでいるとノードの分割においてfindbestsplitというサブルーチンが使われていることに気が付きました。このサブルーチン自体はこちらのL191に定義されているのでそれを読めばわかる(はずな)のですが、もう少しわかりやすい説明はないかなーと探してみたところ、こんな解説記事を見つけました。

http://dni-institute.in/blogs/cart-algorithm-for-decision-tree/

これによると、どうやらfindbestsplit

  1. 閾値(Cut off points)を決める
  2. 閾値におけるGini係数を求める
  3. 現時点のGini係数とのギャップが最大となる閾値を探す

というステップによって最良の閾値を探しているようです。それほど難しくなさそうなので、これをRで書いてみましょう。

実装

findbestsplitを実装するためには、以下のような関数が必要となりそうです。

  1. データ、説明変数を与えると閾値の候補を返す(return_threshold_values)
  2. データ、目的変数、説明変数、閾値を与えるとGini係数を返す(return_gini_index)
  3. 現在のGini係数との差分が最大となる(最良な)閾値を返す(return_best_value

まずは1つ目から書いてみましょう。

1. データ、説明変数を与えると閾値の候補を返す関数

先ほど紹介したページでは閾値の候補を生成するための方法として

One of the common approach is to find splits /cut off point is to take middle values

との説明がありましたので、これに倣います。以下のように書いてみました:

  1. 説明変数列をuniqueする
  2. sortで並び替える
  3. 各要素について1つ前の値との差分(diff)を取る
  4. 差分を2で割る
  5. sort後の列(最初の要素は除く)に加える
return_threshold_values <- function(dat, col) {
   
   uniq_vals <- sort(unique(dat[, col]))
   diffs <- diff(uniq_vals) / 2
   thre_vals <- uniq_vals[-1] - diffs
   
   return(thre_vals)
}

2. データ、目的変数、説明変数、閾値を与えるとGini係数を返す関数

続いてGini係数を求める関数を定義します。ここでは引数として閾値も与え、後ほどapply閾値候補をまとめて並列に処理することを考えました。Gini係数の求め方はここを参考に、以下のように書きました:

  1. 説明変数を所与の閾値で1/0のカテゴリに振り分ける
  2. 全体および各カテゴリのサンプルサイズを求める
  3. 目的変数 × 説明変数による混同行列の各要素の割合(の二乗)を計算する
  4. Gini係数を求める
return_gini_index <- function(dat, target, col, val) {
   
   d <- dat[, c(target, col)]
   d$cat <- ifelse(d[, col] > val, 1, 0)

   n0 <- nrow(d)
   n1 <- sum(d$cat)
   n2 <- n0 - n1

   p11 <- (nrow(d[d$cat == 1 & d$target == 1, ]) / n1)^2
   p12 <- (nrow(d[d$cat == 1 & d$target == 0, ]) / n1)^2
   p21 <- (nrow(d[d$cat == 0 & d$target == 1, ]) / n2)^2
   p22 <- (nrow(d[d$cat == 0 & d$target == 0, ]) / n2)^2

   gini_val <- (n1/n0) * (1 - (p11 + p12)) + (n2/n0) * (1 - (p21 + p22))
   
   return(gini_val)
}

3. 現在のGini係数との差分が最大となる(最良な)閾値を返す関数

上記の処理によってある閾値におけるGini係数を求めることが出来ましたので、これをapplyで並列化します。また実際のところ必要な値はGini係数そのものではなく、現時点におけるGini係数との差分なので、それも計算しましょう。これまでに定義した関数を使って以下のように書きます:

  1. 閾値候補のベクトルをlistにする
  2. 閾値候補リストをreturn_gini_indexsapplyで渡す
  3. 現時点のGini係数を計算する
  4. 差分の最大値、Gini係数の最小値を得る
  5. 差分が最大となる閾値を得る
  6. 変数名と合わせて返す

差分の最大値やGini係数の最小値は別に必要ないのですが、参考のために取っておきます。

return_best_val <- function(dat, target, col) {
   
   thre_vals <- as.list(return_threshold_values(dat, col))
   gini_vals <- sapply(thre_vals, return_gini_index, dat = dat, target = target, col = col)
   
   p1 <- sum(dat$target) / nrow(dat)
   p2 <- 1 - p1
   current_gini <- 1 - ((p1)^2 + (p2)^2)
   
   max_gap <- max(current_gini - gini_vals)
   min_gini <- min(gini_vals)
   max_thre_val <- thre_vals[[which(max_gap == current_gini - gini_vals)]]
   return(c(col, max_thre_val, round(max_gap, 2), round(min_gini, 2)))
   
}

では、このようにして定義した関数を実際に当てはめまめてみます。irisを使って適当に以下のようなデータを用意します。

my_iris <- iris
my_iris$target <- ifelse(my_iris$Species == "setosa", 1, 0)
dat <- my_iris[, -5]

まずはSepal.Lengthの最良な閾値を取得してみます。

> return_best_val(dat, "target", "Sepal.Length")
[1] "Sepal.Length" "5.45"         "0.3"          "0.14"  

これが合っているのかはわかりませんが、続いて全部の変数に同時に当てはめてみましょう。lapplydo.callを使います。

> do.call("rbind", 
+         lapply(as.list(colnames(dat)[1:4]), return_best_val, dat = dat, target = "target"))
     [,1]           [,2]   [,3]   [,4]  
[1,] "Sepal.Length" "5.45" "0.3"  "0.14"
[2,] "Sepal.Width"  "3.35" "0.17" "0.28"
[3,] "Petal.Length" "2.45" "0.44" "0"   
[4,] "Petal.Width"  "0.8"  "0.44" "0"

各変数について最良な閾値を取得することが出来たようです。ところで4列目(Gini係数の最小値)を確認すると、Petal.LengthPetal.Widthで0になっていますが、Gini係数が0ということは完全に分離されていることを意味します。確かめてみましょう。

> plot(dat$Petal.Length, col = dat$target + 1)
> abline(h = 2.45)
> plot(dat$Petal.Width, col = dat$target + 1)
> abline(h = 0.8)

f:id:ushi-goroshi:20190213193440p:plain

f:id:ushi-goroshi:20190213193457p:plain

確かに完全に分離しているようです。

終わりに

今回はrandomForestの中でも使われているfindbestsplitというアルゴリズムをRで書いてみました。実際には決定木やRandom Forestは一連の処理を再帰的に繰り返しているのですが、最も重要なポイントはこちらになるのだと思います。なおPythonDecisionTreeClassifierでも同じようなアルゴリズムとなっているのか確認したかったのですが、Pythonソースコードの表示方法が良くわかりませんでした。

おしまい。