randomForestで有効な交互作用を発見したい
背景
GLMは発想がわかりやすく解釈性も高くて良いアルゴリズム1なのですが、線形の仮定があるため変数間の交互作用を見るのが苦手です。実際のプロジェクトでGLMを使った結果を見せ、
- 変数の組み合わせ効果みたいなものは見れないの?
- この変数は条件によって効き方が違うんだよね〜
みたいな指摘を受けて困った経験があったりしないでしょうか。そんな時に使えるテクニックを同僚から教えてもらったので、備忘がてらメモしておきます。勝手に公開して怒られる可能性もありますが。。。
概要
手法の概要ですが、話としてはシンプルで「もしも有効な変数の組み合わせ(交互作用)が存在しているのであれば、Random Forestの各決定木において、ノードの分岐に使われる変数の順番として出現しやすいペアがあるのではないか」ということです。例えば変数X1とX2の間に交互作用があれば、決定木においてX1が選択された場合、続く分岐ではX2が選択されやすくなるのではないでしょうか。
実装
上記のアイディアを実現するために、以下のように実装してみます:
- Random Forestでモデルを作る
- 各決定木から分岐に用いられた変数ペアを得る
- 出現回数のカウントを取る
- 交互作用効果を確かめてみる
1. Random Forestでモデルを作る
まずはRandom Forestでモデルを作ります。randomForest
パッケージを使ってサクッと作りましょう。
### libraryの読み込み library(randomForest) library(tidyverse)
データには前回記事と同じTelco Customer Churn
を使いますが、前回の反省を踏まえてread.csv
を使います。
前回の記事はこちら。
d <- read.csv("./Data/WA_Fn-UseC_-Telco-Customer-Churn.csv") %>% as_data_frame()
本来ならここから一つ一つの変数を観察するところですが、今回はそれが目的ではないので欠損だけ埋めておきます。
> colSums(apply(d, c(1, 2), is.na)) customerID gender SeniorCitizen Partner Dependents tenure 0 0 0 0 0 0 PhoneService MultipleLines InternetService OnlineSecurity OnlineBackup DeviceProtection 0 0 0 0 0 0 TechSupport StreamingTV StreamingMovies Contract PaperlessBilling PaymentMethod 0 0 0 0 0 0 MonthlyCharges TotalCharges Churn 0 11 0
TotalCharges
に欠損があるようですね。
> summary(d$TotalCharges) Min. 1st Qu. Median Mean 3rd Qu. Max. NA's 18.8 401.4 1397.5 2283.3 3794.7 8684.8 11
MedianとMeanに差があるので分布が偏っていそうです。ひとまずNA
はMedianで埋めておきましょう。
d2 <- d %>% mutate(TotalCharges = if_else(is.na(.$TotalCharges), median(.$TotalCharges, na.rm = T), .$TotalCharges))
customerID
は変数として使えないので除外しましょう。またrandomForest
はカテゴリ数が53より多い変数を扱えないので、カテゴリ数をチェックしておきます。
cat_vars <- sapply(d2, is.factor)
> apply(d2[, cat_vars], 2, function(x) length(unique(x))) customerID gender Partner Dependents PhoneService MultipleLines 7043 2 2 2 2 3 InternetService OnlineSecurity OnlineBackup DeviceProtection TechSupport StreamingTV 3 3 3 3 3 3 StreamingMovies Contract PaperlessBilling PaymentMethod Churn 3 3 2 4 2
大丈夫そうですね。customerID
だけ落としておきます。
d3 <- d2 %>% select(-customerID)
randomForest
を当てはめます。目的変数はChurn
です。
set.seed(123) result <- randomForest(Churn ~ ., d3, ntree = 500)
> result Call: randomForest(formula = Churn ~ ., data = d3, ntree = 500) Type of random forest: classification Number of trees: 500 No. of variables tried at each split: 4 OOB estimate of error rate: 20.3% Confusion matrix: No Yes class.error No 4643 531 0.1026285 Yes 899 970 0.4810059
精度とかは気にしません。
2. 各決定木から分岐に用いられた変数ペアを得る
次にRandom Forestから変数ペアを取得します。そのためにはRandom Forestの各決定木について、どの変数がどの順番で分岐に用いられたかを知る必要があります。
まずはRandom Forestから各決定木の結果を取ってきましょう。 getTree
関数を使います。過去記事も参考にしてください。
> getTree(result, 1, labelVar = TRUE) left daughter right daughter split var split point status prediction 1 2 3 Contract 1.000 1 <NA> 2 4 5 OnlineBackup 1.000 1 <NA> 3 6 7 TechSupport 2.000 1 <NA> 4 8 9 tenure 2.500 1 <NA> 5 10 11 InternetService 2.000 1 <NA> 6 12 13 Dependents 1.000 1 <NA> # 中略
getTree
は分岐に用いられた変数とその分岐先ノードなどを返します。ここで必要なのは1~3列目なのですが、それぞれ「左の子ノード」「右の子ノード」「分岐に用いられた変数」を意味しています。例えば1行目を見るとノードの分岐に Contract
が用いられたことがわかります。
このテーブルに行インデックスを列として追加しましょう。なおこれ以降のコードはこちらの記事を参考にさせて頂きました。
tree_tbl <- getTree(result, 1, labelVar = TRUE) %>% # labelVar = Fだとエラー rownames_to_column() %>% mutate(rowname = as.integer(rowname))
作成した tree_tbl
では分岐に用いた変数( split var
)はわかりますが、変数のペアはわかりません。例えば1行目の Contract
で分岐された子ノード(2と3)は、次にそれぞれが OnlineBackup
と TechSupport
で分岐されています(2行目、3行目)ので、 Contract - OnlineBackup
と Contract - TechSupport
という変数ペアが出現したことがわかるような形に整形したいですね。
各行には「分岐に用いた変数」と「分岐先の子ノードの番号」がありますので、「分岐先の子ノード(左右両方)」に「分岐元の変数」を追加すれば欲しいものが得られそうです。
まずはノードと変数のマスタを用意しましょう。
var_name <- tree_tbl %>% select(rowname, "split var") %>% rename(split_var =`split var`) %>% # スペースを`_`に修正 unique() %>% filter(!is.na(.$split_var))
続けて left daughter
と right daughter
それぞれに rowname
でJOINします。
> tree_tbl %>% + left_join(var_name, by = c("left daughter" = "rowname")) %>% + left_join(var_name, by = c("right daughter" = "rowname")) %>% + select(rowname, `split var`, `split_var.x`, `split_var.y`) %>% + na.omit() rowname split var split_var.x split_var.y 1 1 Contract OnlineBackup TechSupport 2 2 OnlineBackup tenure InternetService 3 3 TechSupport Dependents StreamingMovies 4 4 tenure InternetService InternetService 5 5 InternetService Dependents OnlineSecurity 6 6 Dependents TotalCharges tenure # 中略
良さそうですね。ただしこのままでは後の工程で使いにくいのでもう少し加工します。
> tree_tbl %>% + left_join(var_name, by = c("left daughter" = "rowname")) %>% + left_join(var_name, by = c("right daughter" = "rowname")) %>% + select(`split var`, `split_var.x`, `split_var.y`) %>% + na.omit() %>% + rename(from_var = `split var`, + left = `split_var.x`, + right = `split_var.y`) %>% + gather(key = node, value = to_var, -from_var) %>% + select(-node) from_var to_var 1 Contract OnlineBackup 2 OnlineBackup tenure 3 TechSupport Dependents 4 tenure InternetService 5 InternetService Dependents 6 Dependents TotalCharges # 中略
これで必要なアウトプットが得られました。あとは上記の加工を関数化しておき、各決定木に当てはめれば良さそうです。
get_var_pairs <- function(tree_num, rf = result) { # 決定木の結果を得る tree_tbl <- getTree(rf, tree_num, labelVar = TRUE) %>% rownames_to_column() %>% mutate(rowname = as.integer(rowname)) var_name <- tree_tbl %>% select(rowname, "split var") %>% rename(split_var =`split var`) %>% # スペースを`_`に修正 unique() %>% filter(!is.na(.$split_var)) out <- tree_tbl %>% left_join(var_name, by = c("left daughter" = "rowname")) %>% left_join(var_name, by = c("right daughter" = "rowname")) %>% select(`split var`, `split_var.x`, `split_var.y`) %>% na.omit() %>% rename(from_var = `split var`, left = `split_var.x`, right = `split_var.y`) %>% gather(key = node, value = to_var, -from_var) %>% select(-node) return(out) }
試してみましょう。
> get_var_pairs(5, result) from_var to_var 1 PaymentMethod Contract 2 Contract TechSupport 3 OnlineSecurity tenure 4 TechSupport PaperlessBilling 5 tenure TechSupport 6 TotalCharges StreamingTV # 中略
これを全ての決定木に当てはめます。 purrr
の map_dfr
を使ってみましょう。
var_pairs <- map_dfr(as.matrix(1:5), get_var_pairs, result) # 少しだけ実行
> dim(var_pairs) [1] 2950 2 > head(var_pairs) from_var to_var 1 Contract OnlineBackup 2 OnlineBackup tenure 3 TechSupport Dependents 4 tenure InternetService 5 InternetService Dependents 6 Dependents TotalCharges
素直に for
で書いた時と同じ結果になっていますでしょうか?
tmp <- c() for (i in 1:5) { tmp <- bind_rows(tmp, get_var_pairs(i, result)) }
> dim(tmp) [1] 2950 2 > head(tmp) from_var to_var 1 Contract OnlineBackup 2 OnlineBackup tenure 3 TechSupport Dependents 4 tenure InternetService 5 InternetService Dependents 6 Dependents TotalCharges
合っているようなので全ての結果を取得します。 randomForest
では作成する決定木の数を ntree
で指定するので 1:ntree
で全ての決定木に適用できるのですが、直接 ntree
を取ってくることは出来ないようなので length(treesize())
を使います。エラーは気にしないことにします。
var_pairs <- map_dfr(as.matrix(1:length(treesize(result))), get_var_pairs, result)
30万弱の変数ペアが得られました。
3. 出現回数のカウントを取る
さっそく変数ペアのカウントを取ってみましょう。
> var_pairs %>% + group_by(from_var) %>% + summarise(cnt = n()) %>% + arrange(desc(cnt)) # A tibble: 19 x 2 from_var cnt <chr> <int> 1 TotalCharges 36908 2 MonthlyCharges 36592 3 tenure 34132 4 PaymentMethod 21244 5 gender 17626 6 MultipleLines 14496 # 中略
分岐の始点となった変数(分岐元)の数は19ですが、分析対象のデータセットには目的変数を含んで20列だったので、分岐元とならない変数はなかったようです。一方で分岐元としての出現頻度には大きなばらつきがあり、 TotalCharges
、 MonthlyCharges
、 tenure
が選ばれやすいようですね。
ちなみに varImpPlot
で変数重要度を見てみると、これらはいずれも上位に付けており、4位以下と大きな隔たりがあるようです。
varImpPlot(result)
続いて分岐の終点となった変数(分岐先)についても見てみましょう。
> var_pairs %>% + group_by(to_var) %>% + # group_by(from_var, to_var) %>% + summarise(cnt = n()) %>% + arrange(desc(cnt)) # A tibble: 19 x 2 to_var cnt <chr> <int> 1 TotalCharges 47553 2 MonthlyCharges 47137 3 tenure 37964 4 PaymentMethod 20138 5 gender 14242 6 Partner 12567 # 中略
同じく19変数ありますので、全ての変数は分岐元・分岐先ともに出現しています。分岐元と同じく出現頻度はばらつきがあり、出現しやすい変数としては、 TotalCharges
、 MonthlyCharges
、 tenure
となっています。これは少し意外ですね。てっきり分岐元に選ばれやすい変数と分岐先に選ばれやすい変数は違うものになると思っていましたが。
せっかくなので分岐元と分岐先で選ばれやすさが異なるか、可視化してみましょう。
plt <- var_pairs %>% group_by(from_var) %>% summarise(cnt_f = n()) %>% left_join(var_pairs %>% group_by(to_var) %>% summarise(cnt_t = n()), by = c("from_var" = "to_var")) %>% gather(var_type, cnt, -from_var) %>% rename(var = from_var) ggplot(plt, aes(x = reorder(var, -cnt), y = cnt, fill = var_type)) + geom_bar(stat = "identity", position = "dodge") + # scale_color_brewer(palette = "Set2") + # facet_wrap(~var_type, nrow = 2) + theme_classic() + theme(axis.text.x = element_text(angle = 90, hjust = 1)) + NULL
上位3変数は分岐元として選ばれやすいですが、分岐先としては更に頻度が多くなっています。また変数間の順位にはほとんど変動はないようですね。
組み合わせでも見てみましょう。
> var_pairs %>% + group_by(from_var, to_var) %>% + summarise(cnt = n()) %>% + arrange(desc(cnt)) # A tibble: 355 x 3 # Groups: from_var [19] from_var to_var cnt <chr> <chr> <int> 1 MonthlyCharges TotalCharges 6032 2 TotalCharges MonthlyCharges 6018 3 TotalCharges TotalCharges 5858 4 MonthlyCharges MonthlyCharges 5623 5 tenure MonthlyCharges 5602 6 tenure TotalCharges 5407 # 中略
上位10件の組み合わせを見ると、9行目の PaymentMethod
を除いていずれも上位3変数の組み合わせになっています。同じ変数のペアも出てきているので、レンジを絞る形で分岐条件として選ばれているようですね、なるほど。なお19 * 19 = 361なので発生していない組み合わせがあるようですが、一部ですね。
4. 交互作用効果を確かめてみる
ひとまず目的としていた分析は以上となります。今回のデータセットおよび分析条件を用いた場合、 TotalCharges
、 MonthlyCharges
、 tenure
の3変数が(組み合わせの意味でも、重要度の意味でも)影響の大きい変数であるようです。したがって冒頭のようなクライアントからの指摘があった場合には、特にこの3変数を中心に他の変数との交互作用を確認していくと良いのではないでしょうか。
と、ここまで書いたところで一つ疑問が浮かんできました。このような組み合わせ効果は、GLMでも発見できないでしょうか?例えば二次の交互作用項を準備しておき、Lassoで変数選択させるとこれらの組み合わせが残らないでしょうか?
やってみましょう。yとxを用意します。
library(glmnet) # glmnetを使う y <- as.matrix(ifelse(d3$Churn == "Yes", 1, 0)) tmp <- scale(model.matrix(Churn ~ .^2 , d3)) x_vars <- which(colSums(apply(tmp, c(1,2), is.nan)) == 0) x <- cbind(1, tmp[, x_vars])
> dim(x) [1] 7043 403
Lassoは回帰係数の絶対値に対して罰則が与えられるため、説明変数のスケールを揃えておく必要があります。そのため model.matrix
でダミー化したあと scale
で正規化しています。またその際に分散が0であるために NaN
となってしまう変数は除外しています。
このデータでLassoにかけてみましょう。まずは適切な lambda
を得るために cv.glmnet
を使いますが、計算時間が少しかかるため nfolds
は5にしておきましょう。
res_lasso_cv <- cv.glmnet(x = x, y = y, family = "binomial", alpha = 1, nfolds = 5)
> res_lasso_cv$lambda.min [1] 0.004130735
このときにDevianceが最小となる lambda
は0.004130735のようです。プロットも見ておきましょう。
plot(res_lasso_cv)
ここで縦に引かれた二本の破線はそれぞれ lambda.min
および lambda.1se
を表しています。 lambda.1se
は lambda.min
の1SD以内で最も罰則を与えたときの lambda
を示すようですね。詳しくは以下を参照してください。
https://web.stanford.edu/~hastie/glmnet/glmnet_alpha.html
ひとまず lambda.min
を与えたときの結果を確認しましょう。
res_lasso <- glmnet(x = x, y = y, family = "binomial", alpha = 1, lambda = res_lasso_cv$lambda.min)
# このときの回帰係数の絶対値 > as.data.frame(as.matrix(res_lasso$beta)) %>% + rownames_to_column() %>% + filter(s0 != 0) %>% + mutate(abs_beta = abs(s0)) %>% + arrange(desc(abs_beta)) %>% + select(rowname, abs_beta, s0) rowname abs_beta s0 1 tenure 7.708817e-01 -7.708817e-01 2 ContractTwo year 4.658562e-01 -4.658562e-01 3 ContractOne year 2.723315e-01 -2.723315e-01 4 InternetServiceFiber optic 2.640571e-01 2.640571e-01 5 InternetServiceNo 2.202043e-01 -2.202043e-01 6 tenure:PaymentMethodMailed check 1.714972e-01 -1.714972e-01 # 中略
むむ。。Lassoにおいても tenure
は影響の大きい(回帰係数の絶対値が大きい)変数として選ばれましたが、 TotalCharges
と MonthlyCharges
はいませんね。
っていうか、
> as.data.frame(as.matrix(res_lasso$beta)) %>% + rownames_to_column() %>% + filter(rowname %in% c("TotalCharges", "MonthlyCharges")) rowname s0 1 MonthlyCharges 0 2 TotalCharges 0
Lassoで落とされてる。。。
TotalCharges
がどのような影響を示すのか partialPlot
で見てみましょう。
partialPlot(result, as.data.frame(d3), "TotalCharges") # tibbleのままだとエラーになる
グニャグニャですね。特定のレンジでは影響が大きいものの、他ではそうでもないということなんでしょうか。だからRandom Forestのような非線形のアルゴリズムだと効果が認められる一方、Lassoのような線形のアルゴリズムでは拾いきれないのかもしれません2。これは素直に、Random Forestの結果から効果のありそうな組み合わせ変数を見つけ、分布を見ながら組み込んだ方が良さそうです。
しかしy軸が1を超えるのはなぜなんでしょうか。。。
終わりに
今回の分析はRandom Forestの結果から交互作用の良い候補を見つけようという趣旨でした。また同様の結果がLassoからも得られるかを検証しましたが、両者の結果は異なるものとなりました。Random Forestは非線形な効果を捉えることができるアルゴリズムなのでこちらの結果から有効な変数ペアを絞り込み、一つずつ検証していくスタイルが良さそうです。