統計コンサルの議事メモ

統計や機械学習の話題を中心に、思うがままに

GAMをもう少し理解したい

とても久しぶりの更新です。

背景

業務でモデリングを行うとき、私は大抵の場合GLMから始めます。目的変数に合わせて柔軟に分布を選択することが可能で、回帰係数という極めて解釈性の高い結果を得ることができるというのが理由です。

一方でGLMを使っていて不満に感じることの一つが、( \eta の世界で)非線形な効果を表現できないという点です。もちろん2次・3次の項や交互作用項を追加することである程度それらの不満は解消できるのですが、もう少しデータからそれらの特徴を学習したいと思うことがあります。

今回取り上げる一般化加法モデル(Generalized Additive Model, GAM)は、そのような複雑な関連性を表現できるよう説明変数に非線形な変換を行うもので、GLMを拡張したものとなります。ちょっと古いですが、2015年にMicrosoft RearchがKDDに出した論文(PDF)では、GAMを指して「the gold standard for intelligibility when low-dimensional terms are considered」と言っており、解釈性を保ちつつ高い予測精度を得ることができるモデルとしています。なおこの論文ではGAMに2次までの交互作用を追加するGA2Mという手法を提案しています。

このGAMの実装について調べた内容を書き留めておきます。GAMがどういうものかとか、平滑化に関する説明は、他に良いページがありますのでそちらを参照してください。例えば:

GAMの実行結果

始めに、GAMを使うとどのような結果を得ることができるのか確認しましょう。ちなみに検証した環境は以下の通りです。

> sessionInfo()
R version 3.6.0 (2019-04-26)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS Mojave 10.14.6

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib

locale:
[1] ja_JP.UTF-8/ja_JP.UTF-8/ja_JP.UTF-8/C/ja_JP.UTF-8/ja_JP.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
[1] compiler_3.6.0 tools_3.6.0    knitr_1.23     xfun_0.7 

GAMの実装としてRでは {mgcv}が使われることが多いようですが、今回は「Rによる統計的学習入門」を参考に{gam}を使用しました。ちなみに、この本は統計・機械学習の主要な手法を網羅的に押さえつつ章末にRでの実行方法が紹介されており、大変勉強になる良書です。

Rによる 統計的学習入門

Rによる 統計的学習入門

  • 作者: Gareth James,Daniela Witten,Trevor Hastie,Robert Tibshirani,落海浩,首藤信通
  • 出版社/メーカー: 朝倉書店
  • 発売日: 2018/08/03
  • メディア: 単行本(ソフトカバー)
  • この商品を含むブログ (1件) を見る

同書で用いているサンプルデータ Wage を使用するため、{ISLR}を同時に読み込みます。この{ISLR}パッケージは上記の本の原著であるIntroduction of Statistical Learning with Applications in Rから来ているようです。またこのデータは、アメリカの大西洋岸中央部における男性3000人の賃金、および年齢や婚姻状況、人種や学歴などの属性が記録されています。

library(gam)
library(ISLR)

データの中身を見てみましょう。

> head(Wage)
       year age           maritl     race       education             region       jobclass
231655 2006  18 1. Never Married 1. White    1. < HS Grad 2. Middle Atlantic  1. Industrial
86582  2004  24 1. Never Married 1. White 4. College Grad 2. Middle Atlantic 2. Information
161300 2003  45       2. Married 1. White 3. Some College 2. Middle Atlantic  1. Industrial
155159 2003  43       2. Married 3. Asian 4. College Grad 2. Middle Atlantic 2. Information
11443  2005  50      4. Divorced 1. White      2. HS Grad 2. Middle Atlantic 2. Information
376662 2008  54       2. Married 1. White 4. College Grad 2. Middle Atlantic 2. Information
               health health_ins  logwage      wage
231655      1. <=Good      2. No 4.318063  75.04315
86582  2. >=Very Good      2. No 4.255273  70.47602
161300      1. <=Good     1. Yes 4.875061 130.98218
155159 2. >=Very Good     1. Yes 5.041393 154.68529
11443       1. <=Good     1. Yes 4.318063  75.04315
376662 2. >=Very Good     1. Yes 4.845098 127.11574

このデータを用いて早速フィッティングしてみましょう。 glm と同様に関数 gam でモデルを当てはめることができます。

res_gam <- gam(wage ~ s(year, 4) + s(age, 5) + education, data = Wage)

ここで s(age, 5) は、説明変数 year を平滑化した上でモデルに取り込むことを意味し、5は平滑化の自由度です。なお s() の時点ではまだ平滑化は行われておらず、平滑化に必要な情報を属性として付与しているだけのようです。

> head(s(Wage$age, 5))
[1] 18 24 45 43 50 54
attr(,"spar")
[1] 1
attr(,"df")
[1] 5
attr(,"call")
gam.s(data[["s(Wage$age, 5)"]], z, w, spar = 1, df = 5)
attr(,"class")
[1] "smooth"

ではフィッティングした結果を見てみましょう。

> summary(res_gam)

Call: gam(formula = wage ~ s(year, 4) + s(age, 5) + education, data = Wage)
Deviance Residuals:
    Min      1Q  Median      3Q     Max 
-119.43  -19.70   -3.33   14.17  213.48 

(Dispersion Parameter for gaussian family taken to be 1235.69)

    Null Deviance: 5222086 on 2999 degrees of freedom
Residual Deviance: 3689770 on 2986 degrees of freedom
AIC: 29887.75 

Number of Local Scoring Iterations: 2 

Anova for Parametric Effects
             Df  Sum Sq Mean Sq F value    Pr(>F)    
s(year, 4)    1   27162   27162  21.981 2.877e-06 ***
s(age, 5)     1  195338  195338 158.081 < 2.2e-16 ***
education     4 1069726  267432 216.423 < 2.2e-16 ***
Residuals  2986 3689770    1236                      
---
Signif. codes:  0***0.001**0.01*0.05 ‘.’ 0.1 ‘ ’ 1

Anova for Nonparametric Effects
            Npar Df Npar F  Pr(F)    
(Intercept)                          
s(year, 4)        3  1.086 0.3537    
s(age, 5)         4 32.380 <2e-16 ***
education                            
---
Signif. codes:  0***0.001**0.01*0.05 ‘.’ 0.1 ‘ ’ 1

パッと見ると、 glm を当てはめたときと同様の結果が得られるようです。実際、 gam の本体(と思われる) gam.fit の中では stats::lm.wfit が呼ばれます。2019/10/17修正 今回のケースでは lm.wfit は使われないので正しくありませんでした。

しかし、 glm の結果と大きく異なる点として、各説明変数の回帰係数が出ていません。これはどういうことでしょう。以下のように説明変数の効果をプロットしてみます。

par(mfrow = c(1, 3))
plot(res_gam, se = TRUE, col = "blue")

f:id:ushi-goroshi:20191016112558p:plain

真ん中の age が分かりやすいですが、この変数は34まではy軸の値が0未満となっており、 age は( yeareducation を固定した下で)35歳程度まで wage の平均を下回っていることがわかります。その後45歳をピークに減少傾向に入り、65歳を過ぎると再び平均を下回るようになり、以降は急激に減少していきます。このような現象は、年齢とともに役職が上がることで賃金が増加し、退職によって減少することを考えれば非常に納得感のあるものだと思います。

なお上のプロットに必要なデータは以下のように取れるので、説明変数の各点における目的変数に対する影響を数値でも確認できます(必要な関数がエクスポートされていないので gam::: で直接呼び出しています)。

tmp <- gam:::preplot.Gam(res_gam, terms = gam:::labels.Gam(res_gam))
age_sm <- cbind(tmp$`s(age, 5)`$x, tmp$`s(age, 5)`$y) 
age_sm_uniq <- unique(age_sm[order(age_sm[, 1]), ])

プロットしてみましょう。同じ線が描けます。

plot(age_sm_uniq, type = "l", xlab = "age", ylab = "s(age, 5)")

f:id:ushi-goroshi:20191016113801p:plain

さて、上記の結果は例えば age の二次の項を含めることで lm でも再現できるかもしれません。例えば以下のようになります:

### lmでフィッティング
res_lm <- lm(wage ~ year + poly(age, 2) + education, Wage) ## poly(age, 2)で二次の多項式とする

### 予測用のデータ作成。ageだけを変化させ、yearとeducationは固定する。
x_lm <- seq(min(Wage$age), max(Wage$age), length.out = 100)
nd <- data.frame(year = rep(2003, 100),
               age = x_lm,
               education = rep("2. HS Grad"))
prd_lm <- predict(res_lm, nd, se.fit = T)

### 予測値から平均を引いてからプロット
prd_m <- 90
plot(x_lm, prd_lm$fit - prd_m, type = "l", col = "blue", ylim = c(-40, 10))
lines(x_lm, prd_lm$fit - prd_m + 2*prd_lm$se.fit, lty = "dashed")
lines(x_lm, prd_lm$fit - prd_m - 2*prd_lm$se.fit, lty = "dashed")

f:id:ushi-goroshi:20191016114041p:plain

大体同じようなプロットを作成する事ができました。しかしピークを過ぎてからの緩やかな減少は表現できていませんし、一つ一つの変数について何次までの多項式を含めるかを検討していくのは少し手間がかかります。GAMを使えばデータに存在する細やかな変化を自動的に捉えることができます(もちろん代償もあります)。

GAMの実装

gam()

それでは gam がどのようにフィッティングを行っているのかを見ていきましょう。本体は最後の方に出てくる gam.fit なのですが、途中も少し細かく追ってみます。コンソールで gam を実行すると、以下のように関数の中身を見ることができます。

まず以下のブロックでは、 gam にオプションとして指定した内容に沿った処理を実行しています。

function (formula, family = gaussian, data, weights, subset, 
        na.action, start = NULL, etastart, mustart, control = gam.control(...), 
        model = TRUE, method = "glm.fit", x = FALSE, y = TRUE, ...) 
{
### 関数の引数を名前付きで確定。gam(wage ~ s(year, 4) + education, Wage)として与えた場合、
### formula = と data = がそれぞれ保持される。
### match.call returns a call in which all of the specified arguments are specified by their full names.
call <- match.call()

### familyの判定
if (is.character(family)) 
  family <- get(family, mode = "function", envir = parent.frame())
if (is.function(family)) 
  family <- family()
if (is.null(family$family)) {
  print(family)
  stop("`family' not recognized")
}

### データが指定されていない場合
if (missing(data)) 
  data <- environment(formula)

### 指定されている引数の取り出し
mf <- match.call(expand.dots = FALSE)
m <- match(c("formula", "data", "subset", "weights", "etastart", 
             "mustart", "offset"), names(mf), 0L)
mf <- mf[c(1L, m)]

### 指定されていない引数を指定し、stats::model.frame()の形式に仕立てる
mf$na.action = quote(na.pass)
mf$drop.unused.levels <- TRUE
mf[[1L]] <- quote(stats::model.frame)

次に、ここで一つ gam ならではの処理として平滑化に使う関数を取り出しています。

### 平滑化の関数を取ってくる(s, lo, random)
gam.slist <- gam.smoothers()$slist

gam.smoothers として新しい平滑化関数を指定することも出来るようですが、デフォルトは slo および random で、それぞれ 平滑化スプライン局所回帰ランダム効果としての指定を意味しているようです。 3つめの random がわからなかったので調べてみたところ、これはカテゴリ変数に対する指定で、パラメータの推定においていわゆる縮小推定を行なうもののようでした。

https://www.rdocumentation.org/packages/gam/versions/1.16.1/topics/random

これらに接頭語として gam を加えた(e.g. gam.s)関数が平滑化のための関数として実行されます。 gam.s については後述します。

次のブロックですが、 call クラスであった mf を評価することで data.frame に変換しています。

### term を mf$formula に渡す
mt <- if (missing(data)) 
  terms(formula, gam.slist)
else terms(formula, gam.slist, data = data)
mf$formula <- mt

### ここで mf 、つまり model.frame が実行されて data.frame になる
### ただし平滑化は実行されず、平滑化のパラメータは attribute として持っている
mf <- eval(mf, parent.frame())
if (missing(na.action)) {
  naa = getOption("na.action", "na.fail")
  na.action = get(naa)
}
mf = na.action(mf)
mt = attributes(mf)[["terms"]]

ここまでは mfcall クラス、すなわち未評価の関数およびその引数を要素に持つオブジェクトでした。それが eval で評価されたため stats::model.frame が実行され、 formula にしたがい data.frame が生成された、という流れのようです(間違っていたらすみません)。

### method の指定によって処理を分ける。 glm.fit または glm.fit.null 以外の場合はエラー
switch(method, model.frame = return(mf), glm.fit = 1, glm.fit.null = 1, 
       stop("invalid `method': ", method))

ここでは method によって処理を分けていますが、現状は glm.fit または model.frame のみ受け付けているようです。 なおhelpを参照すると、 model.frame を指定した場合フィッティングは行われないようですね。以下のようになります。

> gam(wage ~ s(year, 4) + s(age, 5) + education, data = Wage, method = "model.frame")
wage s(year, 4) s(age, 5)          education
231655  75.04315       2006        18       1. < HS Grad
86582   70.47602       2004        24    4. College Grad
161300 130.98218       2003        45    3. Some College
155159 154.68529       2003        43    4. College Grad
11443   75.04315       2005        50         2. HS Grad

この通りデータが返ってきます。

以下では、YおよびXをそれぞれ抽出し、さらにフィッティングに必要なオプションを指定しています。

### Y を取り出す
Y <- model.response(mf, "any")

### X を matrix で取り出す。
### gam を実行したときのエラーメッセージ( `non-list contrasts argument ignored` )はここで出ている。 contrasts の指定が良くない様子。
X <- if (!is.empty.model(mt)) 
  model.matrix(mt, mf, contrasts) 
else matrix(, NROW(Y), 0)

### その他パラメータ(weights, offset, mustart, etastart)
weights <- model.weights(mf)
offset <- model.offset(mf)
if (!is.null(weights) && any(weights < 0)) 
  stop("Negative wts not allowed")
if (!is.null(offset) && length(offset) != NROW(Y)) 
  stop("Number of offsets is ", length(offset), ", should equal ", 
       NROW(Y), " (number of observations)")
mustart <- model.extract(mf, "mustart")
etastart <- model.extract(mf, "etastart")

以上で準備が完了し、 gam.fit で当てはめを行います。

### ここが本体。 gam.fit を呼び出している
fit <- gam.fit(x = X, y = Y, smooth.frame = mf, weights = weights, 
               start = start, etastart = etastart, mustart = mustart, 
               offset = offset, family = family, control = control)

いったん後続の部分は無視して gam.fit に移りたいと思いますが、長くなったので今回はここまでにします。