統計コンサルの議事メモ

統計や機械学習の話題を中心に、思うがままに

glmnetをもう少し理解したい⑤

それでは前回の記事に続いてelnet1の紹介です。前回の記事はこちらです。

ushi-goroshi.hatenablog.com

ushi-goroshi.hatenablog.com

ushi-goroshi.hatenablog.com

ushi-goroshi.hatenablog.com

ループ③(回帰係数の推定)

以上までで見てきた通り、ループ①・②では almすなわちlambdaを更新しつつ、alphaalf)やpenalty.factorvp)との乗算によって罰則を計算していました。 ループ③ではその罰則を用いて回帰係数を更新します。 なのでこのループがglmnetにおいてメインとなる処理と言って良いと思います。

ループ③はniに対するループです。ここでniは説明変数の数ですね。k をインデックスとして各説明変数をさらっていきます。

まずjuですが、これは各説明変数列における数値のバラつきの有無を示す 1/0 のベクトルでした。バラつきがない、すなわち全ての数値が同じであれば(ju(k) == 0 )ループ③をスキップします(gotoの向かう先が10371で、ループの範囲も同じく10371となっています)。

do 10371 k=1,ni
if(ju(k).eq.0) goto 10371

次にaから k 番目の変数の値をakに格納します。前回記事で追いかけた通り、このa(またはao)が最終的には回帰係数として返ります。

前処理においてa = 0.0で初期化されているのでループの 1 周目時点ではakも 0 ですが、ループ①の 2 周目以降は縮小された回帰係数が入っています。

ak=a(k) ! k 番目の変数の a の値を ak に代入。

続いてuvを計算します。これらは前回の記事で少し紹介した通り、次のブロックで回帰係数aを更新するためのものです。

ug(k)ak*xv(k)を加算して計算します。ここでg(k)standerdにおいてg(j)=dot_product(y,x(:,j))、つまりyx内積として定義されたものでした(yxはそれぞれ標準化されています)。もしも罰則が付いていなければこの共分散が OLS による回帰係数となるはずです(標準化されているのでx標準偏差は 1)。

このgxvで重みをつけたakを加算します。ここでxvは weight を乗じたxの二乗和です。しかしループの 1 周目ではak=0であるためgがそのまま利用されることになります。

このようにして定義されたuの絶対値から罰則を減じたものがvとなります。

u=g(k)+ak*xv(k)
v=abs(u)-vp(k)*ab

そしてさらにvが 0 よりも大きい場合(OLS による回帰係数が罰則よりも大きい場合)、

  • cl(2,k)」と「sign(v,u)/(xv(k)+vp(k)*dem)」を比較して小さい方を選ぶ
  • それを「cl(1,k)」と比較して大きい方を選ぶ

という処理を行い、新たにaとして格納します。 ここでclglmnet.rcl = rbind(lower.limits, upper.limits) として定義されたものなので、推定された値を上限と下限の間に抑えようとしていることがわかります。またvが 0 以下の場合は 0 となります。

! a(k) を更新
a(k)=0.0
if(v.gt.0.0) a(k)=max(cl(1,k),min(cl(2,k),sign(v,u)/(xv(k)+vp(k)*dem)))

以上が回帰係数の更新を行う処理になります。 ややアッサリしていますが、ここの処理は glmnet を理解する上で極めて重要なのでもう少し説明します。

まず前提として、(Elastic Net ではなく)Lasso では軟閾値作用素と呼ばれる写像を用いて解を推定しています。 ここで軟閾値作用素とは、定数  a および  \lambda (> 0) において  a の絶対値が  \lambda よりも大きければ  a-\lambda を、そうでなければ 0 を返す作用素です:


S(a, \lambda) = \begin{cases} a - \lambda & (a > \lambda) \\ 0 & (|a| \le \lambda) \\ a + \lambda & (a < -\lambda) \end{cases}

すなわち、推定された回帰係数(の絶対値)が罰則よりも小さければ 0 に丸めてしまい、大きくても罰則の分だけ係数を縮小してしまう、ということです。 一般に Lasso は効果の小さな変数の回帰係数を 0 に縮小する方法として知られていますが、実装としてはこのような軟閾値作用素が用いられており、これを見ると「Lasso はスパースな解を推定できる」という言葉の意味がわかるのではないでしょうか。推定したら 0 になるわけではなく、明示的に 0 にしているのだと。

ここで少し余談なのですが、Lasso や Ridge に関する参考書などを読んでいると「幾何学的な説明」として以下のようなグラフが描かれることがよくあると思います:

f:id:ushi-goroshi:20210421132013j:plain

このグラフを見るたびに私は納得いかない気分になっていました。と言うのも、Lasso の方(グラフ左側)に着目すると、OLS による推定値の座標(グラフ中の×印の位置)や楕円の広がり方によっては菱形の頂点ではなく辺に接することが普通にあり得そうだからです。 少なくともこのグラフをもって「Lasso は菱形の頂点に接しやすい(ゆえに解が 0 と推定されやすい)」というのは全く自明ではないし直感的でもないな、と思っていました。

そんな時に「機械学習の数理100問シリーズ」の「スパース推定100問 with R」を読んでいると、またも上記のようなグラフが出てきたので悶々としたのですが、次のページには以下のようなグラフがありました:

f:id:ushi-goroshi:20210421132410j:plain

まさにこれです。このグラフにおいて白色の部分に OLS の推定値がある場合、頂点ではなく辺に接することになります。そこから少しずれて緑色の部分に OLS の推定値が存在する場合には菱形の頂点に接することとなる、つまりいずれか重要でない方の解が 0 として推定されるようになります。

上のグラフのような「幾何学的な説明」は本当に多くの本・記事で見かけるのですが、下のグラフも合わせて説明することでより理解が深まるのでは、と思いました。 余談おわり。

さて、上記のブロックでは、回帰係数が罰則よりも大きく、かつ上限・下限の範囲内であればsign(v,u)/(xv(k)+vp(k)*dem)を新たなaとするのでした。 さきほどの軟閾値作用素の説明においては「罰則を減じた回帰係数」(つまりv)をLasso推定値としていましたが、ここではそれをxv(k)+vp(k)*demで除しています。 これは、ここで得ようとしている推定値というのが Lasso ではなく Elastic Net であるためであり、(第一回で紹介した)教科書(P36)では Elastic Net の推定量


\hat{\beta}^{EN}_{j} = \begin{cases} (\hat{\beta}^{OLS}_{j} - \lambda_{1})/(1+\lambda_{2}) & (\hat{\beta}^{OLS}_{j} > \lambda_{1}) \\ 0 & (|\hat{\beta}^{OLS}_{j}| \le \lambda_{1}) \\ (\hat{\beta}^{OLS}_{j} + \lambda_{1})/(1+\lambda_{2}) & (\hat{\beta}^{OLS}_{j} < -\lambda_{1}) \end{cases}

としています。demalm*(1-bta)で定義されていたことを思い出すと、これは Ridge (L2)に対する罰則であり、上記の式では $\lambda_{2}$ に該当します。 またxvは X の二乗和を分散で除して 1 を加算したもので、これが何を意味しているのかは以前紹介したときもわからなかったのですが、サンプルデータを使って計算してみるとおおよそ 1 になりそうなのできっとそういう数値なんだろうと思います(適当)。

残る処理ですが、上記によってa(k)が更新されなければループを抜けて次の変数に移ります(gotoの移動先10371はループ③の終点でした)。 またmmが 0 でなければ10391(ループ④の先)に移動するため、以降の処理から次に紹介するループ④までをスキップするようです。 なおこのmmはループ①の1回目では 0 なので1回目は確実に処理が行われるようですね。 またnxは非ゼロとする変数の数の上限なので、推定したパラメータ数がそれを越えると3番目のループを抜けるようです。

if(a(k).eq.ak) goto 10371
if(mm(k) .ne. 0) goto 10391 
nin=nin+1                                                    
if(nin.gt.nx)goto 10372 
ループ④(分散共分散行列の計算)

続いてループ④です。 ここでもループの対象は説明変数(ni)ですが、今度はインデックスとしてjを用い、分散共分散行列(のようなもの)を計算してcに格納するようです。 ここでcni*nxのサイズの行列です。 このループは短いのでまとめて見てしまいましょう。

まずはjuで変数にバラツキがあるかを確認し、なければ次の変数にスキップします。 続いてmmをチェックし、mmが 0 でなければcmmを代入して次の変数にスキップします(なおこのmmには後続の処理でninが代入されるのですが、そのninmmを基準に数値が加算されるような変数となっており互いに入り組んでいて何をやっているのかよくわかりませんでした)。 続いてjkを比較して同一(同じ変数)だったらcxvを、同一でなければxjk内積cに代入します。xvは先ほど出てきたxの二乗和ですので、このcは分散共分散行列のようなものを計算しているようです(正方行列ではないので分散共分散行列とは言わないでしょうけども)。

do 10401 j=1,ni
! バラツキがなければ以降の処理をスキップ
if(ju(j).eq.0)goto 10401

! mm が 0(パラメータが 0 でない)でなければ次のブロックを実行して次の変数へスキップ
if(mm(j) .eq. 0)goto 10421
c(j,nin)=c(k,mm(j))
goto 10401

10421 continue
if(j .ne. k)goto 10441  ! 変数が同一でなければ 10441 に飛ぶ
c(j,nin)=xv(j) ! 同一だったらここ
goto 10401
10441 continue
c(j,nin)=dot_product(x(:,j),x(:,k)) ! 同一でなかったら j と k の内積をとる
10401 continue ! 4番目のループはここまで

ループ④が終わった後は少しだけ処理が入ります。 mmにはninが代入されます。またiaにはkが入りますが、このkはループ③のインデックスで、ループ③は更新がなければループ④をスキップしてしまうため、パラメータに更新があった変数のインデックスを表すことになります。 その上で、推定された回帰係数の差分を評価し、残差平方和を更新します。 このときg(k)は縮小前の回帰係数(yx(k)内積)で、そこから weight調整済みの x の二乗和 を減じたものを残差平方和から減じて計算しています。

continue
! mm に nin を入れる
mm(k)=nin

! ia に k を格納
ia(nin)=k

10391 continue   
! a(k) の差分をとる。 a(k)、 ak は推定された回帰係数。
del=a(k)-ak

! 残差平方和を更新する
rsq=rsq+del*(2.0*g(k)-del*xv(k))
dlx=max(xv(k)*del**2,dlx)
ループ⑤(回帰係数の更新)

さらに続けてループ⑤です。ここは一瞬で終わり、いま計算されたdelを用いてg(j)つまり縮小前の回帰係数を更新します。ところでkはループ③のインデックスで、このループの中では固定されていますので、各変数の回帰係数の縮小に別の変数との共分散を利用しているわけですね。 共分散が大きいということは互いの変数間に相関があるということであり、相関が正なら回帰係数が小さくなるように働くようです。

! 探索範囲は三度説明変数
do 10451 j=1,ni ! インデックスは再度 j を使う
if(ju(j).ne.0) g(j)=g(j)-c(j,mm(k))*del                           
10451 continue ! 5番目のループはここまで
continue

ループ⑤を抜けるとすぐにループ③も終了です。

10371 continue ! 3番目のループはここまで

続いて以下のブロックで終了処理の判定を行います。10352まで飛ぶと、いくつか処理はあるもののそのままreturnとなります。つまりdlxthrよりも小さい、またはninnxよりも大きい場合にはelnet1を抜けます。 そうではない場合、もう少し処理が続きます。

10372 continue
if(dlx.lt.thr)goto 10352
if(nin.gt.nx)goto 10352
if(nlp .le. maxit)goto 10471
jerr=-m
return
10471 continue
10360 continue
iz=1
da(1:nin)=a(ia(1:nin))
continue
10481 continue
nlp=nlp+1
dlx=0.0

ループ⑥(回帰係数の推定・再)

さらに続いてループ⑥です。実はこのループ、以下の通りループ③と処理がほとんど同じです。

! 3番目のループ(一部省略)
do 10371 k=1,ni
if(ju(k).eq.0)goto 10371
ak=a(k)
u=g(k)+ak*xv(k)
v=abs(u)-vp(k)*ab
a(k)=0.0
if(v.gt.0.0) a(k)=max(cl(1,k),min(cl(2,k),sign(v,u)/(xv(k)+vp(k)*dem)))
if(a(k).eq.ak)goto 10371
if(mm(k) .ne. 0)goto 10391
nin=nin+1
if(nin.gt.nx)goto 10372
continue
mm(k)=nin
ia(nin)=k
10391 continue
del=a(k)-ak
rsq=rsq+del*(2.0*g(k)-del*xv(k))
dlx=max(xv(k)*del**2,dlx)
do 10451 j=1,ni
if(ju(j).ne.0) g(j)=g(j)-c(j,mm(k))*del

! 6番目のループ
do 10491 l=1,nin
k=ia(l)
ak=a(k)
u=g(k)+ak*xv(k)
v=abs(u)-vp(k)*ab
a(k)=0.0
if(v.gt.0.0) a(k)=max(cl(1,k),min(cl(2,k),sign(v,u)/(xv(k)+vp(k)*dem)))
if(a(k).eq.ak)goto 10491
del=a(k)-ak
rsq=rsq+del*(2.0*g(k)-del*xv(k))
dlx=max(xv(k)*del**2,dlx)
do 10501 j=1,nin
g(ia(j))=g(ia(j))-c(ia(j),mm(k))*del

ループの対象がniではなくninになっている点が異なりますが、処理としては大体同じなので説明は省略します。

do 10491 l=1,nin
k=ia(l) ! k を取り出す( ia には 0 ではないパラメータが推定された変数の列が格納されてる)
ak=a(k) ! a を取り出す
u=g(k)+ak*xv(k)
v=abs(u)-vp(k)*ab
a(k)=0.0
if(v.gt.0.0) a(k)=max(cl(1,k),min(cl(2,k),sign(v,u)/(xv(k)+vp(k)*dem)))
if(a(k).eq.ak)goto 10491
del=a(k)-ak
rsq=rsq+del*(2.0*g(k)-del*xv(k))
dlx=max(xv(k)*del**2,dlx)
ループ⑦(回帰係数の更新・再)

ループ⑦も同様にループ⑤と同じ処理をninに対して行っています。

do 10501 j=1,nin
g(ia(j))=g(ia(j))-c(ia(j),mm(k))*del
10501 continue ! 7番目のループはここまで

そしてループ⑥が終了。

continue
10491 continue ! 6番目のループはここまで

ここで終了判定が行われます。 nlpはループのカウンターとなっているようで、一定回数を過ぎていなければ10481まで戻されます。 この10481はループ⑥の手前ですので、dlxが十分に小さくなければ再度ループ⑥を実行するような流れになっているようですね。

continue
if(dlx.lt.thr)goto 10482
if(nlp .le. maxit)goto 10521
jerr=-m
return
10521 continue
goto 10481  ! ループ⑥の手前まで戻す
10482 continue
da(1:nin)=a(ia(1:nin))-da(1:nin)

ループ⑧(回帰係数の更新・再々)

ループ⑧です。改めて、ninではなくniに対して回帰係数の更新が行われます。 ここでdaにはすぐ上のブロックでaの値からdaの値を減じて更新しているのですが、もう少し上の方でdaにはaを渡しています。 つまり順番としては、da <- a とした上でaを更新し、更新後のada(つまり更新前のa)の差分を改めてdaとする、という流れです。 この更新後のdaと分散共分散行列の内積を回帰係数から減じるわけですので、やっていることはループ⑤における回帰係数の更新と同じですね。

do 10531 j=1,ni
if(mm(j).ne.0)goto 10531
if(ju(j).ne.0) g(j)=g(j)-dot_product(da(1:nin),c(j,1:nin))
10531 continue ! 8番目のループはここまで

ループ⑧を抜けると後は終了まで一直線です…と言いたいところですが、ここでなんと衝撃的なことに、10351、つまりループ③の開始まで戻されてしまいます。なんてこった。

実はループ③の開始直後にはiz*jzで処理を変える判定があり、ともに 1 であればループ③の終了時点まで移動するのですが、ここでjzを 0 にしてしまっているので愚直にループ③を再度実行することになります。 しかもjzが 1 に更新される機会があるのはループ③よりも前の段階なので、一度この処理に入った場合には必ずループ③の処理から再開しないといけない、ということですね。

continue
jz=0
goto 10351  ! えっ!! 

上のgotoを無事に回避できた場合、最後の処理に入ります。 以下では必要な変数を格納しています。

10352 continue
if(nin .le. nx)goto 10551  ! nin が nx を超えた場合はここにくる
jerr=-10000-m
goto 10282 ! jerr を 更新して elnet1 を抜ける
10551 continue
if(nin.gt.0) ao(1:nin,m)=a(ia(1:nin))
kin(m)=nin   ! m 回目のループの nin を kin[m] に格納する
rsqo(m)=rsq  ! m 回目のループの rsq を rsqo[m] に格納する
almo(m)=alm  ! m 回目のループの alm を almo[m] に格納する
lmu=m
if(m.lt.mnl)goto 10281
if(flmin.ge.1.0)goto 10281
me=0

ループ⑨(回帰係数が推定された変数のカウント)

以下ではelnet1のここまでのループによって推定された回帰係数を確認し、0.0 ではない変数の数をカウントしています。改めて、jは変数、mlambdaのインデックスです。

! 9番目のループ
do 10561 j=1,nin
if(ao(j,m).ne.0.0) me=me+1
10561 continue ! 9番目のループここまで 

最後にmersqrsq0の確認をし、問題なければ次のlambdaに移ります。

continue
if(me.gt.ne)goto 10282
if(rsq-rsq0.lt.sml*rsq)goto 10282
if(rsq.gt.rsqmax)goto 10282
10281 continue ! 1番目のループはここまで

10282 continue
deallocate(a,mm,c,da)
return
end

終わりに

以上でelnet1は終了です。 ここまで随分とかかりましたが、なんとか{glmnet}のメインの処理を最後まで追いかけることが出来ました(途中でわからない部分を飛ばしたりしましたが)。

今回の調査での一番のポイントはやはり、「Lassoでは推定された回帰係数が罰則よりも小さければ 0 に丸めてしまう」ということを確認できたことだと思います。 「Lassoは不要な変数を0として推定することで変数選択できる」というのは間違ってはいないのですが、0として推定できるというよりも明示的に0にしてしまっているという表現の方が正しいと思います。 なので「変数選択できる」という言葉も本来であれば「効果の小さな変数を無視することで変数選択している」という言い方になるのかなと思いました。

こういったモデルにおける重要なポイントを、ソースコードを追いかけながら理解するというのは本当に大事なことだと改めて思います。

それでは。