統計コンサルの議事メモ

統計や機械学習の話題を中心に、思うがままに

randomForestで有効な交互作用を発見したい

背景

GLMは発想がわかりやすく解釈性も高くて良いアルゴリズム1なのですが、線形の仮定があるため変数間の交互作用を見るのが苦手です。実際のプロジェクトでGLMを使った結果を見せ、

  • 変数の組み合わせ効果みたいなものは見れないの?
  • この変数は条件によって効き方が違うんだよね〜

みたいな指摘を受けて困った経験があったりしないでしょうか。そんな時に使えるテクニックを同僚から教えてもらったので、備忘がてらメモしておきます。勝手に公開して怒られる可能性もありますが。。。

概要

手法の概要ですが、話としてはシンプルで「もしも有効な変数の組み合わせ(交互作用)が存在しているのであれば、Random Forestの各決定木において、ノードの分岐に使われる変数の順番として出現しやすいペアがあるのではないか」ということです。例えば変数X1とX2の間に交互作用があれば、決定木においてX1が選択された場合、続く分岐ではX2が選択されやすくなるのではないでしょうか。

実装

上記のアイディアを実現するために、以下のように実装してみます:

  1. Random Forestでモデルを作る
  2. 各決定木から分岐に用いられた変数ペアを得る
  3. 出現回数のカウントを取る
  4. 交互作用効果を確かめてみる

1. Random Forestでモデルを作る

まずはRandom Forestでモデルを作ります。randomForestパッケージを使ってサクッと作りましょう。

### libraryの読み込み
library(randomForest)
library(tidyverse)

データには前回記事と同じTelco Customer Churnを使いますが、前回の反省を踏まえてread.csvを使います。

前回の記事はこちら。

ushi-goroshi.hatenablog.com

d <- read.csv("./Data/WA_Fn-UseC_-Telco-Customer-Churn.csv") %>% as_data_frame()

本来ならここから一つ一つの変数を観察するところですが、今回はそれが目的ではないので欠損だけ埋めておきます。

> colSums(apply(d, c(1, 2), is.na))
      customerID           gender    SeniorCitizen          Partner       Dependents           tenure 
               0                0                0                0                0                0 
    PhoneService    MultipleLines  InternetService   OnlineSecurity     OnlineBackup DeviceProtection 
               0                0                0                0                0                0 
     TechSupport      StreamingTV  StreamingMovies         Contract PaperlessBilling    PaymentMethod 
               0                0                0                0                0                0 
  MonthlyCharges     TotalCharges            Churn 
               0               11                0 

TotalChargesに欠損があるようですね。

> summary(d$TotalCharges)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max.    NA's 
   18.8   401.4  1397.5  2283.3  3794.7  8684.8      11 

MedianとMeanに差があるので分布が偏っていそうです。ひとまずNAはMedianで埋めておきましょう。

d2 <- 
   d %>% 
   mutate(TotalCharges = if_else(is.na(.$TotalCharges),
                                 median(.$TotalCharges, na.rm = T),
                                 .$TotalCharges))

customerIDは変数として使えないので除外しましょう。またrandomForestはカテゴリ数が53より多い変数を扱えないので、カテゴリ数をチェックしておきます。

cat_vars <- sapply(d2, is.factor)
> apply(d2[, cat_vars], 2, function(x) length(unique(x)))
      customerID           gender          Partner       Dependents     PhoneService    MultipleLines 
            7043                2                2                2                2                3 
 InternetService   OnlineSecurity     OnlineBackup DeviceProtection      TechSupport      StreamingTV 
               3                3                3                3                3                3 
 StreamingMovies         Contract PaperlessBilling    PaymentMethod            Churn 
               3                3                2                4                2 

大丈夫そうですね。customerIDだけ落としておきます。

d3 <- 
   d2 %>% 
   select(-customerID)

randomForestを当てはめます。目的変数はChurnです。

set.seed(123)
result <- randomForest(Churn ~ ., d3, ntree = 500)
> result

Call:
 randomForest(formula = Churn ~ ., data = d3, ntree = 500) 
               Type of random forest: classification
                     Number of trees: 500
No. of variables tried at each split: 4

        OOB estimate of  error rate: 20.3%
Confusion matrix:
      No Yes class.error
No  4643 531   0.1026285
Yes  899 970   0.4810059

精度とかは気にしません。

2. 各決定木から分岐に用いられた変数ペアを得る

次にRandom Forestから変数ペアを取得します。そのためにはRandom Forestの各決定木について、どの変数がどの順番で分岐に用いられたかを知る必要があります。

まずはRandom Forestから各決定木の結果を取ってきましょう。 getTree 関数を使います。過去記事も参考にしてください。

ushi-goroshi.hatenablog.com

> getTree(result, 1, labelVar = TRUE)
     left daughter right daughter        split var split point status prediction
1                2              3         Contract       1.000      1       <NA>
2                4              5     OnlineBackup       1.000      1       <NA>
3                6              7      TechSupport       2.000      1       <NA>
4                8              9           tenure       2.500      1       <NA>
5               10             11  InternetService       2.000      1       <NA>
6               12             13       Dependents       1.000      1       <NA>
# 中略

getTree は分岐に用いられた変数とその分岐先ノードなどを返します。ここで必要なのは1~3列目なのですが、それぞれ「左の子ノード」「右の子ノード」「分岐に用いられた変数」を意味しています。例えば1行目を見るとノードの分岐に Contract が用いられたことがわかります。

このテーブルに行インデックスを列として追加しましょう。なおこれ以降のコードはこちらの記事を参考にさせて頂きました。

tree_tbl <- getTree(result, 1, labelVar = TRUE) %>% # labelVar = Fだとエラー
   rownames_to_column() %>%
   mutate(rowname = as.integer(rowname))

作成した tree_tbl では分岐に用いた変数( split var )はわかりますが、変数のペアはわかりません。例えば1行目の Contract で分岐された子ノード(2と3)は、次にそれぞれが OnlineBackupTechSupport で分岐されています(2行目、3行目)ので、 Contract - OnlineBackupContract - TechSupport という変数ペアが出現したことがわかるような形に整形したいですね。

各行には「分岐に用いた変数」と「分岐先の子ノードの番号」がありますので、「分岐先の子ノード(左右両方)」に「分岐の変数」を追加すれば欲しいものが得られそうです。

まずはノードと変数のマスタを用意しましょう。

var_name <- 
   tree_tbl %>% 
   select(rowname, "split var") %>% 
   rename(split_var =`split var`) %>% # スペースを`_`に修正
   unique() %>% 
   filter(!is.na(.$split_var))

続けて left daughterright daughter それぞれに rowname でJOINします。

> tree_tbl %>% 
+     left_join(var_name, by = c("left daughter" = "rowname")) %>% 
+     left_join(var_name, by = c("right daughter" = "rowname")) %>% 
+     select(rowname, `split var`, `split_var.x`, `split_var.y`) %>% 
+     na.omit()
     rowname        split var      split_var.x      split_var.y
1          1         Contract     OnlineBackup      TechSupport
2          2     OnlineBackup           tenure  InternetService
3          3      TechSupport       Dependents  StreamingMovies
4          4           tenure  InternetService  InternetService
5          5  InternetService       Dependents   OnlineSecurity
6          6       Dependents     TotalCharges           tenure
# 中略

良さそうですね。ただしこのままでは後の工程で使いにくいのでもう少し加工します。

> tree_tbl %>% 
+     left_join(var_name, by = c("left daughter" = "rowname")) %>% 
+     left_join(var_name, by = c("right daughter" = "rowname")) %>% 
+     select(`split var`, `split_var.x`, `split_var.y`) %>% 
+     na.omit() %>% 
+     rename(from_var = `split var`, 
+            left = `split_var.x`, 
+            right = `split_var.y`) %>% 
+     gather(key = node, value = to_var, -from_var) %>% 
+     select(-node)
            from_var           to_var
1           Contract     OnlineBackup
2       OnlineBackup           tenure
3        TechSupport       Dependents
4             tenure  InternetService
5    InternetService       Dependents
6         Dependents     TotalCharges
# 中略

これで必要なアウトプットが得られました。あとは上記の加工を関数化しておき、各決定木に当てはめれば良さそうです。

get_var_pairs <- function(tree_num, rf = result) {
   
   # 決定木の結果を得る
   tree_tbl <- getTree(rf, tree_num, labelVar = TRUE) %>%
      rownames_to_column() %>%
      mutate(rowname = as.integer(rowname))
   
   var_name <- 
      tree_tbl %>% 
      select(rowname, "split var") %>% 
      rename(split_var =`split var`) %>% # スペースを`_`に修正
      unique() %>% 
      filter(!is.na(.$split_var))
   
   out <- 
      tree_tbl %>% 
      left_join(var_name, by = c("left daughter" = "rowname")) %>% 
      left_join(var_name, by = c("right daughter" = "rowname")) %>% 
      select(`split var`, `split_var.x`, `split_var.y`) %>% 
      na.omit() %>% 
      rename(from_var = `split var`, 
             left = `split_var.x`, 
             right = `split_var.y`) %>% 
      gather(key = node, value = to_var, -from_var) %>% 
      select(-node)
   
   return(out)
}

試してみましょう。

> get_var_pairs(5, result)
            from_var           to_var
1      PaymentMethod         Contract
2           Contract      TechSupport
3     OnlineSecurity           tenure
4        TechSupport PaperlessBilling
5             tenure      TechSupport
6       TotalCharges      StreamingTV
# 中略

これを全ての決定木に当てはめます。 purrrmap_dfr を使ってみましょう。

var_pairs <- map_dfr(as.matrix(1:5), get_var_pairs, result) # 少しだけ実行
> dim(var_pairs)
[1] 2950    2

> head(var_pairs)
         from_var          to_var
1        Contract    OnlineBackup
2    OnlineBackup          tenure
3     TechSupport      Dependents
4          tenure InternetService
5 InternetService      Dependents
6      Dependents    TotalCharges

素直に for で書いた時と同じ結果になっていますでしょうか?

tmp <- c()
for (i in 1:5) {
   tmp <- bind_rows(tmp, get_var_pairs(i, result))
}
> dim(tmp)
[1] 2950    2

> head(tmp)
         from_var          to_var
1        Contract    OnlineBackup
2    OnlineBackup          tenure
3     TechSupport      Dependents
4          tenure InternetService
5 InternetService      Dependents
6      Dependents    TotalCharges

合っているようなので全ての結果を取得します。 randomForest では作成する決定木の数を ntree で指定するので 1:ntree で全ての決定木に適用できるのですが、直接 ntree を取ってくることは出来ないようなので length(treesize()) を使います。エラーは気にしないことにします。

var_pairs <- map_dfr(as.matrix(1:length(treesize(result))), get_var_pairs, result)

30万弱の変数ペアが得られました。

3. 出現回数のカウントを取る

さっそく変数ペアのカウントを取ってみましょう。

> var_pairs %>% 
+     group_by(from_var) %>%
+     summarise(cnt = n()) %>% 
+     arrange(desc(cnt))
# A tibble: 19 x 2
   from_var           cnt
   <chr>            <int>
 1 TotalCharges     36908
 2 MonthlyCharges   36592
 3 tenure           34132
 4 PaymentMethod    21244
 5 gender           17626
 6 MultipleLines    14496
# 中略

分岐の始点となった変数(分岐元)の数は19ですが、分析対象のデータセットには目的変数を含んで20列だったので、分岐元とならない変数はなかったようです。一方で分岐元としての出現頻度には大きなばらつきがあり、 TotalChargesMonthlyChargestenure が選ばれやすいようですね。

ちなみに varImpPlot で変数重要度を見てみると、これらはいずれも上位に付けており、4位以下と大きな隔たりがあるようです。

varImpPlot(result)

f:id:ushi-goroshi:20190206152054p:plain

続いて分岐の終点となった変数(分岐先)についても見てみましょう。

> var_pairs %>% 
+     group_by(to_var) %>%
+     # group_by(from_var, to_var) %>% 
+     summarise(cnt = n()) %>% 
+     arrange(desc(cnt))
# A tibble: 19 x 2
   to_var             cnt
   <chr>            <int>
 1 TotalCharges     47553
 2 MonthlyCharges   47137
 3 tenure           37964
 4 PaymentMethod    20138
 5 gender           14242
 6 Partner          12567
# 中略

同じく19変数ありますので、全ての変数は分岐元・分岐先ともに出現しています。分岐元と同じく出現頻度はばらつきがあり、出現しやすい変数としては、 TotalChargesMonthlyChargestenure となっています。これは少し意外ですね。てっきり分岐元に選ばれやすい変数と分岐先に選ばれやすい変数は違うものになると思っていましたが。

せっかくなので分岐元と分岐先で選ばれやすさが異なるか、可視化してみましょう。

plt <- 
   var_pairs %>% 
   group_by(from_var) %>%
   summarise(cnt_f = n()) %>% 
   left_join(var_pairs %>% group_by(to_var) %>% summarise(cnt_t = n()),
             by = c("from_var" = "to_var")) %>%
   gather(var_type, cnt, -from_var) %>% 
   rename(var = from_var)

ggplot(plt, aes(x = reorder(var, -cnt), y = cnt, fill = var_type)) +
   geom_bar(stat = "identity", position = "dodge") +
   # scale_color_brewer(palette = "Set2") +
   # facet_wrap(~var_type, nrow = 2) +
   theme_classic() +
   theme(axis.text.x = element_text(angle = 90, hjust = 1)) +
   NULL

f:id:ushi-goroshi:20190206152252p:plain

上位3変数は分岐元として選ばれやすいですが、分岐先としては更に頻度が多くなっています。また変数間の順位にはほとんど変動はないようですね。

組み合わせでも見てみましょう。

> var_pairs %>% 
+     group_by(from_var, to_var) %>%
+     summarise(cnt = n()) %>% 
+     arrange(desc(cnt))
# A tibble: 355 x 3
# Groups:   from_var [19]
   from_var       to_var           cnt
   <chr>          <chr>          <int>
 1 MonthlyCharges TotalCharges    6032
 2 TotalCharges   MonthlyCharges  6018
 3 TotalCharges   TotalCharges    5858
 4 MonthlyCharges MonthlyCharges  5623
 5 tenure         MonthlyCharges  5602
 6 tenure         TotalCharges    5407
# 中略

上位10件の組み合わせを見ると、9行目の PaymentMethod を除いていずれも上位3変数の組み合わせになっています。同じ変数のペアも出てきているので、レンジを絞る形で分岐条件として選ばれているようですね、なるほど。なお19 * 19 = 361なので発生していない組み合わせがあるようですが、一部ですね。

4. 交互作用効果を確かめてみる

ひとまず目的としていた分析は以上となります。今回のデータセットおよび分析条件を用いた場合、 TotalChargesMonthlyChargestenure の3変数が(組み合わせの意味でも、重要度の意味でも)影響の大きい変数であるようです。したがって冒頭のようなクライアントからの指摘があった場合には、特にこの3変数を中心に他の変数との交互作用を確認していくと良いのではないでしょうか。

と、ここまで書いたところで一つ疑問が浮かんできました。このような組み合わせ効果は、GLMでも発見できないでしょうか?例えば二次の交互作用項を準備しておき、Lassoで変数選択させるとこれらの組み合わせが残らないでしょうか?

やってみましょう。yとxを用意します。

library(glmnet) # glmnetを使う
y <- as.matrix(ifelse(d3$Churn == "Yes", 1, 0))
tmp <- scale(model.matrix(Churn ~ .^2 , d3))
x_vars <- which(colSums(apply(tmp, c(1,2), is.nan)) == 0)
x <- cbind(1, tmp[, x_vars])
> dim(x)
[1] 7043  403

Lassoは回帰係数の絶対値に対して罰則が与えられるため、説明変数のスケールを揃えておく必要があります。そのため model.matrix でダミー化したあと scale で正規化しています。またその際に分散が0であるために NaN となってしまう変数は除外しています。

このデータでLassoにかけてみましょう。まずは適切な lambda を得るために cv.glmnet を使いますが、計算時間が少しかかるため nfolds は5にしておきましょう。

res_lasso_cv <- cv.glmnet(x = x, y = y, family = "binomial", alpha = 1, nfolds = 5)
> res_lasso_cv$lambda.min
[1] 0.004130735

このときにDevianceが最小となる lambda は0.004130735のようです。プロットも見ておきましょう。

plot(res_lasso_cv)

f:id:ushi-goroshi:20190206152638p:plain

ここで縦に引かれた二本の破線はそれぞれ lambda.min および lambda.1se を表しています。 lambda.1selambda.min の1SD以内で最も罰則を与えたときの lambda を示すようですね。詳しくは以下を参照してください。

https://web.stanford.edu/~hastie/glmnet/glmnet_alpha.html

ひとまず lambda.min を与えたときの結果を確認しましょう。

res_lasso <- glmnet(x = x, y = y, family = "binomial", alpha = 1,
                    lambda = res_lasso_cv$lambda.min)
# このときの回帰係数の絶対値
> as.data.frame(as.matrix(res_lasso$beta)) %>% 
+     rownames_to_column() %>% 
+     filter(s0 != 0) %>% 
+     mutate(abs_beta = abs(s0)) %>% 
+     arrange(desc(abs_beta)) %>% 
+     select(rowname, abs_beta, s0)
                                                                  rowname     abs_beta            s0
1                                                                  tenure 7.708817e-01 -7.708817e-01
2                                                        ContractTwo year 4.658562e-01 -4.658562e-01
3                                                        ContractOne year 2.723315e-01 -2.723315e-01
4                                              InternetServiceFiber optic 2.640571e-01  2.640571e-01
5                                                       InternetServiceNo 2.202043e-01 -2.202043e-01
6                                        tenure:PaymentMethodMailed check 1.714972e-01 -1.714972e-01
# 中略

むむ。。Lassoにおいても tenure は影響の大きい(回帰係数の絶対値が大きい)変数として選ばれましたが、 TotalChargesMonthlyCharges はいませんね。

っていうか、

> as.data.frame(as.matrix(res_lasso$beta)) %>% 
+     rownames_to_column() %>% 
+     filter(rowname %in% c("TotalCharges", "MonthlyCharges"))
         rowname s0
1 MonthlyCharges  0
2   TotalCharges  0

Lassoで落とされてる。。。

TotalCharges がどのような影響を示すのか partialPlot で見てみましょう。

partialPlot(result, as.data.frame(d3), "TotalCharges") # tibbleのままだとエラーになる

f:id:ushi-goroshi:20190206153106p:plain

グニャグニャですね。特定のレンジでは影響が大きいものの、他ではそうでもないということなんでしょうか。だからRandom Forestのような非線形アルゴリズムだと効果が認められる一方、Lassoのような線形のアルゴリズムでは拾いきれないのかもしれません2。これは素直に、Random Forestの結果から効果のありそうな組み合わせ変数を見つけ、分布を見ながら組み込んだ方が良さそうです。

しかしy軸が1を超えるのはなぜなんでしょうか。。。

終わりに

今回の分析はRandom Forestの結果から交互作用の良い候補を見つけようという趣旨でした。また同様の結果がLassoからも得られるかを検証しましたが、両者の結果は異なるものとなりました。Random Forestは非線形な効果を捉えることができるアルゴリズムなのでこちらの結果から有効な変数ペアを絞り込み、一つずつ検証していくスタイルが良さそうです。


  1. 余談ですがGLMをアルゴリズムと呼ぶのは少し抵抗があります

  2. もちろん加法モデルのようにxに非線形な変換を施すことで捉えにいく方法もあるでしょうけども

randomForestではCharacterは使わないようにしよう

RのrandomForestを使っていてはまったのでメモしておきます。

①目的変数がcharacterだと分類として扱ってくれない

最初にはまったのがこちらでした。目的変数がcharacterだとカテゴリ変数として扱ってもらえないため、分類ではなく回帰としてプログラムが進んでしまい、エラーが返ります。

まずは以下のようにデータを読み込みます。ちなみにこのデータはTelco Customer Churnで、 KaggleからDLしてきました。

library(tidyverse)
library(randomForest)
d <- read_csv("./Data/WA_Fn-UseC_-Telco-Customer-Churn.csv")

このデータを使って下記のようにrandomForestを実行します:

> randomForest(Churn ~ gender, d)
 y - ymean でエラー:  二項演算子の引数が数値ではありません 

「何このエラー??」と思いながらrandomForestの中身を見てみると、以下のような記述が見つかります。

addclass <- is.null(y)
classRF <- addclass || is.factor(y)

classRFというのはカテゴリ変数であるかを判定するbooleanなのですが、is.factorで判断しているんですね。ここでFALSEが返ると、後の工程で分類のためのプログラム(Cで書かれたもの、多分コレのL38 ~ L540)ではなく、回帰用のプログラム(多分コレのL22 ~ L340)が呼ばれてしまうようです。

ちなみにrandomForestはformulaでもmatrixでも対応できるような総称関数になっているので、 コンソールでrandomForestと叩いても中身を見ることはできません。そのような場合、まずはmethodsでどのような関数が含まれているかを確認しましょう。

> methods(randomForest)
[1] randomForest.default* randomForest.formula*
see '?methods' for accessing help and source code

randomForestという関数はrandomForest.defaultrandomForest.formulaという関数を総称しているようです。前者がメインのようなので以下のコマンドを実行します。

getS3method("randomForest", "default")

そうすると先程の関数定義を確認することができます。さらにちなみに、lookupというパッケージを用いることでそういった総称関数やCで書かれた関数などをRStudio上でシンタックスハイライトさせながら表示することができます。大変便利ですので、是非こちらの記事を参考にして使ってみてください。

②説明変数がcharacterだとダミー化してくれない

次にはまったのがこちらでした。さきほどエラーが返ってきたrandomForestを、目的変数をfactorに直して実行してみましょう:

d2 <- 
   d %>% 
   mutate(Churn = as.factor(Churn))
> randomForest(Churn ~ gender, d2)
 強制変換により NA が生成されました  randomForest.default(m, y, ...) でエラー: 
   外部関数の呼び出し (引数 1) 中に NA/NaN/Inf があります

NAがありますというエラーなのですが、このデータではNATotalCharges列にしかありませんので、どうやら途中で生成されているようです。ちなみにRStudioでRMarkdownを使っているとき、チャンク内で実行すると上記のエラーが表示されるのですが、コンソールで実行すると以下のようにもう少し情報が追加されます:

> randomForest(Churn ~ gender, d2)
 randomForest.default(m, y, ...) でエラー: 
   外部関数の呼び出し (引数 1) 中に NA/NaN/Inf があります 
 追加情報:  警告メッセージ: 
 data.matrix(x):   強制変換により NA が生成されました 

data.matrixが悪さしているようですね。私は普段チャンク内で実行させることが多いため、 この表示に気付かなくて時間を無駄にしました。

randomForest.defaultを見ると以下の記述があります。

if (is.data.frame(x)) {
   xlevels <- lapply(x, mylevels)
   ncat <- sapply(xlevels, length)
   ncat <- ifelse(sapply(x, is.ordered), 1, ncat)
   x <- data.matrix(x)

ここでxをdata.matrix(x)でダミー化しようとしたものの、characterであったために失敗しているようですね。

> d2 %>% select(Churn, gender) %>% str()
Classes ‘tbl_df’, ‘tbl’ and 'data.frame':  7043 obs. of  2 variables:
 $ Churn : Factor w/ 2 levels "No","Yes": 1 1 2 1 2 2 1 1 2 1 ...
 $ gender: chr  "Female" "Male" "Male" "Male" ...
> d2 %>% select(Churn, gender) %>% data.matrix(.) %>% head()
 強制変換により NA が生成されました      Churn gender
[1,]     1     NA
[2,]     1     NA
[3,]     2     NA
[4,]     1     NA
[5,]     2     NA
[6,]     2     NA

それではfactorに直して実行してみましょう。

d3 <- d2 %>% select(Churn, gender) %>% mutate(gender = as.factor(.$gender))
> str(d3)
Classes ‘tbl_df’, ‘tbl’ and 'data.frame':  7043 obs. of  2 variables:
 $ Churn : Factor w/ 2 levels "No","Yes": 1 1 2 1 2 2 1 1 2 1 ...
 $ gender: Factor w/ 2 levels "Female","Male": 1 2 2 2 1 1 2 1 1 2 ...
> randomForest(Churn ~ gender, d3)

Call:
 randomForest(formula = Churn ~ gender, data = d3) 
               Type of random forest: classification
                     Number of trees: 500
No. of variables tried at each split: 1

        OOB estimate of  error rate: 26.54%
Confusion matrix:
      No Yes class.error
No  5174   0           0
Yes 1869   0           1

ようやく動くようになりました。

このようなエラーにはまらないためには、例えばread_csvの代わりに(StringsAsFactorsFALSEにしないで)read.csvを使うか、randomForestの代わりにrangerを使うという手があります。

> ranger::ranger(Churn ~ gender, d)
Ranger result

Call:
 ranger::ranger(Churn ~ gender, d) 

Type:                             Classification 
Number of trees:                  500 
Sample size:                      7043 
Number of independent variables:  1 
Mtry:                             1 
Target node size:                 1 
Variable importance mode:         none 
Splitrule:                        gini 
OOB prediction error:             26.54 % 

もとのデータでも動いていますね!

終わりに

今回のエラーはモデリングに入る前にダミー化を行っていれば防げたものでしたが、Rだとカテゴリ変数をそのまま渡しても動くものが多いのでついサボってしまいました。普段からダミー化をする習慣が身についている人や、多分Pythonでははまらないんでしょうね。

ただ今回はエラーを追いかけながら総称関数のソースの便利な確認方法を知ることができたので良かったです。

GLMをもう少し理解したい④

前回の記事では、結局GLMというのは以下の方程式:


\mathbf{X^ {T}WXb}^ {(m)} = \mathbf{X^ {T}Wz}

を用いて、 \mathbf{b}を反復的に求めることであると説明しました(IRLS)。

ushi-goroshi.hatenablog.com

そのために必要なパーツとしては \mathbf{W} \mathbf{z}であり、これらは( \mathbf{Y}を除けば) \mu \etaとそれらのそれぞれに対する微分です。 ではそれらの値をどのように求めるのか実際に試してみましょう。

ポアソン回帰

関数の定義

まずはポアソン回帰で確かめてみましょう。以下のように関数を定義します:

get_eta <- function(b) {
   eta <- X %*% b
   return(eta)
}

get_mu <- function(b) {
   mu <- exp(X %*% b)
   return(mu)
}

get_var <- get_mu

GLMにおいて \etaは常に線形予測子によって説明されるものなので、シンプルに \mathbf{X} \mathbf{b}の積によって得られます。一方 \mu \etaを逆リンク関数で \mathbf{Y}の世界に戻したものなので、ポアソン回帰で一般的な逆リンク関数である exp を使います。 3つ目の関数は \mathbf{Y}の分散を得るためのものですが、ポアソン回帰では Var(\mathbf{Y}) = E(\mathbf{Y})を仮定しますので get_mu を再利用しています。

続いて \mathbf{W} \mathbf{z}を得るための関数を定義します:

get_z <- function(b) {
   z <- get_eta(b) + (d$y - get_mu(b)) / get_mu(b)
   return(z)
}

get_W <- function(b) {
   w <- get_mu(b)^2 / get_var(b)
   return(w)
}

get_z の二項目には分母として get_mu を用いました。これは前回の記事で説明した通り \mathbf{z}


z_{i} = \Sigma_{k=1}^ {p} x_{ik}b_{k}^ {(m-1)} + (y_{i} - \mu_{i})( \frac{\partial{\eta_{i}}}{\partial{\mu_{i}}} )

で得られますが1


\frac{\partial{\eta_{i}}}{\partial{\mu_{i}}} = \frac{\partial{log(\mu_{i})}}{\partial{\mu_{i}}} = \frac{1}{\mu_{i}}

となるためです。また get_W の分子には get_mu の二乗を使いましたが、これは


\frac{\partial{\mu_{i}}}{\partial{\eta_{i}}} = \frac{\partial{exp(\eta_{i})}}{\partial{\eta_{i}}} = exp(\eta_{i}) = \mu_{i}

となるためです。

サンプルデータ作成

続いてサンプルデータを作成します。

set.seed(123)
n <- 100
x <- cbind(rep(1, n), runif(n, -1, 1))
b <- c(1, 0.5)
lam <- exp(x %*% b)

d <- data.frame(
   y = rpois(n, lam),
   x = x[, -1])

ここでは(推定されるべき)真の回帰係数として b1 = 1b2 = 0.5 としました。 特に難しいところはないと思いますので、このデータを使って解を求めてみます。

解の推定

以下のように反復の条件を設定します:

iteration <- 0 # イテレータ
b_old <- c(2, 0.3) # 解の初期値
threshold <- 1e-06 # 反復終了の閾値
diff <- 1 # 解の更新前後の差
X <- cbind(rep(1, n), d[, -1]) # 説明変数X
n <- nrow(d) # サンプルサイズ
W <- matrix(0, n, n) # W

最後の W ですが、IRLSにおいて \mathbf{W}は対角行列であり、対角要素のみが更新されるため予め0行列を用意しておきました。

では、反復を開始します:

while (diff > threshold) {

   z <- get_z(b_old) # zを計算
   w <- get_W(b_old) # Wを計算
   diag(W) <- w # Wの対角要素を更新
   
   xwx <- t(X) %*% W %*% X # 左辺を計算
   xwz <- t(X) %*% W %*% z # 右辺を計算
   
   b_new <- solve(xwx) %*% xwz # solveで逆行列を求める
   diff  <- sum(abs(b_old - b_new)) # 解の更新前後の差分
   b_old <- b_new # 解の更新
   
   iteration <- iteration + 1
   cat(sprintf("Iterations: %i, b_New_1: %1.8f, b_New_2: %1.8f \n", 
               iteration, b_new[1], b_new[2]))
   if (iteration > 100) break

}

上を実行すると、下記のような結果が得られます。

Iterations: 1, b_New_1: 1.37767882, b_New_2: 0.37311981 
Iterations: 2, b_New_1: 1.07883183, b_New_2: 0.45722167 
Iterations: 3, b_New_1: 1.02225611, b_New_2: 0.49049195 
Iterations: 4, b_New_1: 1.02040076, b_New_2: 0.49240696 
Iterations: 5, b_New_1: 1.02039846, b_New_2: 0.49241027 
Iterations: 6, b_New_1: 1.02039846, b_New_2: 0.49241027

最終的に b1 = 1.02039846b2 = 0.49241027 で停止したようですが、 glm の解と一致するでしょうか?

> coef(glm(y ~ x, d, family = poisson("log")))
(Intercept)           x 
  1.0203985   0.4924103 

合っていますね!

ロジスティック回帰

続いてロジスティック回帰を試してみます。先に言っておくとこの記事の執筆時点でロジスティック回帰の方は上手く行っていませんので予めご承知おきください。

関数の定義

やるべきことはポアソン回帰と変わりません。まずは関数を以下のように定義します:

get_eta <- function(b) {
   eta <- X %*% b
   return(eta)
}

get_mu <- function(b) {
   mu <- exp(X %*% b) / (1 + exp(X %*% b))
   return(mu)
}

get_var <- function(b) {
   var <- get_mu(b) * (1 - get_mu(b))
   return(var)
}

get_z <- function(b) {
   p <- get_mu(b)
   z <- get_eta(b) + (d$y - p) / (1/p + 1/(1-p))
   return(z)
}

get_W <- function(b) {
   t <- get_mu(b) * (1 - get_mu(b))
   w <- t^2 / get_var(b)
   return(w)
}

eta についてはポアソン回帰と同じですが、逆リンク関数が異なるため mu は定義が違います。ロジスティック回帰における一般的なリンク関数はロジットなので、逆リンク関数にはロジスティック関数を用います。また get_var にはベルヌーイ分布の分散である pq を用いました。

 \mathbf{z} \mathbf{W}については、


logit(\mu) = log(\mu) - log(1-\mu) \\
\frac{\partial{logit(\mu_{i})}}{\partial{\mu_{i}}} = \frac{1}{\mu} + \frac{1}{(1-\mu)}

および


logistic(\eta) = \frac{exp(\eta)}{1 + exp(\eta)}  = \frac{exp(\eta) + exp(\eta)^ {2} - exp(\eta)^ {2}}{(1 + exp(\eta))^ {2}} \\
= \frac{exp(\eta)}{(1 + exp(\eta))^ {2}} = \frac{exp(\eta)}{1 + exp(\eta)} \frac{1}{1 + exp(\eta)} \\
= logistic(\eta) (1 - logistic(\eta))

を使っています。

サンプルデータ作成

先程と同様にサンプルデータを作成します。

set.seed(789)
n <- 100
x <- cbind(rep(1, n), runif(n, -1, 1))
b <- c(1, 0.5)
eta <- x %*% b
p <- exp(eta)/(1 + exp(eta))

d <- data.frame(
   y = rbinom(n, 1, p),
   x = x[, -1])

回帰係数などは特に変更していません。

解の推定

では反復を開始します。ポアソン回帰の時と違い、随分と収束に時間がかかるため反復回数の上限を500回とし、100回ごとに表示しています。また収束の基準も厳し目にしましたが、それ以外は while の中身に変更はありません。

iteration <- 0
b_old <- c(2, 0.3)
threshold <- 1e-08
diff <- 1
X <- cbind(rep(1, n), d[, -1])
n <- nrow(d)
W <- matrix(0, n, n)

while (diff > threshold) {

   z <- get_z(b_old)
   w <- get_W(b_old)
   diag(W) <- w
   
   xwx <- t(X) %*% W %*% X
   xwz <- t(X) %*% W %*% z
   
   b_new <- solve(xwx) %*% xwz
   diff  <- sum(abs(b_old - b_new))
   b_old <- b_new
   
   iteration <- iteration + 1
   if (iteration %% 100 == 0) {
      cat(sprintf("Iterations: %i, b_New_1: %1.8f, b_New_2: %1.8f \n", 
                  iteration, b_new[1], b_new[2]))
   }
   if (iteration > 500) break
}

上記を実行すると、以下のような結果が得られます:

Iterations: 100, b_New_1: 0.91099754, b_New_2: 0.35764440 
Iterations: 200, b_New_1: 0.86993338, b_New_2: 0.33293052 
Iterations: 300, b_New_1: 0.86937860, b_New_2: 0.33250983

一応、500回までは行かずに収束したと判断されたようなので、 b_new を確認してみましょう。

> b_new
          [,1]
[1,] 0.8693712
[2,] 0.3325039

> coef(glm(y ~ x, d, family = binomial("logit")))
(Intercept)           x 
  0.8825275   0.3975239

…うーん、合っていませんね。収束までにも随分と反復していますし、何かおかしいようです。 残念ながらこの理由についてはまだ良くわかっていません。途中の微分の計算が間違っているのか、分散の定義がダメなのか。。。

終わりに

というわけで、GLMではどのように計算が行われているのかを4回に渡って追いかけてきました。 最後はパッとしない結果になってしまいましたが、IRLSがどのように計算を行っているのかを理解できた気がします。これで今までよりももう少し自信を持って glm を使うことができそうですね。


  1. 一部修正あり

GLMをもう少し理解したい③

前回の記事において、GLMでは以下の方程式を用いてパラメータベクトルを推定するという話をしました:


\mathbf{b}^ {(m)} = \mathbf{b}^ {(m-1)}  + [\mathfrak{J}^ {(m-1)}]^ {-1} \mathbf{U}^ {(m-1)}

ushi-goroshi.hatenablog.com


今回はその続きです。

※ 1/25 記事を修正しました

最尤推定

上の式には情報行列 \mathfrak{J}逆行列が入っているので、前から情報行列を乗じます:


\mathfrak{J}^ {(m-1)}\mathbf{b}^ {(m)} = \mathfrak{J}^ {(m-1)}\mathbf{b}^ {(m-1)}  +  \mathbf{U}^ {(m-1)}
\tag{1}

ところで、(ここもまた天下りですが)スコア Uは以下の式で表されますが*1


\mathbf{U_{j}} = \Sigma_{i=1}^ {N} [ \frac{(y_{i} - \mu_{i})}{var(Y_{i})} x_{ij} ( \frac{\partial{\mu_{i}}}{\partial{\eta_{i}}} ) ]
\tag{2}

 \mathfrak{J}はスコア Uの分散共分散行列であり、上記の式から

 \begin{align}
\mathfrak{J}_{jk} &= E\left\{\Sigma_{i=1}^ {N}[ \frac{(Y_{i} - \mu_{i})}{var(Y_{i})} x_{ij} (\frac{\partial{\mu_{i}}}{\partial{\eta_{i}}}) ] \Sigma_{l=1}^ {N}[ \frac{(Y_{l} - \mu_{l})}{var(Y_{l})} x_{lk} (\frac{\partial{\mu_{l}}}{\partial{\eta_{l}}}) ]\right\} \\

 &= \Sigma_{i=1}^ {N} \frac{E[ (Y_{i} - \mu_{i})^ {2} ] x_{ij}x_{ik}}
 {[var(Y_{i})]^ 2} (\frac{\partial{\mu_{i}}}{\partial{\eta_{i}}})^ {2}
\end{align}

となり、さらに E[ (Y_{i} - \mu_{i})^ {2}] = var(Y_{i})から以下のようになります:


\mathfrak{J}_{jk} = \Sigma_{i=1}^ {N} \frac{x_{ij}x_{ik}}
 {var(Y_{i})} (\frac{\partial{\mu_{i}}}{\partial{\eta_{i}}})^ {2}

これを行列で表記すると:


\mathfrak{J} = \mathbf{X^ {T} WX}

となります。ただし \mathbf{W}


w_{ii} = \frac{1}{var(Y_{i})} (\frac{\partial{\mu_{i}}}{\partial{\eta_{i}}})^ {2}

となる対角行列です。よって冒頭の式(1)の左辺は \mathbf{X^ {T} WXb}^ {(m)}と書けます。

次に右辺ですが、式(2)から以下のように書けます:


\Sigma_{k=1}^ {p} \Sigma_{i=1}^ {N} \frac{x_{ij}x_{ik}}{var(Y_{i})} 
(\frac{\partial{\mu_{i}}}{\partial{\eta_{i}}} )^ {2}b_{k}^ {(m-1)} + \Sigma_{i=1}^ {N} \frac{(y_{i}-\mu_{i})x_{ij}}{var(Y_{i})} (\frac{\partial{\mu_{i}}}{\partial{\eta_{i}}} )

このとき1項目と2項目を \mathbf{X^ {T}W}でくくり、残りを \mathbf{z}とすると


\mathbf{X^ {T}WXb}^ {(m)} = \mathbf{X^ {T}Wz}

が得られます。これは線形モデルに対する重み付き最小二乗法を適用して得られる正規方程式と同じ形となりますが、 \mathbf{z} \mathbf{W} \mathbf{b}に依存するため、反復的に解く必要があります。
なお \mathbf{X^ {T}W}でくくる際、 (\frac{\partial{\mu_{i}}}{\partial{\eta_{i}}})^ {2}を使っているため二項目にはこの逆数である (\frac{\partial{\eta_{i}}}{\partial{\mu_{i}}})が残り、 \mathbf{z}にはそれが渡ります。すなわち:


\mathbf{z_{i}} = \Sigma_{k=1}^ {p} x_{ik}b_{k}^ {(m-1)} + (y_{i} - \mu_{i})( \frac{\partial{\eta_{i}}}{\partial{\mu_{i}}} )

です。

※ 以下のセクションで \eta \muの関係を取り違えて説明していたので修正しました

さて、この \mathbf{z}に含まれる (\frac{\partial{\eta_{i}}}{\partial{\mu_{i}}})はどのような形になるでしょうか?GLMにおいて \eta線形予測子 \mathbf{X^ {T}b}をリンク関数で変換したもの線形予測子で説明されるもので、 \muはそれを逆リンク関数で変換したものでした。つまり、


\eta = g(\mu) \\
\mu = g^ {-1}(\eta)

という関係であることを思い出すと、 (\frac{\partial{\eta_{i}}}{\partial{\mu_{i}}})は分析者が事前に設定したリンク関数に依存することになります。例えばリンク関数としてlogを指定していれば、


(\frac{\partial{\eta_{i}}}{\partial{\mu_{i}}}) = (\frac{\partial{log(\mu_{i})}}{\partial{\mu_{i}}}) = \frac{1}{\mu_{i}}

となります。identityなら1になるでしょう。

同様に w_{ii}に含まれる (\frac{\partial{\mu_{i}}}{\partial{\eta_{i}}})は、 \mu g^ {-1}(\eta)であることから逆リンク関数に依存します。例えば逆リンク関数がexpなら


(\frac{\partial{\mu_{i}}}{\partial{\eta_{i}}}) = (\frac{\partial{exp(\eta_{i})}}{\partial{\eta_{i}}}) = exp(\eta_{i}) = \mu_{i}

となるでしょう。

以上から、 \mathbf{X} \mathbf{W} \mathbf{z}が分かれば \mathbf{b}を更新することができ、またそれぞれについても具体的に計算できそうです。GLMではこの反復重み付き最小二乗法(IRLS, Iteratively Reweighted Least Squares)によって最尤推定量を求めます。

それでは次回は、具体的な数値を使って計算してみましょう。

*1:Dobsonの式(4.18)より。この式の導出がちょっと煩雑だったので割愛。

GLMをもう少し理解したい②

前回の記事からだいぶ間が空いてしまいましたが続きを書いてみます。なおこの記事は主にDobsonの「一般化線形モデル入門」の第3・4章を参考にしていますので、そちらも合わせてご確認ください。良書です。

ushi-goroshi.hatenablog.com

一般化線形モデル入門 原著第2版

一般化線形モデル入門 原著第2版

さて、前回の記事では、以下のように結びました:

つまりglmは最尤法によって解を推定していると思われている(そしてそれは正しい)のだけれど、実際には最小二乗解をHouseholder法によって得ているのだということがわかりました。

これは果たしてどういうことなのでしょうか?これについて答えるために以下の流れで見ていきます。

※1/22 記事を少し修正しました

ニュートン・ラプソン法

初めにニュートン・ラプソン法についてです。やや天下り的で申し訳ないのですが、一般にGLMでは解を推定する際に数値的な解法を必要とします(理由については後述します追記:考えていた説明が誤っていたので削除します)。その際に用いられる手法としてニュートン・ラプソン法というものがあり、これは非線形な方程式の解を反復によって求める手法です。

あまり詳しい説明はできませんが、原理的には以下の式を反復的に更新することで求めたい解を得ることができます:


x^ {(m)} = x^ {(m-1)} - \frac{f(x^ {(m-1)})}{f'(x^ {(m-1)})}

ここで x^ {(m-1)}および x^ {(m)}はそれぞれ更新前後の解となります。  f(x)は対象となる関数、 f'(x)はその導関数を表します。以下の例で簡単に見てみましょう:

x <- seq(-3, 3, 0.1)
y <- sin(x)
plot(x, y, type = "l")

f:id:ushi-goroshi:20190116233606p:plain

上記のようなsin関数において sin(x) = 0となる点を見つけます。先ほど示したニュートン・ラプソンのアルゴリズムに従い、

  1. 初期値を入力
  2. 関数およびその導関数による返り値を計算する
  3. 解を更新する
  4. 更新後の関数値が十分0に近くなったら反復を停止する

ことで最終的な解を得ます。Rで書くと以下のようになるでしょう:

iteration <- 0 # イテレータ
threshold <- 1e-02 # 反復の停止を判断する閾値
x_old <- -1 # 更新前の解
updates <- 1 # 更新後の解による関数値(の絶対値)

while (updates > threshold) {
   x_new <- x_old - sin(x_old)/cos(x_old) # 更新後の解を得る
   updates <- abs(sin(x_new)) # 更新後の解による関数値を得る
   x_old <- x_new # 解を更新する
   cat(sprintf("Iterations: %i, X_New: %f \n", iteration, x_new)) # 表示
   iteration <- iteration + 1 # イテレータを増やす
}

上記のスクリプトを実行すると下のような結果が得られました。

Iterations: 1, X_New: 0.557408 
Iterations: 2, X_New: -0.065936 
Iterations: 3, X_New: 0.000096 

3回の反復でほぼ f(x) = 0に収束している様子がわかります。この例の場合は x = 0において f(x)つまり sin(x)が0となることが示されました1

スコア法

上記では一般的なニュートン・ラプソン法の原理を紹介しましたが、本記事のテーマであるGLMに置き換えると、 xおよび f(x)はそれぞれ \thetaおよび l'(\theta)が対象となります。ここで \thetaはパラメータ、即ち推定したい対象であり、 l'は対数尤度関数導関数です。

なぜ対数尤度関数そのものではなく対数尤度関数の導関数なのでしょうか?これについては以下のような図を考えるとわかりやすいかもしれません。

x <- seq(-2, 3, 0.1)
quad <- function(x) -(x-1)^2 + 3
derive <- function(x) -2 * x + 2
plot(x, quad(x), type = "l", ylim = c(-10, 10), ylab = "")
par(new = T)
plot(x, derive(x), type = "l", col = "red", ylim = c(-10, 10), ylab = "f(x) or f'(x)")
lines(x, rep(0, length(x)), lty = 2, col = "blue")

f:id:ushi-goroshi:20190117003859p:plain

ここで黒い線は対象となる関数(二次関数)、赤い線はその導関数、青い点線はy = 0を示しています。

先ほどニュートン・ラプソン法の節において

 sin(x) = 0となる点を見つけます

と書きましたが、このアルゴリズムでは返り値が0となる点を見つける一方、我々が得たいのは対数尤度関数を最大にする点です。このグラフで言えば黒い線を最大にするような xです。

もしこの黒い線に対してニュートン・ラプソン法を当てはめるとどうなるでしょうか?おそらく黒い線と青い線が交わるところに収束するでしょう。先ほどのスクリプトを少し修正してみます:

iteration <- 0
threshold <- 1e-02
x_old <- 0 # 初期値を変更
updates <- 1

while (updates > threshold) {
   x_new <- x_old - quad(x_old)/derive(x_old) # 関数を変更
   updates <- abs(quad(x_new))
   x_old <- x_new
   iteration <- iteration + 1
   cat(sprintf("Iterations: %i, X_New: %f \n", iteration, x_new))
}
Iterations: 1, X_New: -1.000000 
Iterations: 2, X_New: -0.750000 
Iterations: 3, X_New: -0.732143

-0.732143という点で収束しましたが、quad(-0.732143)-0.0003193724となるため、確かに f(x)が0となる点を見つけているようです。しかしこれは最大となる値ではありません。

それでは赤い線で試してみるとどうでしょうか?答えはグラフからも明らかですが、元々の目的である黒い線の最大値で収束しそうです。

上記のような理由により、GLMにニュートン・ラプソンを当てはめる場合には、対数尤度関数そのものではなく対数尤度関数の導関数を用いると言えそうです。

それでは改めて、GLMにおけるニュートン・ラプソン法の推定方程式を見てみましょう:


\theta^ {(m)} = \theta^ {(m-1)} - \frac{U^ {(m-1)}}{U'^ {(m-1)}}

ここで Uは対数尤度関数を \theta微分したもので、スコア関数スコア統計量、または単にスコアなどと呼ばれます。

 U'はスコアを更に微分したものとなりますが、最尤推定においては U'そのものではなく E(U')により近似したものを用いることがあります。このとき、


E(U') = -Var(U) = -\mathfrak{J}

という関係があるため2、先程の方程式は以下のように書き直されます:


\theta^ {(m)} = \theta^ {(m-1)} + \frac{U^ {(m-1)}}{\mathfrak{J}^ {(m-1)}}

さらに推定対象をパラメータのベクトル \mathbf{\beta}に置き換えると:


\mathbf{b}^ {(m)} = \mathbf{b}^ {(m-1)}  + [\mathfrak{J}^ {(m-1)}]^ {-1} \mathbf{U}^ {(m-1)}

と一般化され、この式を用いて解を得る方法をスコア法と言います。

最尤推定

それでは上の式がそれぞれどのように得られるのかを確認していきますが、ここで一度切りたいと思います。


  1. 仮に初期値を2としていた場合、グラフの右端に向かって値が更新されることになるためx = 3.1415…で収束します。

  2. Dobsonの式(3.16)

小話

以下は全て憶測に依るもので、全く根拠のない話です。

ハイパフォーマーを発見するための分析を考えたとき、過去の個人ごとの業績から予測モデルを構築すると、「性別」の効果が有意に効くのではと思います1。そしてそれはきっと、男性の効果がプラス(または女性だとマイナス)になるでしょう。

もちろんこれは、業績に対して男女に差があることを示すものというよりも、他の条件が同一であったとしても女性の業績が低く評価されやすい現象(いわゆるガラスの天井)を、「性別」という効果が吸収してしまっているだけなんだと思います。しかし、そのようにして構築したモデルをテストデータ(つまりは新入社員)に対して当てはめると、潜在的な能力に係わらず、女性は不当に低く評価されてしまいます。

これを避けるにはどうしたら良いかと言うと、答えは簡単で、予測の際には「性別」の効果を予測モデルから外せば良いだけです。ここでのポイントは、モデルを学習する際にはガラスの天井効果を調整するために「性別」の効果を入れつつも、予測の際にはそれを用いないという点にあります。

このような操作が可能なところが線形モデル(GLMやGAMを含む)の良いところだと思うのですが、しかし同時に、その解釈可能性の高さ故に誤解を招くこともあるんじゃないかとも思います。つまり、学習した結果を見た人が短絡的に「女性の方が能力が低い」と理解してしまうことを助長してしまうのではないでしょうか。

その可能性を考えると、むしろ解釈可能性が低いアルゴリズムの方が良いのかもしれません。そのようなアルゴリズムでは、線形モデルのように単純に性別の効果を外して予測を行うことができるわけではないでしょうけども、例えば全ての対象者について真の性別に係わらず一律に男性 or 女性としておけばガラスの天井効果が紛れることを避けられます。

そう考えると、今後の機械学習界隈において「解釈可能性の高いモデル」は本当に求められるべきものなのか悩みます。


  1. 性別は単に例として挙げただけで、ここは学歴でも出身地方でも人種でも何でも構いません。

GLMをもう少し理解したい①

背景

一般化線形モデル(GLM)は、一般に線形回帰モデルを正規分布を含む指数分布族に拡張したものだと捉えられています。アイディアとしてはシンプルである割に非常に有用で、GLMによって

  • 整数値(ポアソン回帰)
  • 二値(ロジスティック回帰)
  • 0〜1の実数(ベータ回帰) ※2020/4/8追記 ベータ回帰は一般的でなかったので消しておきます
  • 0以上の実数(ガンマ回帰) ※2020/4/8追記 代わりにガンマ回帰を追加

などを扱うことができ、しかも回帰係数という非常に解釈性の高い結果を得ることができます1

そんなGLMですが、よく使う割には内容を今ひとつ理解できていないなと思うことがあったので、もう少しだけGLMを理解したいと思いRのglmの中身を見てみました。その内容をメモしておきます。

ちなみにこの検証を行っている環境は以下の通りです:

> sessionInfo()
R version 3.3.3 (2017-03-06)
Platform: x86_64-apple-darwin13.4.0 (64-bit)
Running under: macOS  10.13.3

locale:
   [1] ja_JP.UTF-8/ja_JP.UTF-8/ja_JP.UTF-8/C/ja_JP.UTF-8/ja_JP.UTF-8

attached base packages:
   [1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
   [1] tools_3.3.3  yaml_2.1.13  knitr_1.15.1

glm

まずはRのglmがどのように定義されているかを見てみましょう。コンソールでglmと入力することで、以下のようにglmという関数の定義を見ることができます。

> glm
function (formula, family = gaussian, data, weights, subset, 
          na.action, start = NULL, etastart, mustart, offset, control = list(...), 
          model = TRUE, method = "glm.fit", x = FALSE, y = TRUE, contrasts = NULL, 
          ...) 

まずはここでglmに渡す引数を定義しています。これらの引数でよく使われるのはformulafamilydataでしょうか。それぞれglmに渡す線形予測子(数式)、Yの従う分布、モデリングに用いるデータを指定しています。その他、データポイントの一つ一つの重みを変えたい場合にはweights、データの一部を使用する場合にはsubsetを指定したりします。

関数の定義は以下より始まりますが、細かい話は飛ばしてglmの本体に向かいましょう。

{
   
   call <- match.call()
   if (is.character(family)) 
      family <- get(family, mode = "function", envir = parent.frame())   
   
   ##
   ## 中略
   ##
   
   ## 本体はココのようです
   fit <- eval(call(if(is.function(method)) "method" else method,
                    x = X, y = Y, weights = weights, start = start, etastart = etastart, 
                    mustart = mustart, offset = offset, family = family, 
                    control = control, intercept = attr(mt, "intercept") > 0L))

glmの本体と呼べそうな部分はどうやらこのfitを定義している部分です。最初のevalは与えられた文字列をスクリプトとして解釈するための関数なので、call以降を実行するようです。またcallここによると「与えられた名前の関数の、与えられた引数への適応からなる未評価の表現式である」とのことなので、callに続くmethodおよび残りの引数がmethod(...)の形でevalに与えられ、関数として評価されます。

つまり

eval(call(if(is.function(method)) "method" else method, ...

method(...)

と同じとなるはずで、以下の例では同じように動いていることが確認できました。

# 関数を定義
return_cube <- function(x) x^3

# 普通に呼び出す
> return_cube(3)
[1] 27

# eval(call(...))で呼び出す
> eval(call("return_cube", x = 3))
[1] 27

# match.call
> eval(match.call(return_cube, call("return_cube", x = 3)))
[1] 27

> eval(match.call(return_cube, call("return_cube", 3)))
[1] 27

さて、callで指定しているmethodglmの引数で指定されているものでしたが、デフォルトではglm.fitが入力されています。したがってmethod(...)glm.fit(...)となるはずです。そこで今度はglm.fitの定義を確認してみましょう。

glm.fit

glmと同じく、glm.fitについてもコンソールに直接打ち込むことで関数の定義を表示することができます。まずは引数から見てみましょう。

> glm.fit
function (x, y, weights = rep(1, nobs), start = NULL, etastart = NULL, 
    mustart = NULL, offset = rep(0, nobs), family = gaussian(), 
    control = list(), intercept = TRUE) 

xyはそれぞれ説明変数と目的変数を指定し、その他の引数はglmから引き継がれるようですね。

またメインとなるのは以下のループ部分のようです。

   control <- do.call("glm.control", control)
   x <- as.matrix(x)
   xnames <- dimnames(x)[[2L]]

   ##
   ## 中略
   ##

      for (iter in 1L:control$maxit) {

         ##
         ## 中略
         ##

         z <- (eta - offset)[good] + (y - mu)[good]/mu.eta.val[good]
         w <- sqrt((weights[good] * mu.eta.val[good]^2)/variance(mu)[good])
         fit <- .Call(C_Cdqrls, x[good, , drop = FALSE] * 
             w, z * w, min(1e-07, control$epsilon/1000), check = FALSE)
          
         ##
         ## 中略
         ##
      }

   ##
   ## 中略
   ##
      

do.call(...)に渡すcontrollist()なので、do.call("glm.control", list())を実行すると以下が返ります;

> do.call("glm.control", list())
$epsilon
[1] 1e-08

$maxit
[1] 25

$trace
[1] FALSE

maxitが25なので、このループは最大で25回実行されます。ではこのループ内で何が行われているかというと、zwを新たに定義し、それをxおよび互いに乗じた形でC_Cdqrlsに渡しています。また.CallはCで書かれたルーチンを呼び出すための関数なので、ここではC_Cdqrlsという関数にx * wz * wといった引数を渡しているようです。

ではこのC_Cdqrlsはどこにあるのでしょうか?今度はC_Cdqrlsを探してみましょう。

C_Cdqrls

実はこのC_Cdqrlsstatsパッケージの関数として定義されています。しかしエクスポートされていないため、そのままコンソールに打ち込んでも表示されません。そのような場合には:::を使います。

> stats:::C_Cdqrls
$name
[1] "Cdqrls"

$address
<pointer: 0x101a2cdd0>
attr(,"class")
[1] "RegisteredNativeSymbol"

$dll
DLL name: stats
Filename:
         /Library/Frameworks/R.framework/Versions/3.3/Resources/library/stats/libs/stats.so
Dynamic lookup: FALSE

$numParameters
[1] 4

attr(,"class")
[1] "CallRoutine"      "NativeSymbolInfo"

しかしstats:::C_Cdqrlsと打ち込んでも、これまでと異なり関数の定義が表示されません。これは先ほど書いた通り、C_CdqrlsがCで書かれた関数であり、.callで呼び出されるためです。

ではどこから呼び出されるのかと言うと、私の環境では上記のFilenameで指定されている場所のようなのですが、これ自体は実行ファイル(stats.so)となっていてソースが見当たりません。それじゃどこにあるのかということで色々とググってみたところ、どうやらここで見れそうです。ファイル名を見てわかる通り、これはlmを定義しているコードです。glmは深く潜っていくとlmにたどり着くようです

C_Cdqrlsを定義している部分を見てみると:

SEXP Cdqrls(SEXP x, SEXP y, SEXP tol, SEXP chk)
{
   SEXP ans;

   ###
   ### 中略
   ###

   work = (double *) R_alloc(2 * p, sizeof(double));
   F77_CALL(dqrls)(REAL(qr), &n, &p, REAL(y), &ny, &rtol,
           REAL(coefficients), REAL(residuals), REAL(effects),
           &rank, INTEGER(pivot), REAL(qraux), work);
   SET_VECTOR_ELT(ans, 4, ScalarInteger(rank));
   for(int i = 0; i < p; i++)
    if(ip[i] != i+1) { pivoted = 1; break; }
   SET_VECTOR_ELT(ans, 8, ScalarLogical(pivoted));
   UNPROTECT(nprotect);

   return ans;
}

ここまで来ると何がなんだか私にはわかりませんが、returnansを返しているのでansを定義している箇所に着目すると、どうもF77_CALLが怪しい感じです。F77_CALLこのページによるとCからFortranを呼び出すための関数のようです。

F77_CALL

ではFortranで書かれたdqrlsソースコードはどこで見れるのかと言うと、ここのようです。重要そうなところを抜き出すと:

subroutine dqrls(x,n,p,y,ny,tol,b,rsd,qty,k,jpvt,qraux,work)
      integer n,p,ny,k,jpvt(p)
      double precision x(n,p),y(n,ny),tol,b(p,ny),rsd(n,ny),
     .                 qty(n,ny),qraux(p),work(p)
      integer info,j,jj,kk
      
      ### Householder transformation
      call dqrdc2(x,n,n,p,tol,k,qraux,jpvt,work)
      
      ### 
      if(k .gt. 0) then
         do 20 jj=1,ny
            call dqrsl(x,n,n,k,qraux,y(1,jj),rsd(1,jj),qty(1,jj),
     1           b(1,jj),rsd(1,jj),rsd(1,jj),1110,info)
   20       continue
      else
         do 35 i=1,n
            do 30 jj=1,ny
                rsd(i,jj) = y(i,jj)
   30       continue
   35   continue
      endif

となっており、dqrdc2dqrsl(紛らわしいけどdqrlsではない)を呼んでいます。これらはそれぞれ、

  • Householder変換を行う関数
  • そのアウトプットに対して加工および最小二乗解を与える関数

となっています。

随分かかりましたが、ここに来てようやく解を得ることができました。ここまでを振り返ると、glmという関数のコアの部分の役割はそれぞれ:

  1. Householder法による最小二乗解の推定(C_Cdqrls
  2. 上記の反復による収束判定(glm.fit
  3. もろもろの条件設定など(glm

となっているようでした。つまりglmは最尤法によって解を推定していると思われている(そしてそれは正しい)のだけれど、実際には最小二乗解をHouseholder法によって得ているのだということがわかりました。

長くなってしまったので、一旦切ります。次回はこの意味についてもう少し追いかけてみたいと思います。


  1. そのため個人的にはGLMをモデリングのベースラインとすることが多く、ここで十分な精度が得られるかでその後の対応を決めたりしています