FindBestSplitを書いてみる
背景
前回、前々回の記事でrandomForest
を使ってみたのですが、ソースコードを読んでいるとノードの分割においてfindbestsplit
というサブルーチンが使われていることに気が付きました。このサブルーチン自体はこちらのL191に定義されているのでそれを読めばわかる(はずな)のですが、もう少しわかりやすい説明はないかなーと探してみたところ、こんな解説記事を見つけました。
http://dni-institute.in/blogs/cart-algorithm-for-decision-tree/
これによると、どうやらfindbestsplit
は
というステップによって最良の閾値を探しているようです。それほど難しくなさそうなので、これをRで書いてみましょう。
実装
findbestsplit
を実装するためには、以下のような関数が必要となりそうです。
- データ、説明変数を与えると閾値の候補を返す(return_threshold_values)
- データ、目的変数、説明変数、閾値を与えるとGini係数を返す(return_gini_index)
- 現在のGini係数との差分が最大となる(最良な)閾値を返す(return_best_value)
まずは1つ目から書いてみましょう。
1. データ、説明変数を与えると閾値の候補を返す関数
先ほど紹介したページでは閾値の候補を生成するための方法として
One of the common approach is to find splits /cut off point is to take middle values
との説明がありましたので、これに倣います。以下のように書いてみました:
- 説明変数列を
unique
する sort
で並び替える- 各要素について1つ前の値との差分(
diff
)を取る - 差分を2で割る
sort
後の列(最初の要素は除く)に加える
return_threshold_values <- function(dat, col) { uniq_vals <- sort(unique(dat[, col])) diffs <- diff(uniq_vals) / 2 thre_vals <- uniq_vals[-1] - diffs return(thre_vals) }
2. データ、目的変数、説明変数、閾値を与えるとGini係数を返す関数
続いてGini係数を求める関数を定義します。ここでは引数として閾値も与え、後ほどapply
で閾値候補をまとめて並列に処理することを考えました。Gini係数の求め方はここを参考に、以下のように書きました:
- 説明変数を所与の閾値で1/0のカテゴリに振り分ける
- 全体および各カテゴリのサンプルサイズを求める
- 目的変数 × 説明変数による混同行列の各要素の割合(の二乗)を計算する
- Gini係数を求める
return_gini_index <- function(dat, target, col, val) { d <- dat[, c(target, col)] d$cat <- ifelse(d[, col] > val, 1, 0) n0 <- nrow(d) n1 <- sum(d$cat) n2 <- n0 - n1 p11 <- (nrow(d[d$cat == 1 & d$target == 1, ]) / n1)^2 p12 <- (nrow(d[d$cat == 1 & d$target == 0, ]) / n1)^2 p21 <- (nrow(d[d$cat == 0 & d$target == 1, ]) / n2)^2 p22 <- (nrow(d[d$cat == 0 & d$target == 0, ]) / n2)^2 gini_val <- (n1/n0) * (1 - (p11 + p12)) + (n2/n0) * (1 - (p21 + p22)) return(gini_val) }
3. 現在のGini係数との差分が最大となる(最良な)閾値を返す関数
上記の処理によってある閾値におけるGini係数を求めることが出来ましたので、これをapply
で並列化します。また実際のところ必要な値はGini係数そのものではなく、現時点におけるGini係数との差分なので、それも計算しましょう。これまでに定義した関数を使って以下のように書きます:
- 閾値候補のベクトルを
list
にする - 閾値候補リストを
return_gini_index
にsapply
で渡す - 現時点のGini係数を計算する
- 差分の最大値、Gini係数の最小値を得る
- 差分が最大となる閾値を得る
- 変数名と合わせて返す
差分の最大値やGini係数の最小値は別に必要ないのですが、参考のために取っておきます。
return_best_val <- function(dat, target, col) { thre_vals <- as.list(return_threshold_values(dat, col)) gini_vals <- sapply(thre_vals, return_gini_index, dat = dat, target = target, col = col) p1 <- sum(dat$target) / nrow(dat) p2 <- 1 - p1 current_gini <- 1 - ((p1)^2 + (p2)^2) max_gap <- max(current_gini - gini_vals) min_gini <- min(gini_vals) max_thre_val <- thre_vals[[which(max_gap == current_gini - gini_vals)]] return(c(col, max_thre_val, round(max_gap, 2), round(min_gini, 2))) }
では、このようにして定義した関数を実際に当てはめまめてみます。iris
を使って適当に以下のようなデータを用意します。
my_iris <- iris my_iris$target <- ifelse(my_iris$Species == "setosa", 1, 0) dat <- my_iris[, -5]
まずはSepal.Length
の最良な閾値を取得してみます。
> return_best_val(dat, "target", "Sepal.Length") [1] "Sepal.Length" "5.45" "0.3" "0.14"
これが合っているのかはわかりませんが、続いて全部の変数に同時に当てはめてみましょう。lapply
とdo.call
を使います。
> do.call("rbind", + lapply(as.list(colnames(dat)[1:4]), return_best_val, dat = dat, target = "target")) [,1] [,2] [,3] [,4] [1,] "Sepal.Length" "5.45" "0.3" "0.14" [2,] "Sepal.Width" "3.35" "0.17" "0.28" [3,] "Petal.Length" "2.45" "0.44" "0" [4,] "Petal.Width" "0.8" "0.44" "0"
各変数について最良な閾値を取得することが出来たようです。ところで4列目(Gini係数の最小値)を確認すると、Petal.Length
とPetal.Width
で0になっていますが、Gini係数が0ということは完全に分離されていることを意味します。確かめてみましょう。
> plot(dat$Petal.Length, col = dat$target + 1) > abline(h = 2.45) > plot(dat$Petal.Width, col = dat$target + 1) > abline(h = 0.8)
確かに完全に分離しているようです。
終わりに
今回はrandomForest
の中でも使われているfindbestsplit
というアルゴリズムをRで書いてみました。実際には決定木やRandom Forestは一連の処理を再帰的に繰り返しているのですが、最も重要なポイントはこちらになるのだと思います。なおPythonのDecisionTreeClassifier
でも同じようなアルゴリズムとなっているのか確認したかったのですが、Pythonのソースコードの表示方法が良くわかりませんでした。
おしまい。