統計コンサルの議事メモ

統計や機械学習の話題を中心に、思うがままに

glmnetをもう少し理解したい③

前回の記事では R の関数 elnet の中で elnet という Fortran のサブルーチンが呼ばれ(やっぱりややこしいですね)、さらに type.gaussian の値( covariancenaive )によって elnetuelnetn のいずれかが呼ばれるところまで確認しました。 今回は elnetu の中身を見ていきます。 過去の記事はこちらです。

ushi-goroshi.hatenablog.com

ushi-goroshi.hatenablog.com

elnetu の実装

それでは早速 elnetu を見ていきましょう。 elnetuelnet と同様にそれほど大きくないのでいきなり内容の確認に入りますが、処理としては以下の手順になっているようです:

  1. 前処理
  2. 標準化
  3. フィッティング
  4. 後処理

まずは前処理ですが、メモリの割り付けのあとに chkvars というサブルーチンを呼び出しています。

      subroutine elnetu(parm,no,ni,x,y,w,jd,vp,cl,ne,nx,nlam,  flmin,ulam,thr,isd,intr,maxit,  lmu,a0,ca,ia,nin,rsq,alm,nlp,jerr)
      implicit double precision(a-h,o-z)                                
      double precision x(no,ni),y(no),w(no),vp(ni),ulam(nlam),cl(2,ni)  
      double precision ca(nx,nlam),a0(nlam),rsq(nlam),alm(nlam)         
      integer jd(*),ia(nx),nin(nlam)                                    
      double precision, dimension (:), allocatable :: xm,xs,g,xv,vlam   
      integer, dimension (:), allocatable :: ju                         
      allocate(g(1:ni),stat=jerr)              
      if(jerr.ne.0) return                                              
      allocate(xm(1:ni),stat=jerr)                                      
      if(jerr.ne.0) return                                              
      allocate(xs(1:ni),stat=jerr)                                      
      if(jerr.ne.0) return                                              
      allocate(ju(1:ni),stat=jerr)                                     
      if(jerr.ne.0) return                                              
      allocate(xv(1:ni),stat=jerr)                                    
      if(jerr.ne.0) return                                              
      allocate(vlam(1:nlam),stat=jerr)                                  
      if(jerr.ne.0) return

      ! 1. 前処理
      call chkvars(no,ni,x,ju)

      if(jd(1).gt.0) ju(jd(2:(jd(1)+1)))=0
      if(maxval(ju) .gt. 0)goto 10071                                   
      jerr=7777                                                         
      return                                                            
10071 continue

この chkvars では x の各変数について一行目の値と異なる値が二行目以降にあるかを確認し、 ju に格納しています。

      subroutine chkvars(no,ni,x,ju)
      implicit double precision(a-h,o-z)
      double precision x(no,ni)
      integer ju(ni)
      
      ! ここから各変数のチェックを開始
      do 11061 j=1,ni
      ju(j)=0
      t=x(1,j) ! 1行目の値を取得

      ! ここから2行目の値を確認する
      do 11071 i=2,no
      ! t は x(1, j) なので、各変数 j について 1 行目の値と等しいかを確認している
      if(x(i,j).eq.t) goto 11071 ! 等しければ次の行へ
      ju(j)=1 ! 等しくない数値があれば ju を 1 にして次の変数へ
      goto 11072
11071 continue
11072 continue
11061 continue
      continue
      return
      end

異なる値がなければ全ての値は同じということになりますので、例えば回帰係数を推定する意味はありません。 後続の処理ではこの ju を参照してスキップするかを決めている箇所が多々出てきます。

続いて standard というサブルーチンを呼び出して標準化を行います。

      ! 2. 標準化
      call standard(no,ni,x,y,w,isd,intr,ju,g,xm,xs,ym,ys,xv,jerr)

この standard とうサブルーチンは結構大きく見えますが、切片の有無で処理を分けているため重複部分があります。 処理の内容としては:

  1. 重みの変換
  2. y と x の更新
  3. y と x の内積(共分散)を計算

となっています。

まずは重みの変換を確認してみると、重み w を「重みの総和あたりの重み」に変換し、 さらにその平方根をとったものを v として定義しています。 またその次から、先に述べたように切片の有無によって処理を分けています。

      subroutine standard(no,ni,x,y,w,isd,intr,ju,g,xm,xs,ym,ys,xv,jerr)
      implicit double precision(a-h,o-z)                                
      double precision x(no,ni),y(no),w(no),g(ni),xm(ni),xs(ni),xv(ni)  
      integer ju(ni)                                                    
      double precision, dimension (:), allocatable :: v                 
      allocate(v(1:no),stat=jerr)                                       
      if(jerr.ne.0) return
      
      ! 1. 重みの変換
      w=w/sum(w)
      v=sqrt(w) 

      ! intr は intercept なので切片が 0 であるかで判定
      ! 切片が 0 でない場合は 10141 に飛ばされる
      if(intr .ne. 0) goto 10141                                         

以降の処理ではこの vyx に対して掛け合わせるのですが、全ての観測値の重みが等しい単純なパターンを想定すると w には $1/n$ 、v にはその平方根が入ります。 例えば観測値の数が 100 であれば $w = 1/100 = 0.01$ 、$v = sqrt(1/100) = 0.1$ となります。

ではこのような wv を使って何をやっているかというと、 y に対しては:

  1. yv を乗じたものを新たに y とする
  2. その y内積(二乗和)から vy内積の二乗を減じ、平方根をとる(ys
  3. yys で割る

ということをしています。

      ! 2. y と x の更新
      ! 以下のセクションでは y と x それぞれについて観測値の重みを使って色々と調整する
      ! まずは y      
      ym = 0.0
      y  = v*y 
      ys = sqrt(dot_product(y,y)-dot_product(v,y)**2)
      y  = y/ys 

ただこの説明だけでは意味が分からないと思いますので少し式を整理してみましょう。 もとの y および w を $y0$ 、 $w0$ とおくと、


y_{i} = v_{i} * y0_{i} = \sqrt{w_{i}} * y0 = \sqrt{\frac{w0_{i}}{\sum{w0_{i}}}}y0_{i}

となります。 また ys の二乗(平方根を取る前) $ (ys)^{2} $ は:


 (ys)^{2} = \sum{y_{i}^{2}} - (\sum{v_{i}y_{i}})^2 = \sum{\frac{w_{i}}{\sum{w_{i}}}y0_{i}^{2}} - (\sum{\frac{w_{i}}{\sum{w_{i}}}}y0_{i})^{2}

と書けます。 ここで $w$ は観測値に対する重み $w0$ をその総和で除した形(単純なパターンでは $\frac{1}{n}$ )となっていることを思い出すと、これを乗じたものの総和は重み付き平均となります。 そうすると右辺の第一項はもともとの y ($y0$)の二乗の重み付き平均、第二項は重み付き平均の二乗が得られていることがわかります。 二乗の平均から平均の二乗を引いたものと言えば分散ですので、その平方根をとった ys は $y0$ の重み付き標準偏差を得ているようです (ところで $w$、$w0$ は添字 $i$ を付けるべきですが、はてなブログLaTeX がなぜか崩れるので省略しています)。

実際にサンプルデータで計算してみましょう。 まずは以下のような簡単なデータで二乗の平均から平均の二乗を引いたものが分散になることを確認します。

# 適当なデータ
a <- c(5, 5, 6, 7, 9)

# 一般的な分散の計算
mean((a - mean(a))^2)
# 二乗の平均から平均の二乗を引いてみる
mean(a^2) - mean(a)^2
# R の var を使う
var(a) * 4 / 5
[1] 2.24
[1] 2.24
[1] 2.24

上の例ではいずれも同じ値を返していることがわかります。 なお var を使った計算では不偏分散ではなく標本分散に修正しました。

続いて先の計算にしたがった場合に、やはり同じように分散・標準偏差が得られるかを見てみます。

set.seed(123)
n <- 10
y0 <- rnorm(n)
w0 <- rep(1, n)

w <- w0/sum(w0)
v <- sqrt(w)

y <- v*y0
ys <- sqrt(y %*% y - (v %*% y)^2)
y_new <- y/ys[1]
> mean((y0 - mean(y0))^2) # 一般的な分散の計算
[1] 0.8187336
> mean(y0^2) - mean(y0)^2 # 二乗の平均から平均の二乗を引く
[1] 0.8187336
> var(y0) * (n-1) / n # R の var を使って
[1] 0.8187336
> (ys^2)[1] # 計算した値
[1] 0.8187336

$(ys)^{2}$ が $y0$ の分散になっていることが確認できますね。 ということで、先ほどの処理では wv を使ってもともとの y の重み付き標準偏差を計算し、その値で重み付きの y を除していることがわかりました。 このサブルーチンの名前が standard なので当然ですが、標準化をしているようです。

x についても基本的に同様の処理を行っており、v を使って重み付き標準偏差を計算・標準化をしています。 ただし最後に重み付き平均の二乗 / 分散 に 1 を加算したものを xv に格納しており、これを x の分散としているようなのですが、これが良くわかりませんでした。

ちなみに ju は先ほど説明したように各変数に異なる数値・バラつきがあるかを示すもので、バラつきがなければさっさとループを抜けて次の変数に移っていることがわかります。

      ! x
      do 10151 j=1,ni ! ni は nvars
      if(ju(j).eq.0)goto 10151
      xm(j) = 0.0 
      x(:,j) = v*x(:,j) ! x にも重みを乗じる
      xv(j) = dot_product(x(:,j),x(:,j)) ! x の二乗の重み付き平均

      ! isd は標準化するかの指定で、標準化する場合は 1 が入っており 10171 に飛ばされない
      if(isd .eq. 0) goto 10171 
      xbq = dot_product(v, x(:,j))**2 ! x の重み付き平均の二乗
      vc = xv(j)-xbq ! 重み付き分散
      xs(j) = sqrt(vc) ! 重み付き標準偏差。 ys と対応している。
      x(:,j) = x(:,j)/xs(j) ! 標準偏差で割って標準化。 y/ys と対応している。
      
      ! これはよくわからない
      xv(j) = 1.0 + xbq/vc ! 重み付き平均の二乗 / 分散 に 1 を加算
      goto 10181 
10171 continue
      xs(j)=1.0
10181 continue
      continue
10151 continue
      continue
      goto 10191

切片が 0 でない場合はこちらにきます(基本はこっち)が、処理は上記と大体同じです。 yx ともに値を更新する前に重み付き平均を引いているところが違う点ですね。

      ! 切片が 0 でない場合ここに来る
      ! 基本はこっち
10141 continue
      ! x
      do 10201 j=1,ni
      if(ju(j).eq.0)goto 10201 
      xm(j) = dot_product(w,x(:,j)) ! x の重み付き平均
      x(:,j) = v*(x(:,j)-xm(j))  ! 重み付き平均を引いてから重みを乗じる
      xv(j) = dot_product(x(:,j),x(:,j)) ! 二乗の重み付き平均
      if(isd.gt.0) xs(j) = sqrt(xv(j)) ! 重み付き標準偏差
10201 continue
      continue
      if(isd .ne. 0)goto 10221
      xs = 1.0
      goto 10231
10221 continue
      do 10241 j=1,ni
      if(ju(j).eq.0)goto 10241
      x(:,j) = x(:,j)/xs(j) ! 標準化はここで実行
10241 continue
      continue
      xv=1.0
10231 continue
      continue
      ym = dot_product(w,y) ! y の重み付き平均 
      y  = v*(y-ym)          ! y から重み付き平均を引いたものに重みを乗じる
      ys = sqrt(dot_product(y,y)) ! 二乗和(分散)の平方根(SD)
      y  = y/ys ! 標準化

次の処理は共通のもので、y と x の内積を計算し、 g に格納します。 単純に yx内積を計算しているように見えますが、ここでの y


\frac{\sqrt{\frac{w_{i}}{\sum{w_{i}}}}y0_{i}}{SD(y0)}

x


\frac{\sqrt{\frac{w_{i}}{\sum{w_{i}}}}x0_{i}}{SD(x0)}

となっているので、その内積は重み付き共分散をそれぞれの標準偏差の積で除したもの、つまり重み付きの相関係数となっているはずです。

      ! 3. 内積(重み付き相関係数)を格納
10191 continue                                                          
      continue                                                          
      g = 0.0                                                             
      do 10251 j=1,ni 
      ! j 番目の変数にバラツキがあるなら g に y と x の内積(共分散)を格納する
      ! ただしこの時点での y と x はそれぞれ標準偏差で除したものとなっている
      if(ju(j).ne.0) g(j) = dot_product(y, x(:,j))                          
10251 continue 
      continue 
      deallocate(v) 
      return
      end 

先のサンプルデータで確かめてみましょう。 重みが全て等しいという単純なパターンでは、更新された yx内積相関係数になっていることが確認できます。

set.seed(123)
n <- 10
y0 <- rnorm(n)
x0 <- rnorm(n)
w0 <- rep(1, n)

w <- w0/sum(w0)
v <- sqrt(w)

y <- v*(y0 - (w %*% y0)[1])
ys <- sqrt(y %*% y)
y_new <- y/ys[1]

x <- v*(x0 - (w %*% x0)[1])
xs <- sqrt(x %*% x)
x_new <- x/xs[1]
> (y_new %*% x_new)[1] # 内積
[1] 0.5776151
> cor(y_new, x_new) # 更新後の y と xの相関係数
[1] 0.5776151
> cor(y0, x0) # 元の値の相関係数
[1] 0.5776151

一方重みが観測値によって異なる場合はというと、これは近い値になるものの完全に一致はしませんでした(でもこれなんでだろう、一致するような気がするんだけど)。

set.seed(123)
n <- 10
y0 <- rnorm(n)
x0 <- rnorm(n)
w0 <- rep(1, n) - 0.5 * ifelse(runif(n) > 0.8, 1, 0) # 一部のデータに対して重みを小さくしている

w <- w0/sum(w0)
v <- sqrt(w)

y <- v*(y0 - (w %*% y0)[1])
ys <- sqrt(y %*% y)
y_new <- y/ys[1]

x <- v*(x0 - (w %*% x0)[1])
xs <- sqrt(x %*% x)
x_new <- x/xs[1]
> (y_new %*% x_new)[1]
[1] 0.5687947
> cor(y_new, x_new)
[1] 0.5687133
> cor(y0, x0)
[1] 0.5776151

ところで重み調整後の yx内積相関係数と近似(一致?)するなら、個別のデータのペアが相関に対してどのような影響を持っているかを評価できるのではないでしょうか。

内積ではなく各ペアの掛け算語の値を見てみると、6 番目と 8 番目の値が高い値を示していることがわかります。 このデータの重み付き相関係数0.568 ぐらいだったので、この 2 つの観測値の影響が大きそうです。

> cbind(1:n, y_new * x_new)
      [,1]          [,2]
 [1,]    1 -0.0744142551
 [2,]    2 -0.0043928887
 [3,]    3  0.0261049036
 [4,]    4  0.0004833048
 [5,]    5 -0.0025033504
 [6,]    6  0.2852904868
 [7,]    7  0.0104270429
 [8,]    8  0.3466035906
 [9,]    9 -0.0413263433
[10,]   10  0.0225221645

実際にデータを見てみると、 6 番と 8 番のデータは他の観測値と比べて関連性が強そうに見えます。

> cbind(y_new, x_new)
             y_new       x_new
 [1,] -0.233569767  0.31859541
 [2,] -0.117117800  0.03750829
 [3,]  0.513582873  0.05082900
 [4,] -0.011106121 -0.04351698
 [5,]  0.009617489 -0.26029149
 [6,]  0.568708951  0.50164585
 [7,]  0.126538479  0.08240215
 [8,] -0.481982832 -0.71912020
 [9,] -0.278126098  0.14858851
[10,] -0.136535493 -0.16495465

6 と 8 番目のデータを塗り分けてみるとわかりやすいですね。

cols <- c(1, 1, 1, 1, 1, 3, 1, 3, 1, 1) + 1
plot(y ~ x, col = cols, pch = 16)

f:id:ushi-goroshi:20200824113309p:plain

以上で yx について標準化が終わったのでstandard から elnet に帰ってくると今度は回帰係数の上限・下限についても標準化を行います。 また flmin が 1 以上の場合は vlam を更新するのですが、 flminlambda が指定された場合に 1 が入り、そうでなければ $[0, 1)$ の実数が期待されるパラメータでした。 なので lambba が指定された場合(= flmin が 1 のとき)に vlamy の重み付き標準偏差で調整される事になります。 この vlam は後続の処理(フィッティング)では ulam として渡されるものですが、ulamlambda の指定がなければ 1 、指定があればその降順となるものでした。 要するに lambda の大きさについても標準化するよ、という事のようですね。

      ! jerr に 0 でない値が入っていると return
      if(jerr.ne.0) return

      ! cl は glmnet で cl = rbind(lower.limits, upper.limits) と定義される
      ! 回帰係数の上限・加減
      cl=cl/ys

      ! 標準化の指定が 0 であれば以下はスキップ                     
      if(isd .le. 0) goto 10091
      
      ! 説明変数ごとに標準偏差を乗じる
      do 10101 j=1,ni
      cl(:,j)=cl(:,j)*xs(j)
10101 continue                                                          
      continue                                                          
10091 continue                                                          
      
      ! flmin は glmnet のなかで flmin = as.double(lambda.min.ratio) で定義される
      ! ここで lambda.min.ratio = ifelse(nobs < nvars, 0.01, 1e-04)
      if(flmin.ge.1.0) vlam=ulam/ys 

ではフィッティングです。 ここで呼ばれる elnet1 こそが {glmnet} の本体となり、回帰係数の計算はここで行われます。 この中ではもうサブルーチンはほとんど呼ばれず、初期パラメータを取ってくるものとプログレスバーを表示するためのものだけです。 ようやくたどり着きました、今回も長かったですね。

      ! 3. フィッティング
      ! 本体である elnet1 の呼び出し
      call elnet1(parm,ni,ju,vp,cl,g,no,ne,nx,x,nlam,flmin,vlam,thr,maxi,xv,  lmu,ca,ia,nin,rsq,alm,nlp,jerr)

このサブルーチンは量はそこそこ(180行程度)なのですが、ループが込み入っていて紹介が長くなるので今回はここまでです。 また次回。