統計コンサルの議事メモ

統計や機械学習の話題を中心に、思うがままに

randomForestで有効な交互作用を発見したい

背景

GLMは発想がわかりやすく解釈性も高くて良いアルゴリズム1なのですが、線形の仮定があるため変数間の交互作用を見るのが苦手です。実際のプロジェクトでGLMを使った結果を見せ、

  • 変数の組み合わせ効果みたいなものは見れないの?
  • この変数は条件によって効き方が違うんだよね〜

みたいな指摘を受けて困った経験があったりしないでしょうか。そんな時に使えるテクニックを同僚から教えてもらったので、備忘がてらメモしておきます。勝手に公開して怒られる可能性もありますが。。。

概要

手法の概要ですが、話としてはシンプルで「もしも有効な変数の組み合わせ(交互作用)が存在しているのであれば、Random Forestの各決定木において、ノードの分岐に使われる変数の順番として出現しやすいペアがあるのではないか」ということです。例えば変数X1とX2の間に交互作用があれば、決定木においてX1が選択された場合、続く分岐ではX2が選択されやすくなるのではないでしょうか。

実装

上記のアイディアを実現するために、以下のように実装してみます:

  1. Random Forestでモデルを作る
  2. 各決定木から分岐に用いられた変数ペアを得る
  3. 出現回数のカウントを取る
  4. 交互作用効果を確かめてみる

1. Random Forestでモデルを作る

まずはRandom Forestでモデルを作ります。randomForestパッケージを使ってサクッと作りましょう。

### libraryの読み込み
library(randomForest)
library(tidyverse)

データには前回記事と同じTelco Customer Churnを使いますが、前回の反省を踏まえてread.csvを使います。

前回の記事はこちら。

ushi-goroshi.hatenablog.com

d <- read.csv("./Data/WA_Fn-UseC_-Telco-Customer-Churn.csv") %>% as_data_frame()

本来ならここから一つ一つの変数を観察するところですが、今回はそれが目的ではないので欠損だけ埋めておきます。

> colSums(apply(d, c(1, 2), is.na))
      customerID           gender    SeniorCitizen          Partner       Dependents           tenure 
               0                0                0                0                0                0 
    PhoneService    MultipleLines  InternetService   OnlineSecurity     OnlineBackup DeviceProtection 
               0                0                0                0                0                0 
     TechSupport      StreamingTV  StreamingMovies         Contract PaperlessBilling    PaymentMethod 
               0                0                0                0                0                0 
  MonthlyCharges     TotalCharges            Churn 
               0               11                0 

TotalChargesに欠損があるようですね。

> summary(d$TotalCharges)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max.    NA's 
   18.8   401.4  1397.5  2283.3  3794.7  8684.8      11 

MedianとMeanに差があるので分布が偏っていそうです。ひとまずNAはMedianで埋めておきましょう。

d2 <- 
   d %>% 
   mutate(TotalCharges = if_else(is.na(.$TotalCharges),
                                 median(.$TotalCharges, na.rm = T),
                                 .$TotalCharges))

customerIDは変数として使えないので除外しましょう。またrandomForestはカテゴリ数が53より多い変数を扱えないので、カテゴリ数をチェックしておきます。

cat_vars <- sapply(d2, is.factor)
> apply(d2[, cat_vars], 2, function(x) length(unique(x)))
      customerID           gender          Partner       Dependents     PhoneService    MultipleLines 
            7043                2                2                2                2                3 
 InternetService   OnlineSecurity     OnlineBackup DeviceProtection      TechSupport      StreamingTV 
               3                3                3                3                3                3 
 StreamingMovies         Contract PaperlessBilling    PaymentMethod            Churn 
               3                3                2                4                2 

大丈夫そうですね。customerIDだけ落としておきます。

d3 <- 
   d2 %>% 
   select(-customerID)

randomForestを当てはめます。目的変数はChurnです。

set.seed(123)
result <- randomForest(Churn ~ ., d3, ntree = 500)
> result

Call:
 randomForest(formula = Churn ~ ., data = d3, ntree = 500) 
               Type of random forest: classification
                     Number of trees: 500
No. of variables tried at each split: 4

        OOB estimate of  error rate: 20.3%
Confusion matrix:
      No Yes class.error
No  4643 531   0.1026285
Yes  899 970   0.4810059

精度とかは気にしません。

2. 各決定木から分岐に用いられた変数ペアを得る

次にRandom Forestから変数ペアを取得します。そのためにはRandom Forestの各決定木について、どの変数がどの順番で分岐に用いられたかを知る必要があります。

まずはRandom Forestから各決定木の結果を取ってきましょう。 getTree 関数を使います。過去記事も参考にしてください。

ushi-goroshi.hatenablog.com

> getTree(result, 1, labelVar = TRUE)
     left daughter right daughter        split var split point status prediction
1                2              3         Contract       1.000      1       <NA>
2                4              5     OnlineBackup       1.000      1       <NA>
3                6              7      TechSupport       2.000      1       <NA>
4                8              9           tenure       2.500      1       <NA>
5               10             11  InternetService       2.000      1       <NA>
6               12             13       Dependents       1.000      1       <NA>
# 中略

getTree は分岐に用いられた変数とその分岐先ノードなどを返します。ここで必要なのは1~3列目なのですが、それぞれ「左の子ノード」「右の子ノード」「分岐に用いられた変数」を意味しています。例えば1行目を見るとノードの分岐に Contract が用いられたことがわかります。

このテーブルに行インデックスを列として追加しましょう。なおこれ以降のコードはこちらの記事を参考にさせて頂きました。

tree_tbl <- getTree(result, 1, labelVar = TRUE) %>% # labelVar = Fだとエラー
   rownames_to_column() %>%
   mutate(rowname = as.integer(rowname))

作成した tree_tbl では分岐に用いた変数( split var )はわかりますが、変数のペアはわかりません。例えば1行目の Contract で分岐された子ノード(2と3)は、次にそれぞれが OnlineBackupTechSupport で分岐されています(2行目、3行目)ので、 Contract - OnlineBackupContract - TechSupport という変数ペアが出現したことがわかるような形に整形したいですね。

各行には「分岐に用いた変数」と「分岐先の子ノードの番号」がありますので、「分岐先の子ノード(左右両方)」に「分岐の変数」を追加すれば欲しいものが得られそうです。

まずはノードと変数のマスタを用意しましょう。

var_name <- 
   tree_tbl %>% 
   select(rowname, "split var") %>% 
   rename(split_var =`split var`) %>% # スペースを`_`に修正
   unique() %>% 
   filter(!is.na(.$split_var))

続けて left daughterright daughter それぞれに rowname でJOINします。

> tree_tbl %>% 
+     left_join(var_name, by = c("left daughter" = "rowname")) %>% 
+     left_join(var_name, by = c("right daughter" = "rowname")) %>% 
+     select(rowname, `split var`, `split_var.x`, `split_var.y`) %>% 
+     na.omit()
     rowname        split var      split_var.x      split_var.y
1          1         Contract     OnlineBackup      TechSupport
2          2     OnlineBackup           tenure  InternetService
3          3      TechSupport       Dependents  StreamingMovies
4          4           tenure  InternetService  InternetService
5          5  InternetService       Dependents   OnlineSecurity
6          6       Dependents     TotalCharges           tenure
# 中略

良さそうですね。ただしこのままでは後の工程で使いにくいのでもう少し加工します。

> tree_tbl %>% 
+     left_join(var_name, by = c("left daughter" = "rowname")) %>% 
+     left_join(var_name, by = c("right daughter" = "rowname")) %>% 
+     select(`split var`, `split_var.x`, `split_var.y`) %>% 
+     na.omit() %>% 
+     rename(from_var = `split var`, 
+            left = `split_var.x`, 
+            right = `split_var.y`) %>% 
+     gather(key = node, value = to_var, -from_var) %>% 
+     select(-node)
            from_var           to_var
1           Contract     OnlineBackup
2       OnlineBackup           tenure
3        TechSupport       Dependents
4             tenure  InternetService
5    InternetService       Dependents
6         Dependents     TotalCharges
# 中略

これで必要なアウトプットが得られました。あとは上記の加工を関数化しておき、各決定木に当てはめれば良さそうです。

get_var_pairs <- function(tree_num, rf = result) {
   
   # 決定木の結果を得る
   tree_tbl <- getTree(rf, tree_num, labelVar = TRUE) %>%
      rownames_to_column() %>%
      mutate(rowname = as.integer(rowname))
   
   var_name <- 
      tree_tbl %>% 
      select(rowname, "split var") %>% 
      rename(split_var =`split var`) %>% # スペースを`_`に修正
      unique() %>% 
      filter(!is.na(.$split_var))
   
   out <- 
      tree_tbl %>% 
      left_join(var_name, by = c("left daughter" = "rowname")) %>% 
      left_join(var_name, by = c("right daughter" = "rowname")) %>% 
      select(`split var`, `split_var.x`, `split_var.y`) %>% 
      na.omit() %>% 
      rename(from_var = `split var`, 
             left = `split_var.x`, 
             right = `split_var.y`) %>% 
      gather(key = node, value = to_var, -from_var) %>% 
      select(-node)
   
   return(out)
}

試してみましょう。

> get_var_pairs(5, result)
            from_var           to_var
1      PaymentMethod         Contract
2           Contract      TechSupport
3     OnlineSecurity           tenure
4        TechSupport PaperlessBilling
5             tenure      TechSupport
6       TotalCharges      StreamingTV
# 中略

これを全ての決定木に当てはめます。 purrrmap_dfr を使ってみましょう。

var_pairs <- map_dfr(as.matrix(1:5), get_var_pairs, result) # 少しだけ実行
> dim(var_pairs)
[1] 2950    2

> head(var_pairs)
         from_var          to_var
1        Contract    OnlineBackup
2    OnlineBackup          tenure
3     TechSupport      Dependents
4          tenure InternetService
5 InternetService      Dependents
6      Dependents    TotalCharges

素直に for で書いた時と同じ結果になっていますでしょうか?

tmp <- c()
for (i in 1:5) {
   tmp <- bind_rows(tmp, get_var_pairs(i, result))
}
> dim(tmp)
[1] 2950    2

> head(tmp)
         from_var          to_var
1        Contract    OnlineBackup
2    OnlineBackup          tenure
3     TechSupport      Dependents
4          tenure InternetService
5 InternetService      Dependents
6      Dependents    TotalCharges

合っているようなので全ての結果を取得します。 randomForest では作成する決定木の数を ntree で指定するので 1:ntree で全ての決定木に適用できるのですが、直接 ntree を取ってくることは出来ないようなので length(treesize()) を使います。エラーは気にしないことにします。

var_pairs <- map_dfr(as.matrix(1:length(treesize(result))), get_var_pairs, result)

30万弱の変数ペアが得られました。

3. 出現回数のカウントを取る

さっそく変数ペアのカウントを取ってみましょう。

> var_pairs %>% 
+     group_by(from_var) %>%
+     summarise(cnt = n()) %>% 
+     arrange(desc(cnt))
# A tibble: 19 x 2
   from_var           cnt
   <chr>            <int>
 1 TotalCharges     36908
 2 MonthlyCharges   36592
 3 tenure           34132
 4 PaymentMethod    21244
 5 gender           17626
 6 MultipleLines    14496
# 中略

分岐の始点となった変数(分岐元)の数は19ですが、分析対象のデータセットには目的変数を含んで20列だったので、分岐元とならない変数はなかったようです。一方で分岐元としての出現頻度には大きなばらつきがあり、 TotalChargesMonthlyChargestenure が選ばれやすいようですね。

ちなみに varImpPlot で変数重要度を見てみると、これらはいずれも上位に付けており、4位以下と大きな隔たりがあるようです。

varImpPlot(result)

f:id:ushi-goroshi:20190206152054p:plain

続いて分岐の終点となった変数(分岐先)についても見てみましょう。

> var_pairs %>% 
+     group_by(to_var) %>%
+     # group_by(from_var, to_var) %>% 
+     summarise(cnt = n()) %>% 
+     arrange(desc(cnt))
# A tibble: 19 x 2
   to_var             cnt
   <chr>            <int>
 1 TotalCharges     47553
 2 MonthlyCharges   47137
 3 tenure           37964
 4 PaymentMethod    20138
 5 gender           14242
 6 Partner          12567
# 中略

同じく19変数ありますので、全ての変数は分岐元・分岐先ともに出現しています。分岐元と同じく出現頻度はばらつきがあり、出現しやすい変数としては、 TotalChargesMonthlyChargestenure となっています。これは少し意外ですね。てっきり分岐元に選ばれやすい変数と分岐先に選ばれやすい変数は違うものになると思っていましたが。

せっかくなので分岐元と分岐先で選ばれやすさが異なるか、可視化してみましょう。

plt <- 
   var_pairs %>% 
   group_by(from_var) %>%
   summarise(cnt_f = n()) %>% 
   left_join(var_pairs %>% group_by(to_var) %>% summarise(cnt_t = n()),
             by = c("from_var" = "to_var")) %>%
   gather(var_type, cnt, -from_var) %>% 
   rename(var = from_var)

ggplot(plt, aes(x = reorder(var, -cnt), y = cnt, fill = var_type)) +
   geom_bar(stat = "identity", position = "dodge") +
   # scale_color_brewer(palette = "Set2") +
   # facet_wrap(~var_type, nrow = 2) +
   theme_classic() +
   theme(axis.text.x = element_text(angle = 90, hjust = 1)) +
   NULL

f:id:ushi-goroshi:20190206152252p:plain

上位3変数は分岐元として選ばれやすいですが、分岐先としては更に頻度が多くなっています。また変数間の順位にはほとんど変動はないようですね。

組み合わせでも見てみましょう。

> var_pairs %>% 
+     group_by(from_var, to_var) %>%
+     summarise(cnt = n()) %>% 
+     arrange(desc(cnt))
# A tibble: 355 x 3
# Groups:   from_var [19]
   from_var       to_var           cnt
   <chr>          <chr>          <int>
 1 MonthlyCharges TotalCharges    6032
 2 TotalCharges   MonthlyCharges  6018
 3 TotalCharges   TotalCharges    5858
 4 MonthlyCharges MonthlyCharges  5623
 5 tenure         MonthlyCharges  5602
 6 tenure         TotalCharges    5407
# 中略

上位10件の組み合わせを見ると、9行目の PaymentMethod を除いていずれも上位3変数の組み合わせになっています。同じ変数のペアも出てきているので、レンジを絞る形で分岐条件として選ばれているようですね、なるほど。なお19 * 19 = 361なので発生していない組み合わせがあるようですが、一部ですね。

4. 交互作用効果を確かめてみる

ひとまず目的としていた分析は以上となります。今回のデータセットおよび分析条件を用いた場合、 TotalChargesMonthlyChargestenure の3変数が(組み合わせの意味でも、重要度の意味でも)影響の大きい変数であるようです。したがって冒頭のようなクライアントからの指摘があった場合には、特にこの3変数を中心に他の変数との交互作用を確認していくと良いのではないでしょうか。

と、ここまで書いたところで一つ疑問が浮かんできました。このような組み合わせ効果は、GLMでも発見できないでしょうか?例えば二次の交互作用項を準備しておき、Lassoで変数選択させるとこれらの組み合わせが残らないでしょうか?

やってみましょう。yとxを用意します。

library(glmnet) # glmnetを使う
y <- as.matrix(ifelse(d3$Churn == "Yes", 1, 0))
tmp <- scale(model.matrix(Churn ~ .^2 , d3))
x_vars <- which(colSums(apply(tmp, c(1,2), is.nan)) == 0)
x <- cbind(1, tmp[, x_vars])
> dim(x)
[1] 7043  403

Lassoは回帰係数の絶対値に対して罰則が与えられるため、説明変数のスケールを揃えておく必要があります。そのため model.matrix でダミー化したあと scale で正規化しています。またその際に分散が0であるために NaN となってしまう変数は除外しています。

このデータでLassoにかけてみましょう。まずは適切な lambda を得るために cv.glmnet を使いますが、計算時間が少しかかるため nfolds は5にしておきましょう。

res_lasso_cv <- cv.glmnet(x = x, y = y, family = "binomial", alpha = 1, nfolds = 5)
> res_lasso_cv$lambda.min
[1] 0.004130735

このときにDevianceが最小となる lambda は0.004130735のようです。プロットも見ておきましょう。

plot(res_lasso_cv)

f:id:ushi-goroshi:20190206152638p:plain

ここで縦に引かれた二本の破線はそれぞれ lambda.min および lambda.1se を表しています。 lambda.1selambda.min の1SD以内で最も罰則を与えたときの lambda を示すようですね。詳しくは以下を参照してください。

https://web.stanford.edu/~hastie/glmnet/glmnet_alpha.html

ひとまず lambda.min を与えたときの結果を確認しましょう。

res_lasso <- glmnet(x = x, y = y, family = "binomial", alpha = 1,
                    lambda = res_lasso_cv$lambda.min)
# このときの回帰係数の絶対値
> as.data.frame(as.matrix(res_lasso$beta)) %>% 
+     rownames_to_column() %>% 
+     filter(s0 != 0) %>% 
+     mutate(abs_beta = abs(s0)) %>% 
+     arrange(desc(abs_beta)) %>% 
+     select(rowname, abs_beta, s0)
                                                                  rowname     abs_beta            s0
1                                                                  tenure 7.708817e-01 -7.708817e-01
2                                                        ContractTwo year 4.658562e-01 -4.658562e-01
3                                                        ContractOne year 2.723315e-01 -2.723315e-01
4                                              InternetServiceFiber optic 2.640571e-01  2.640571e-01
5                                                       InternetServiceNo 2.202043e-01 -2.202043e-01
6                                        tenure:PaymentMethodMailed check 1.714972e-01 -1.714972e-01
# 中略

むむ。。Lassoにおいても tenure は影響の大きい(回帰係数の絶対値が大きい)変数として選ばれましたが、 TotalChargesMonthlyCharges はいませんね。

っていうか、

> as.data.frame(as.matrix(res_lasso$beta)) %>% 
+     rownames_to_column() %>% 
+     filter(rowname %in% c("TotalCharges", "MonthlyCharges"))
         rowname s0
1 MonthlyCharges  0
2   TotalCharges  0

Lassoで落とされてる。。。

TotalCharges がどのような影響を示すのか partialPlot で見てみましょう。

partialPlot(result, as.data.frame(d3), "TotalCharges") # tibbleのままだとエラーになる

f:id:ushi-goroshi:20190206153106p:plain

グニャグニャですね。特定のレンジでは影響が大きいものの、他ではそうでもないということなんでしょうか。だからRandom Forestのような非線形アルゴリズムだと効果が認められる一方、Lassoのような線形のアルゴリズムでは拾いきれないのかもしれません2。これは素直に、Random Forestの結果から効果のありそうな組み合わせ変数を見つけ、分布を見ながら組み込んだ方が良さそうです。

しかしy軸が1を超えるのはなぜなんでしょうか。。。

終わりに

今回の分析はRandom Forestの結果から交互作用の良い候補を見つけようという趣旨でした。また同様の結果がLassoからも得られるかを検証しましたが、両者の結果は異なるものとなりました。Random Forestは非線形な効果を捉えることができるアルゴリズムなのでこちらの結果から有効な変数ペアを絞り込み、一つずつ検証していくスタイルが良さそうです。


  1. 余談ですがGLMをアルゴリズムと呼ぶのは少し抵抗があります

  2. もちろん加法モデルのようにxに非線形な変換を施すことで捉えにいく方法もあるでしょうけども