GAMをもう少し理解したい
とても久しぶりの更新です。
背景
業務でモデリングを行うとき、私は大抵の場合GLMから始めます。目的変数に合わせて柔軟に分布を選択することが可能で、回帰係数という極めて解釈性の高い結果を得ることができるというのが理由です。
一方でGLMを使っていて不満に感じることの一つが、( の世界で)非線形な効果を表現できないという点です。もちろん2次・3次の項や交互作用項を追加することである程度それらの不満は解消できるのですが、もう少しデータからそれらの特徴を学習したいと思うことがあります。
今回取り上げる一般化加法モデル(Generalized Additive Model, GAM)は、そのような複雑な関連性を表現できるよう説明変数に非線形な変換を行うもので、GLMを拡張したものとなります。ちょっと古いですが、2015年にMicrosoft RearchがKDDに出した論文(PDF)では、GAMを指して「the gold standard for intelligibility when low-dimensional terms are considered」と言っており、解釈性を保ちつつ高い予測精度を得ることができるモデルとしています。なおこの論文ではGAMに2次までの交互作用を追加するGA2Mという手法を提案しています。
このGAMの実装について調べた内容を書き留めておきます。GAMがどういうものかとか、平滑化に関する説明は、他に良いページがありますのでそちらを参照してください。例えば:
- https://logics-of-blue.com/%E5%B9%B3%E6%BB%91%E5%8C%96%E3%82%B9%E3%83%97%E3%83%A9%E3%82%A4%E3%83%B3%E3%81%A8%E5%8A%A0%E6%B3%95%E3%83%A2%E3%83%87%E3%83%AB/
- http://testblog234wfhb.blogspot.com/2014/07/generalized-additive-model.html
- https://blog.datarobot.com/jp/2017/10/24/ga2m-and-rating-table
GAMの実行結果
始めに、GAMを使うとどのような結果を得ることができるのか確認しましょう。ちなみに検証した環境は以下の通りです。
> sessionInfo() R version 3.6.0 (2019-04-26) Platform: x86_64-apple-darwin15.6.0 (64-bit) Running under: macOS Mojave 10.14.6 Matrix products: default BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib locale: [1] ja_JP.UTF-8/ja_JP.UTF-8/ja_JP.UTF-8/C/ja_JP.UTF-8/ja_JP.UTF-8 attached base packages: [1] stats graphics grDevices utils datasets methods base loaded via a namespace (and not attached): [1] compiler_3.6.0 tools_3.6.0 knitr_1.23 xfun_0.7
GAMの実装としてRでは {mgcv}が使われることが多いようですが、今回は「Rによる統計的学習入門」を参考に{gam}を使用しました。ちなみに、この本は統計・機械学習の主要な手法を網羅的に押さえつつ章末にRでの実行方法が紹介されており、大変勉強になる良書です。
- 作者: Gareth James,Daniela Witten,Trevor Hastie,Robert Tibshirani,落海浩,首藤信通
- 出版社/メーカー: 朝倉書店
- 発売日: 2018/08/03
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (1件) を見る
同書で用いているサンプルデータ Wage
を使用するため、{ISLR}を同時に読み込みます。この{ISLR}パッケージは上記の本の原著であるIntroduction of Statistical Learning with Applications in Rから来ているようです。またこのデータは、アメリカの大西洋岸中央部における男性3000人の賃金、および年齢や婚姻状況、人種や学歴などの属性が記録されています。
library(gam) library(ISLR)
データの中身を見てみましょう。
> head(Wage) year age maritl race education region jobclass 231655 2006 18 1. Never Married 1. White 1. < HS Grad 2. Middle Atlantic 1. Industrial 86582 2004 24 1. Never Married 1. White 4. College Grad 2. Middle Atlantic 2. Information 161300 2003 45 2. Married 1. White 3. Some College 2. Middle Atlantic 1. Industrial 155159 2003 43 2. Married 3. Asian 4. College Grad 2. Middle Atlantic 2. Information 11443 2005 50 4. Divorced 1. White 2. HS Grad 2. Middle Atlantic 2. Information 376662 2008 54 2. Married 1. White 4. College Grad 2. Middle Atlantic 2. Information health health_ins logwage wage 231655 1. <=Good 2. No 4.318063 75.04315 86582 2. >=Very Good 2. No 4.255273 70.47602 161300 1. <=Good 1. Yes 4.875061 130.98218 155159 2. >=Very Good 1. Yes 5.041393 154.68529 11443 1. <=Good 1. Yes 4.318063 75.04315 376662 2. >=Very Good 1. Yes 4.845098 127.11574
このデータを用いて早速フィッティングしてみましょう。 glm
と同様に関数 gam
でモデルを当てはめることができます。
res_gam <- gam(wage ~ s(year, 4) + s(age, 5) + education, data = Wage)
ここで s(age, 5)
は、説明変数 year
を平滑化した上でモデルに取り込むことを意味し、5は平滑化の自由度です。なお s()
の時点ではまだ平滑化は行われておらず、平滑化に必要な情報を属性として付与しているだけのようです。
> head(s(Wage$age, 5)) [1] 18 24 45 43 50 54 attr(,"spar") [1] 1 attr(,"df") [1] 5 attr(,"call") gam.s(data[["s(Wage$age, 5)"]], z, w, spar = 1, df = 5) attr(,"class") [1] "smooth"
ではフィッティングした結果を見てみましょう。
> summary(res_gam) Call: gam(formula = wage ~ s(year, 4) + s(age, 5) + education, data = Wage) Deviance Residuals: Min 1Q Median 3Q Max -119.43 -19.70 -3.33 14.17 213.48 (Dispersion Parameter for gaussian family taken to be 1235.69) Null Deviance: 5222086 on 2999 degrees of freedom Residual Deviance: 3689770 on 2986 degrees of freedom AIC: 29887.75 Number of Local Scoring Iterations: 2 Anova for Parametric Effects Df Sum Sq Mean Sq F value Pr(>F) s(year, 4) 1 27162 27162 21.981 2.877e-06 *** s(age, 5) 1 195338 195338 158.081 < 2.2e-16 *** education 4 1069726 267432 216.423 < 2.2e-16 *** Residuals 2986 3689770 1236 --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Anova for Nonparametric Effects Npar Df Npar F Pr(F) (Intercept) s(year, 4) 3 1.086 0.3537 s(age, 5) 4 32.380 <2e-16 *** education --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
パッと見ると、 glm
を当てはめたときと同様の結果が得られるようです。実際、 (2019/10/17修正 今回のケースでは lm.wfit は使われないので正しくありませんでした。)gam
の本体(と思われる) gam.fit
の中では stats::lm.wfit
が呼ばれます。
しかし、 glm
の結果と大きく異なる点として、各説明変数の回帰係数が出ていません。これはどういうことでしょう。以下のように説明変数の効果をプロットしてみます。
par(mfrow = c(1, 3)) plot(res_gam, se = TRUE, col = "blue")
真ん中の age
が分かりやすいですが、この変数は34まではy軸の値が0未満となっており、 age
は( year
や education
を固定した下で)35歳程度まで wage
の平均を下回っていることがわかります。その後45歳をピークに減少傾向に入り、65歳を過ぎると再び平均を下回るようになり、以降は急激に減少していきます。このような現象は、年齢とともに役職が上がることで賃金が増加し、退職によって減少することを考えれば非常に納得感のあるものだと思います。
なお上のプロットに必要なデータは以下のように取れるので、説明変数の各点における目的変数に対する影響を数値でも確認できます(必要な関数がエクスポートされていないので gam:::
で直接呼び出しています)。
tmp <- gam:::preplot.Gam(res_gam, terms = gam:::labels.Gam(res_gam)) age_sm <- cbind(tmp$`s(age, 5)`$x, tmp$`s(age, 5)`$y) age_sm_uniq <- unique(age_sm[order(age_sm[, 1]), ])
プロットしてみましょう。同じ線が描けます。
plot(age_sm_uniq, type = "l", xlab = "age", ylab = "s(age, 5)")
さて、上記の結果は例えば age
の二次の項を含めることで lm
でも再現できるかもしれません。例えば以下のようになります:
### lmでフィッティング res_lm <- lm(wage ~ year + poly(age, 2) + education, Wage) ## poly(age, 2)で二次の多項式とする ### 予測用のデータ作成。ageだけを変化させ、yearとeducationは固定する。 x_lm <- seq(min(Wage$age), max(Wage$age), length.out = 100) nd <- data.frame(year = rep(2003, 100), age = x_lm, education = rep("2. HS Grad")) prd_lm <- predict(res_lm, nd, se.fit = T) ### 予測値から平均を引いてからプロット prd_m <- 90 plot(x_lm, prd_lm$fit - prd_m, type = "l", col = "blue", ylim = c(-40, 10)) lines(x_lm, prd_lm$fit - prd_m + 2*prd_lm$se.fit, lty = "dashed") lines(x_lm, prd_lm$fit - prd_m - 2*prd_lm$se.fit, lty = "dashed")
大体同じようなプロットを作成する事ができました。しかしピークを過ぎてからの緩やかな減少は表現できていませんし、一つ一つの変数について何次までの多項式を含めるかを検討していくのは少し手間がかかります。GAMを使えばデータに存在する細やかな変化を自動的に捉えることができます(もちろん代償もあります)。
GAMの実装
gam()
それでは gam
がどのようにフィッティングを行っているのかを見ていきましょう。本体は最後の方に出てくる gam.fit
なのですが、途中も少し細かく追ってみます。コンソールで gam
を実行すると、以下のように関数の中身を見ることができます。
まず以下のブロックでは、 gam
にオプションとして指定した内容に沿った処理を実行しています。
function (formula, family = gaussian, data, weights, subset, na.action, start = NULL, etastart, mustart, control = gam.control(...), model = TRUE, method = "glm.fit", x = FALSE, y = TRUE, ...) { ### 関数の引数を名前付きで確定。gam(wage ~ s(year, 4) + education, Wage)として与えた場合、 ### formula = と data = がそれぞれ保持される。 ### match.call returns a call in which all of the specified arguments are specified by their full names. call <- match.call() ### familyの判定 if (is.character(family)) family <- get(family, mode = "function", envir = parent.frame()) if (is.function(family)) family <- family() if (is.null(family$family)) { print(family) stop("`family' not recognized") } ### データが指定されていない場合 if (missing(data)) data <- environment(formula) ### 指定されている引数の取り出し mf <- match.call(expand.dots = FALSE) m <- match(c("formula", "data", "subset", "weights", "etastart", "mustart", "offset"), names(mf), 0L) mf <- mf[c(1L, m)] ### 指定されていない引数を指定し、stats::model.frame()の形式に仕立てる mf$na.action = quote(na.pass) mf$drop.unused.levels <- TRUE mf[[1L]] <- quote(stats::model.frame)
次に、ここで一つ gam
ならではの処理として平滑化に使う関数を取り出しています。
### 平滑化の関数を取ってくる(s, lo, random) gam.slist <- gam.smoothers()$slist
gam.smoothers
として新しい平滑化関数を指定することも出来るようですが、デフォルトは s
、 lo
および random
で、それぞれ 平滑化スプライン、局所回帰、ランダム効果としての指定を意味しているようです。 3つめの random
がわからなかったので調べてみたところ、これはカテゴリ変数に対する指定で、パラメータの推定においていわゆる縮小推定を行なうもののようでした。
https://www.rdocumentation.org/packages/gam/versions/1.16.1/topics/random
これらに接頭語として gam
を加えた(e.g. gam.s
)関数が平滑化のための関数として実行されます。 gam.s
については後述します。
次のブロックですが、 call
クラスであった mf
を評価することで data.frame
に変換しています。
### term を mf$formula に渡す mt <- if (missing(data)) terms(formula, gam.slist) else terms(formula, gam.slist, data = data) mf$formula <- mt ### ここで mf 、つまり model.frame が実行されて data.frame になる ### ただし平滑化は実行されず、平滑化のパラメータは attribute として持っている mf <- eval(mf, parent.frame()) if (missing(na.action)) { naa = getOption("na.action", "na.fail") na.action = get(naa) } mf = na.action(mf) mt = attributes(mf)[["terms"]]
ここまでは mf
は call
クラス、すなわち未評価の関数およびその引数を要素に持つオブジェクトでした。それが eval
で評価されたため stats::model.frame
が実行され、 formula
にしたがい data.frame
が生成された、という流れのようです(間違っていたらすみません)。
### method の指定によって処理を分ける。 glm.fit または glm.fit.null 以外の場合はエラー switch(method, model.frame = return(mf), glm.fit = 1, glm.fit.null = 1, stop("invalid `method': ", method))
ここでは method
によって処理を分けていますが、現状は glm.fit
または model.frame
のみ受け付けているようです。 なおhelpを参照すると、 model.frame
を指定した場合フィッティングは行われないようですね。以下のようになります。
> gam(wage ~ s(year, 4) + s(age, 5) + education, data = Wage, method = "model.frame") wage s(year, 4) s(age, 5) education 231655 75.04315 2006 18 1. < HS Grad 86582 70.47602 2004 24 4. College Grad 161300 130.98218 2003 45 3. Some College 155159 154.68529 2003 43 4. College Grad 11443 75.04315 2005 50 2. HS Grad
この通りデータが返ってきます。
以下では、YおよびXをそれぞれ抽出し、さらにフィッティングに必要なオプションを指定しています。
### Y を取り出す Y <- model.response(mf, "any") ### X を matrix で取り出す。 ### gam を実行したときのエラーメッセージ( `non-list contrasts argument ignored` )はここで出ている。 contrasts の指定が良くない様子。 X <- if (!is.empty.model(mt)) model.matrix(mt, mf, contrasts) else matrix(, NROW(Y), 0) ### その他パラメータ(weights, offset, mustart, etastart) weights <- model.weights(mf) offset <- model.offset(mf) if (!is.null(weights) && any(weights < 0)) stop("Negative wts not allowed") if (!is.null(offset) && length(offset) != NROW(Y)) stop("Number of offsets is ", length(offset), ", should equal ", NROW(Y), " (number of observations)") mustart <- model.extract(mf, "mustart") etastart <- model.extract(mf, "etastart")
以上で準備が完了し、 gam.fit
で当てはめを行います。
### ここが本体。 gam.fit を呼び出している fit <- gam.fit(x = X, y = Y, smooth.frame = mf, weights = weights, start = start, etastart = etastart, mustart = mustart, offset = offset, family = family, control = control)
いったん後続の部分は無視して gam.fit
に移りたいと思いますが、長くなったので今回はここまでにします。