統計コンサルの議事メモ

統計や機械学習の話題を中心に、思うがままに

glmnetをもう少し理解したい④

それでは前回の記事に続いてelnet1の紹介です。過去の記事はこちらです。

ushi-goroshi.hatenablog.com ushi-goroshi.hatenablog.com ushi-goroshi.hatenablog.com

elnet1の実装

前回の記事で最後に触れた通り、elnet1自体は 180 行程度とそれほど大きくはないサブルーチンなのですが、多数のループが込み入っています。 具体的には以下の通り 9 つのループ処理(fortran なので do 文)がネストした構造となっており、 しかもgotoによって行き来しています(わかりやすいように R で書いてありますが、添字は統一してあります)。

# 1番目
for (m in 1:nlam) {
  # 2番目のループ
  for (j in 1:ni) {
  }
  # 3番目のループ
  for (k in 1:ni) {
    # 4番目のループ
    for (j in 1:ni) {
    }
    # 5番目のループ
    for (j in 1:ni) {
    }
  }
  # 6番目のループ
  for (l in 1:nin) {
    # 7番目のループ
    for (j in 1:nin) {
    }
  }
  # 8番目のループ
  for (j in 1:ni) {
  }
  # 9番目のループ
  for (j in 1:nin) {
  }
}

前処理

まずはいつもの通り変数の定義ですが、それに加えて初期パラメータを取得するという処理が入ります。

subroutine elnet1(beta,ni,ju,vp,cl,g,no,ne,nx,x,nlam,flmin,ulam,th
                  *r,maxit,xv,  lmu,ao,ia,kin,rsqo,almo,nlp,jerr)
implicit double precision(a-h,o-z)
double precision vp(ni),g(ni),x(no,ni),ulam(nlam),ao(nx,nlam)
double precision rsqo(nlam),almo(nlam),xv(ni)
double precision cl(2,ni)
integer ju(ni),ia(nx),kin(nlam)
double precision, dimension (:), allocatable :: a,da
integer, dimension (:), allocatable :: mm
double precision, dimension (:,:), allocatable :: c
allocate(c(1:ni,1:nx),stat=jerr)
if(jerr.ne.0) return;

! 初期パラメータを取得
call get_int_parms(sml,eps,big,mnlam,rsqmax,pmin,exmx,itrace)

! a, mm, da を allocate
allocate(a(1:ni),stat=jerr)  ! a は説明変数の数の次元をもつベクトル

if(jerr.ne.0) return
allocate(mm(1:ni),stat=jerr) ! mm は説明変数の数の次元をもつベクトル

if(jerr.ne.0) return
allocate(da(1:ni),stat=jerr)
if(jerr.ne.0) return

ここでget_int_parmsはそれほど大きくないので全体を見てみましょう。 以下のようなサブルーチンです:

subroutine get_int_parms(sml,eps,big,mnlam,rsqmax,pmin,exmx,itrace)
implicit double precision(a-h,o-z)                                
data sml0,eps0,big0,mnlam0,rsqmax0,pmin0,exmx0,itrace0  /1.0d-5,1.0d-6,9.9d35,5,0.999,1.0d-9,250.0,0/  
  sml=sml0                                                          
eps=eps0                                                          
big=big0                                                          
mnlam=mnlam0                                                      
rsqmax=rsqmax0                                                    
pmin=pmin0                                                        
exmx=exmx0                                                        
itrace=itrace0                                                    
return                                                            
entry chg_fract_dev(arg)                                          
sml0=arg                                                          
return                                                            
entry chg_dev_max(arg)                                            
rsqmax0=arg                                                       
return                                                            
entry chg_min_flmin(arg)                                          
eps0=arg                                                          
return                                                            
entry chg_big(arg)                                                
big0=arg                                                          
return                                                            
entry chg_min_lambdas(irg)                                        
mnlam0=irg                                                        
return                                                            
entry chg_min_null_prob(arg)                                      
pmin0=arg                                                         
return                                                            
entry chg_max_exp(arg)                                            
exmx0=arg                                                         
return                                                            
entry chg_itrace(irg)                                             
itrace0=irg                                                       
return                                                            
end

上から3行目のdata文は変数に初期値を与える fortran の記法のようで、dataに続いて宣言した変数に対して/で挟んだ値を初期値として与えるようです。 そのためsml0には 1.0d-5 が、eps0には 1.0d-6 が入力されます。 ここで d は倍精度の指数表記を表します。13行目のentry以降は各変数について特定の値を指定するためのもののようです(entryの使い方がよくわからない…)。

続けていくつか変数に値を代入します。 まずはbtaですが、代入しているbetaは元々parmとして渡されたもので、これはelnet.rparm = alphaとして渡していたものでした。さらにこのalphaglmnet.rで定義されたもので、L1 と L2 それぞれに対する罰則の配分を決めるパラメータです:

(なぜか Tex が表示されないのでひとまず)

(1 − \alpha)/2 ||\beta||^2_2 + \alpha||\beta||_1

      bta=beta

このbtaを 1 から減じたものをombとしますが、このombはすぐ下で定義されるalmとの乗算でdemを定義する(つまりdem = alm * obm)ためだけに使われています。 さらにalmはループの中で更新されながら最終的にはbtaとの乗算によってabとなり、回帰係数の縮小に使われることになります。 またその次のalfalmの更新に使われますので、これらの変数がループの中で更新されつつ回帰係数の縮小に利用されるということになります(他にもあります)。

      omb=1.0-bta
      alm=0.0
      alf=1.0

以下のブロックではeqsalfを定義しますが、flminが 1.0以上であればスキップされるようです。 このflminというのはglmnet.rにおいて罰則lambdaが指定されていれば 1 が、されていない時にはlambda.min.ratioが入力される変数でした。 lambda.min.ratioはデフォルトではlambda.min.ratio = ifelse(nobs < nvars, 0.01, 1e-04)となっていますので 1 よりは小さい値が入りそうです。 したがって以下のブロックは「lambdaが指定されていないときはalfを定義しよう」という処理になっています(eqsはここしか出てきません)。

その場合、epsflmin(=1)の大きい方を新たにeqsと定義しますが、このepsget_int_parmseps0(1.0d-6 という小さい数)を受け取っていました。 一方lambda.min.ratioは先ほど述べたようにデフォルトではlambda.min.ratio = ifelse(nobs < nvars, 0.01, 1e-04)となっていますので、もう少し大きい値となりそうです。 したがってeqsは 0.01 or 1e-04 、alfはその1/(nlam-1)乗となるようです。

      if(flmin .ge. 1.0)goto 10271
      eqs=max(eps,flmin)
      alf=eqs**(1.0/(nlam-1))  ! alf を eqs の (1/(nlam-1)) で定義する

flminが 1 以上である(lambdaが指定されている)場合は上記をスキップしてこちらにきます。rsqはそのまま残差平方和ですね。

続くaelnet1の中で重要な役割を担っているのでじっくりと見ていきましょう。 実はこのaは(縮小された)回帰係数を格納する変数です。

10271 continue
      ! パラメータの初期化
      rsq=0.0 ! 残差平方和
      a=0.0

このaがどうなるのか、フライングして先の処理を見てみましょう。 elnet1の 70 行目前後に以下の処理があります:

      ak=a(k)                                                           
      u=g(k)+ak*xv(k)                                                   
      v=abs(u)-vp(k)*ab                                                 
      a(k)=0.0                                                          
      if(v.gt.0.0) a(k)=max(cl(1,k),min(cl(2,k),sign(v,u)/(xv(k)+vp(k)*dem)))
      if(a(k).eq.ak)goto 10371                                  

akという変数にaの k 番目の値を渡しておき、uvを定義し、aの k 番目の値を 0 に更新した上で色んな値を参照しながら再度更新しています(このuvは後で確認します)。 最終的にaは以下のようにaoという変数に代入されます(154 行目):

      if(nin.gt.0) ao(1:nin,m)=a(ia(1:nin)) 

このaoですが、elnetuの中でelnet1を呼び出すときにはcaという引数として渡されています。

! elnet1 で受け取る変数
! lmu の次に ao がある
subroutine elnet1(beta,ni,ju,vp,cl,g,no,ne,nx,x,nlam,flmin,ulam,th
     *r,maxit,xv,  lmu,ao,ia,kin,rsqo,almo,nlp,jerr)

! elnetu で elnet1 を call するときの引数
! こちらは lmu の次に ca がある
call elnet1(parm,ni,ju,vp,cl,g,no,ne,nx,x,nlam,flmin,vlam,thr,maxi,xv,  lmu,ca,ia,nin,rsq,alm,nlp,jerr)

このcaelnet.rの中で.Fortran("elnet", ...)と call される際に定義される変数でした:

else .Fortran("elnet", ka, parm = alpha, nobs, nvars, as.double(x), 
              y, weights, jd, vp, cl, ne, nx, nlam, flmin, 
              ulam, thresh, isd, intr, maxit, lmu = integer(1), 
              a0 = double(nlam), 
              # ここで ca が定義されている
              ca = double(nx * nlam), 
              ia = integer(nx), 
              nin = integer(nlam),  rsq = double(nlam), alm = double(nlam), 
              nlp = integer(1), jerr = integer(1), PACKAGE = "glmnet")

ここでnxは説明変数の数、nlamは罰則lambdaの数なので、説明変数の数 × lambda の数のベクトルを定義しています(そしてそれがelnet1の中でaoとして評価・格納される)。

このcaelnet.rの後続の処理において以下の箇所で抽出されます:

outlist = getcoef(fit, nvars, nx, vnames)

ここでglmnet:::getcoefは以下の通りで、fitとして返ってきたオブジェクトのcaそのものをbetaに格納しています(ninmaxが 0 の場合は 0 のベクトルが返る)。

# glmnet:::getcoef
function (fit, nvars, nx, vnames) {
  # ここまで省略
  nin = fit$nin[seq(lmu)]
  ninmax = max(nin)
  # ここまで省略
  if (ninmax > 0) {
    # ここで ca を抽出している 
    ca = matrix(fit$ca[seq(nx * lmu)], nx, lmu)[seq(ninmax), 
                                                , drop = FALSE]
    df = apply(abs(ca) > 0, 2, sum)
    ja = fit$ia[seq(ninmax)]
    oja = order(ja)
    ja = rep(ja[oja], lmu)
    ia = cumsum(c(1, rep(ninmax, lmu)))
    # beta に格納する
    beta = drop0(new("dgCMatrix", Dim = dd, Dimnames = list(vnames, 
                                                            stepnames), x = as.vector(ca[oja, ]), p = as.integer(ia - 
                                                                                                                   1), i = as.integer(ja - 1)))
  }
  else {
    beta = zeromat(nvars, lmu, vnames, stepnames)
    df = rep(0, lmu)
  }
  # ここも省略
  list(a0 = a0, beta = beta, df = df, dim = dd, lambda = lam)
}

これにいくつかの情報を追加したものがglmnetの返り値です。elnet1において評価されたaaoに格納され、elnetcaとして渡され、elnet.rbetaに抽出・格納される流れが伝わりましたでしょうか。

重要な変数を説明したところなので、以下ブロックで初期化している変数の詳細は出てきたときに説明するとして、さっさと次に進んでしまいましょう。

      mm=0
      nlp=0
      nin=nlp
      
      iz=0
      mnl=min(mnlam,nlam)

ループ①(almの更新)

上記までで必要な変数の初期化が完了したので、以下よりループに入ります。 一番外側のループはlambdaの個数(nlam)に対して実行されますが、nlamのデフォルトは 100 となっています(glmnet.r)。

以下ではおおよそalmを更新する処理を行うのですが、lambdaの指定の有無や、ループの回数によってalmに入力する値を変えています。

まずはlambdaの指定の有無で処理を分けます。以下のまとまりはflminが 1.0 より小さい場合にスキップされますが、先ほど述べたように、flminglmnet.rにおいてlambdaの指定がない場合に相当します。 lambdaの指定がある場合にはalm = ulam(m)としてalmを更新した上で、10291 までスキップするのですが、この 10291 は 2 番目のループの中にありますので、少し大きめのスキップとなるようです。 なおulamlambdaが指定されている場合、lambdaの降順になっているため、ループの 1 回目であればlambdaの最大値が入ります。

      do 10281 m=1,nlam ! nlambda なので lambda の個数だけループ

      if(itrace.ne.0) call setpb(m-1)  ! プログレスバー
      if(flmin .lt. 1.0)goto 10301
      alm=ulam(m) ! flmin が 1.0 以上の場合は alm = ulam(m) とする
      goto 10291

lambdaの指定がなければ以下の処理に入るのですが、ここではループの回数によってalmに入力する値を変えています。 具体的には、ループの 1 回目にはbig(9.9d35)という極端に大きな値を入力し、 2 回目には 0.0 を、3 回目以降は 元の値にalfを乗じたものを入力します。

10301 if(m .le. 2)goto 10311 ! ループの1回目と2回目はここをスキップ
      alm=alm*alf ! ループの3回目からは alm を alf を乗じる
      goto 10291
10311 if(m .ne. 1)goto 10321 ! ループの2回目はここをスキップ
      alm=big     ! ループの1回目は alm = big(9.9d35) にする 
      goto 10331
10321 continue
      alm=0.0     ! ループの2回目は alm を いったん 0 にする

このalfは先ほど説明した通りeqs^(1.0/(nlam-1))として定義されますが、eqsが 0.01 or 1e-4 とすると、nlambdaを 10 とした場合には以下のような数値になります:

0.01^(1/(10-1))
# [1] 0.5994843
1e-4^(1.0/(10-1))
# [1] 0.3593814

つまりalmはだんだん絶対値が小さくなるわけですね。

ループ②(罰則の定義)

続いて 2 番目のループに入ります…と言いつつ 2 番目のループは一瞬で終わります。 先ほど更新したalmについて変数ごとの内積と比較し、大きい方を採用します。 したがってここでは各変数に対するループとなります。

まずjuvpですが、juは前回記事で確認した通り、chkvarsによって各変数列の内容が全く同じでないかを確認したものでした。 ある変数列の中身が全く同じであれば 0 であったため、ここで次の変数にスキップされます。 次にvpですが、これは 1 回目の記事で確認した通りglmnet.rにおいて各変数に対する罰則の重み(デフォルトは 1) が入ったベクトルとして定義されたものです(vp = as.double(penalty.factor))。 罰則をかけない場合は 0 となり、スキップされるようです。

変数にバラつきがあり、罰則を検討する場合にはここで再度almを更新します。 ここで出てくるgstandardの中でyx内積(共分散)を格納したものとして定義されたものでした。 それを罰則の大きさで除しているため、penalty.factorを小さくする(分母が小さくなる)と共分散が大きくなり変数として残りやすい、というロジックになっているようですね。

ちなみにループ①の 1 回目のループはalmに 9.9d35 という数値が入るので必ずこの数値が採用されると思います。またループ 2 回目は今度はalmが 0.0 になるため、今度は必ず変数の共分散側の数値がalmになると思われます。

      ! 2番目のループ
      ! alm の更新
      do 10341 j=1,ni  ! ni は変数の数
      if(ju(j).eq.0) goto 10341
      if(vp(j).le.0.0) goto 10341
      alm=max(alm,abs(g(j))/vp(j))
10341 continue  ! 2番目のループここまで

上記の処理で変数を横断してalmを更新したのち、以下でさらにalmを更新します。 ここではbtaalpha; L1 と L2 への重みの配分パラメータ)と 0.001 の max で alm を除し、alfを乗じています。 一応ここで式を確認しておくと以下のようになります:


alm = alf * alm/max(bta, 1.0d-3) = eqs^{(1.0/(nlam-1))} * alm/max((1-alpha), 0.001)

一体これは何をやっているんでしょうか…。

      continue
      alm=alf*alm/max(bta,1.0d-3)  

続いていくつかの変数を更新します。 demalm * ombとして定義されますが、ここでombは (1-bta)でした。 またabalmbtaを乗じたものですので、これらはそれぞれ「lambda×(1-alpha)」および「lambda×alpha」ということになり、demabが実質的な罰則の大きさを表すことになりそうですね。

10331 continue
10291 continue
      dem=alm*omb ! dem = alm * (1-bta)
      ab=alm*bta  ! ab = alm * bta

これらがどのように使われているか少し先を見てみましょう。

! ab
u=g(k)+ak*xv(k)   ! L69(ループ③の中)、L119(ループ⑥の中)
v=abs(u)-vp(k)*ab ! L70、L120(ともに上に同じ)

! dem
a(k)=0.0 ! L71、L121
if(v.gt.0.0) a(k)=max(cl(1,k),min(cl(2,k),sign(v,u)/(xv(k)+vp(k)*dem))) ! L72、L122

両方ともvpに乗じており、ababs(u)からの減算、demxv(k)との加算の後にsign(v,u)と除算し、clとの max/min を取っています。 vpは罰則の重みを定義したものでしたので、alphalambdaで決まる罰則の大きさをそのまま使うか弱くするかを決めています。 demの方は演算の結果をaに格納していますが、前述の通りaは回帰係数を保存する変数でしたので、sign(v,u)/(xv(k)+vp(k)*dem)cl(1,k)よりも大きければa、すなわち回帰係数が更新されるということになりますね。

またこの演算が実行されるかの基準としてvが使われており、このvを計算するためにabが使われている、ということのようです。 じゃあこのuとかvって何なの?という話なのですが、これは次のループの話なので少しお待ちください。

残る変数のうちrsq0は残差平方和ですね。またjzizと組み合わせて使われていますが、この条件分岐がちょっと理解出来なかったのでスキップします。 一応、izはループ①の途中(ループ③が終了した時点)で 1 になるため、iz * jzが 0 になるのはほぼjzが 0 の時に限ると言えそうです。 nlpは iteration のカウンターとして使われており、dlxは回帰係数の更新前後の差分を見ています。 どちらもループを抜けるための基準として使われています。

      rsq0=rsq 
      jz=1
      continue
10351 continue
      if(iz*jz.ne.0) goto 10360   ! iz = 0, jz = 1
      nlp=nlp+1 
      dlx=0.0

ちょっと長くなってしまったので一度きります。 次回はループ③から始めます。