glmnetをもう少し理解したい⑤
それでは前回の記事に続いてelnet1
の紹介です。前回の記事はこちらです。
ループ③(回帰係数の推定)
以上までで見てきた通り、ループ①・②では alm
すなわちlambda
を更新しつつ、alpha
(alf
)やpenalty.factor
(vp
)との乗算によって罰則を計算していました。
ループ③ではその罰則を用いて回帰係数を更新します。
なのでこのループがglmnet
においてメインとなる処理と言って良いと思います。
ループ③はni
に対するループです。ここでni
は説明変数の数ですね。k をインデックスとして各説明変数をさらっていきます。
まずju
ですが、これは各説明変数列における数値のバラつきの有無を示す 1/0 のベクトルでした。バラつきがない、すなわち全ての数値が同じであれば(ju(k) == 0
)ループ③をスキップします(goto
の向かう先が10371
で、ループの範囲も同じく10371
となっています)。
do 10371 k=1,ni if(ju(k).eq.0) goto 10371
次にa
から k 番目の変数の値をak
に格納します。前回記事で追いかけた通り、このa
(またはao
)が最終的には回帰係数として返ります。
前処理においてa = 0.0
で初期化されているのでループの 1 周目時点ではak
も 0 ですが、ループ①の 2 周目以降は縮小された回帰係数が入っています。
ak=a(k) ! k 番目の変数の a の値を ak に代入。
続いてu
とv
を計算します。これらは前回の記事で少し紹介した通り、次のブロックで回帰係数a
を更新するためのものです。
u
はg(k)
にak*xv(k)
を加算して計算します。ここでg(k)
はstanderd
においてg(j)=dot_product(y,x(:,j))
、つまりy
とx
の内積として定義されたものでした(y
とx
はそれぞれ標準化されています)。もしも罰則が付いていなければこの共分散が OLS による回帰係数となるはずです(標準化されているのでx
の標準偏差は 1)。
このg
にxv
で重みをつけたak
を加算します。ここでxv
は weight を乗じたx
の二乗和です。しかしループの 1 周目ではak=0
であるためg
がそのまま利用されることになります。
このようにして定義されたu
の絶対値から罰則を減じたものがv
となります。
u=g(k)+ak*xv(k) v=abs(u)-vp(k)*ab
そしてさらにv
が 0 よりも大きい場合(OLS による回帰係数が罰則よりも大きい場合)、
- 「
cl(2,k)
」と「sign(v,u)/(xv(k)+vp(k)*dem)
」を比較して小さい方を選ぶ - それを「
cl(1,k)
」と比較して大きい方を選ぶ
という処理を行い、新たにa
として格納します。
ここでcl
はglmnet.r
でcl = rbind(lower.limits, upper.limits)
として定義されたものなので、推定された値を上限と下限の間に抑えようとしていることがわかります。またv
が 0 以下の場合は 0 となります。
! a(k) を更新 a(k)=0.0 if(v.gt.0.0) a(k)=max(cl(1,k),min(cl(2,k),sign(v,u)/(xv(k)+vp(k)*dem)))
以上が回帰係数の更新を行う処理になります。 ややアッサリしていますが、ここの処理は glmnet を理解する上で極めて重要なのでもう少し説明します。
まず前提として、(Elastic Net ではなく)Lasso では軟閾値作用素と呼ばれる写像を用いて解を推定しています。 ここで軟閾値作用素とは、定数 および において の絶対値が よりも大きければ を、そうでなければ 0 を返す作用素です:
すなわち、推定された回帰係数(の絶対値)が罰則よりも小さければ 0 に丸めてしまい、大きくても罰則の分だけ係数を縮小してしまう、ということです。 一般に Lasso は効果の小さな変数の回帰係数を 0 に縮小する方法として知られていますが、実装としてはこのような軟閾値作用素が用いられており、これを見ると「Lasso はスパースな解を推定できる」という言葉の意味がわかるのではないでしょうか。推定したら 0 になるわけではなく、明示的に 0 にしているのだと。
ここで少し余談なのですが、Lasso や Ridge に関する参考書などを読んでいると「幾何学的な説明」として以下のようなグラフが描かれることがよくあると思います:
このグラフを見るたびに私は納得いかない気分になっていました。と言うのも、Lasso の方(グラフ左側)に着目すると、OLS による推定値の座標(グラフ中の×印の位置)や楕円の広がり方によっては菱形の頂点ではなく辺に接することが普通にあり得そうだからです。 少なくともこのグラフをもって「Lasso は菱形の頂点に接しやすい(ゆえに解が 0 と推定されやすい)」というのは全く自明ではないし直感的でもないな、と思っていました。
そんな時に「機械学習の数理100問シリーズ」の「スパース推定100問 with R」を読んでいると、またも上記のようなグラフが出てきたので悶々としたのですが、次のページには以下のようなグラフがありました:
まさにこれです。このグラフにおいて白色の部分に OLS の推定値がある場合、頂点ではなく辺に接することになります。そこから少しずれて緑色の部分に OLS の推定値が存在する場合には菱形の頂点に接することとなる、つまりいずれか重要でない方の解が 0 として推定されるようになります。
上のグラフのような「幾何学的な説明」は本当に多くの本・記事で見かけるのですが、下のグラフも合わせて説明することでより理解が深まるのでは、と思いました。 余談おわり。
さて、上記のブロックでは、回帰係数が罰則よりも大きく、かつ上限・下限の範囲内であればsign(v,u)/(xv(k)+vp(k)*dem)
を新たなa
とするのでした。
さきほどの軟閾値作用素の説明においては「罰則を減じた回帰係数」(つまりv
)をLasso推定値としていましたが、ここではそれをxv(k)+vp(k)*dem
で除しています。
これは、ここで得ようとしている推定値というのが Lasso ではなく Elastic Net であるためであり、(第一回で紹介した)教科書(P36)では Elastic Net の推定量を
としています。dem
はalm*(1-bta)
で定義されていたことを思い出すと、これは Ridge (L2)に対する罰則であり、上記の式では $\lambda_{2}$ に該当します。
またxv
は X の二乗和を分散で除して 1 を加算したもので、これが何を意味しているのかは以前紹介したときもわからなかったのですが、サンプルデータを使って計算してみるとおおよそ 1 になりそうなのできっとそういう数値なんだろうと思います(適当)。
残る処理ですが、上記によってa(k)
が更新されなければループを抜けて次の変数に移ります(goto
の移動先10371
はループ③の終点でした)。
またmm
が 0 でなければ10391(ループ④の先)に移動するため、以降の処理から次に紹介するループ④までをスキップするようです。
なおこのmm
はループ①の1回目では 0 なので1回目は確実に処理が行われるようですね。
またnx
は非ゼロとする変数の数の上限なので、推定したパラメータ数がそれを越えると3番目のループを抜けるようです。
if(a(k).eq.ak) goto 10371 if(mm(k) .ne. 0) goto 10391 nin=nin+1 if(nin.gt.nx)goto 10372
ループ④(分散共分散行列の計算)
続いてループ④です。
ここでもループの対象は説明変数(ni
)ですが、今度はインデックスとしてj
を用い、分散共分散行列(のようなもの)を計算してc
に格納するようです。
ここでc
はni*nx
のサイズの行列です。
このループは短いのでまとめて見てしまいましょう。
まずはju
で変数にバラツキがあるかを確認し、なければ次の変数にスキップします。
続いてmm
をチェックし、mm
が 0 でなければc
にmm
を代入して次の変数にスキップします(なおこのmm
には後続の処理でnin
が代入されるのですが、そのnin
はmm
を基準に数値が加算されるような変数となっており互いに入り組んでいて何をやっているのかよくわかりませんでした)。
続いてj
とk
を比較して同一(同じ変数)だったらc
にxv
を、同一でなければx
のj
とk
の内積をc
に代入します。xv
は先ほど出てきたx
の二乗和ですので、このc
は分散共分散行列のようなものを計算しているようです(正方行列ではないので分散共分散行列とは言わないでしょうけども)。
do 10401 j=1,ni ! バラツキがなければ以降の処理をスキップ if(ju(j).eq.0)goto 10401 ! mm が 0(パラメータが 0 でない)でなければ次のブロックを実行して次の変数へスキップ if(mm(j) .eq. 0)goto 10421 c(j,nin)=c(k,mm(j)) goto 10401 10421 continue if(j .ne. k)goto 10441 ! 変数が同一でなければ 10441 に飛ぶ c(j,nin)=xv(j) ! 同一だったらここ goto 10401 10441 continue c(j,nin)=dot_product(x(:,j),x(:,k)) ! 同一でなかったら j と k の内積をとる 10401 continue ! 4番目のループはここまで
ループ④が終わった後は少しだけ処理が入ります。
mm
にはnin
が代入されます。またia
にはk
が入りますが、このk
はループ③のインデックスで、ループ③は更新がなければループ④をスキップしてしまうため、パラメータに更新があった変数のインデックスを表すことになります。
その上で、推定された回帰係数の差分を評価し、残差平方和を更新します。
このときg(k)
は縮小前の回帰係数(y
とx(k)
の内積)で、そこから weight調整済みの x の二乗和 を減じたものを残差平方和から減じて計算しています。
continue ! mm に nin を入れる mm(k)=nin ! ia に k を格納 ia(nin)=k 10391 continue ! a(k) の差分をとる。 a(k)、 ak は推定された回帰係数。 del=a(k)-ak ! 残差平方和を更新する rsq=rsq+del*(2.0*g(k)-del*xv(k)) dlx=max(xv(k)*del**2,dlx)
ループ⑤(回帰係数の更新)
さらに続けてループ⑤です。ここは一瞬で終わり、いま計算されたdel
を用いてg(j)
つまり縮小前の回帰係数を更新します。ところでk
はループ③のインデックスで、このループの中では固定されていますので、各変数の回帰係数の縮小に別の変数との共分散を利用しているわけですね。
共分散が大きいということは互いの変数間に相関があるということであり、相関が正なら回帰係数が小さくなるように働くようです。
! 探索範囲は三度説明変数 do 10451 j=1,ni ! インデックスは再度 j を使う if(ju(j).ne.0) g(j)=g(j)-c(j,mm(k))*del 10451 continue ! 5番目のループはここまで continue
ループ⑤を抜けるとすぐにループ③も終了です。
10371 continue ! 3番目のループはここまで
続いて以下のブロックで終了処理の判定を行います。10352
まで飛ぶと、いくつか処理はあるもののそのままreturnとなります。つまりdlx
がthr
よりも小さい、またはnin
がnx
よりも大きい場合にはelnet1
を抜けます。
そうではない場合、もう少し処理が続きます。
10372 continue if(dlx.lt.thr)goto 10352 if(nin.gt.nx)goto 10352 if(nlp .le. maxit)goto 10471 jerr=-m return 10471 continue 10360 continue iz=1 da(1:nin)=a(ia(1:nin)) continue 10481 continue nlp=nlp+1 dlx=0.0
ループ⑥(回帰係数の推定・再)
さらに続いてループ⑥です。実はこのループ、以下の通りループ③と処理がほとんど同じです。
! 3番目のループ(一部省略) do 10371 k=1,ni if(ju(k).eq.0)goto 10371 ak=a(k) u=g(k)+ak*xv(k) v=abs(u)-vp(k)*ab a(k)=0.0 if(v.gt.0.0) a(k)=max(cl(1,k),min(cl(2,k),sign(v,u)/(xv(k)+vp(k)*dem))) if(a(k).eq.ak)goto 10371 if(mm(k) .ne. 0)goto 10391 nin=nin+1 if(nin.gt.nx)goto 10372 continue mm(k)=nin ia(nin)=k 10391 continue del=a(k)-ak rsq=rsq+del*(2.0*g(k)-del*xv(k)) dlx=max(xv(k)*del**2,dlx) do 10451 j=1,ni if(ju(j).ne.0) g(j)=g(j)-c(j,mm(k))*del ! 6番目のループ do 10491 l=1,nin k=ia(l) ak=a(k) u=g(k)+ak*xv(k) v=abs(u)-vp(k)*ab a(k)=0.0 if(v.gt.0.0) a(k)=max(cl(1,k),min(cl(2,k),sign(v,u)/(xv(k)+vp(k)*dem))) if(a(k).eq.ak)goto 10491 del=a(k)-ak rsq=rsq+del*(2.0*g(k)-del*xv(k)) dlx=max(xv(k)*del**2,dlx) do 10501 j=1,nin g(ia(j))=g(ia(j))-c(ia(j),mm(k))*del
ループの対象がni
ではなくnin
になっている点が異なりますが、処理としては大体同じなので説明は省略します。
do 10491 l=1,nin k=ia(l) ! k を取り出す( ia には 0 ではないパラメータが推定された変数の列が格納されてる) ak=a(k) ! a を取り出す u=g(k)+ak*xv(k) v=abs(u)-vp(k)*ab a(k)=0.0 if(v.gt.0.0) a(k)=max(cl(1,k),min(cl(2,k),sign(v,u)/(xv(k)+vp(k)*dem))) if(a(k).eq.ak)goto 10491 del=a(k)-ak rsq=rsq+del*(2.0*g(k)-del*xv(k)) dlx=max(xv(k)*del**2,dlx)
ループ⑦(回帰係数の更新・再)
ループ⑦も同様にループ⑤と同じ処理をnin
に対して行っています。
do 10501 j=1,nin g(ia(j))=g(ia(j))-c(ia(j),mm(k))*del 10501 continue ! 7番目のループはここまで
そしてループ⑥が終了。
continue 10491 continue ! 6番目のループはここまで
ここで終了判定が行われます。
nlp
はループのカウンターとなっているようで、一定回数を過ぎていなければ10481
まで戻されます。
この10481
はループ⑥の手前ですので、dlx
が十分に小さくなければ再度ループ⑥を実行するような流れになっているようですね。
continue if(dlx.lt.thr)goto 10482 if(nlp .le. maxit)goto 10521 jerr=-m return 10521 continue goto 10481 ! ループ⑥の手前まで戻す 10482 continue da(1:nin)=a(ia(1:nin))-da(1:nin)
ループ⑧(回帰係数の更新・再々)
ループ⑧です。改めて、nin
ではなくni
に対して回帰係数の更新が行われます。
ここでda
にはすぐ上のブロックでa
の値からda
の値を減じて更新しているのですが、もう少し上の方でda
にはa
を渡しています。
つまり順番としては、da <- a
とした上でa
を更新し、更新後のa
とda
(つまり更新前のa
)の差分を改めてda
とする、という流れです。
この更新後のda
と分散共分散行列の内積を回帰係数から減じるわけですので、やっていることはループ⑤における回帰係数の更新と同じですね。
do 10531 j=1,ni if(mm(j).ne.0)goto 10531 if(ju(j).ne.0) g(j)=g(j)-dot_product(da(1:nin),c(j,1:nin)) 10531 continue ! 8番目のループはここまで
ループ⑧を抜けると後は終了まで一直線です…と言いたいところですが、ここでなんと衝撃的なことに、10351
、つまりループ③の開始まで戻されてしまいます。なんてこった。
実はループ③の開始直後にはiz*jz
で処理を変える判定があり、ともに 1 であればループ③の終了時点まで移動するのですが、ここでjz
を 0 にしてしまっているので愚直にループ③を再度実行することになります。
しかもjz
が 1 に更新される機会があるのはループ③よりも前の段階なので、一度この処理に入った場合には必ずループ③の処理から再開しないといけない、ということですね。
continue jz=0 goto 10351 ! えっ!!
上のgoto
を無事に回避できた場合、最後の処理に入ります。
以下では必要な変数を格納しています。
10352 continue if(nin .le. nx)goto 10551 ! nin が nx を超えた場合はここにくる jerr=-10000-m goto 10282 ! jerr を 更新して elnet1 を抜ける 10551 continue if(nin.gt.0) ao(1:nin,m)=a(ia(1:nin)) kin(m)=nin ! m 回目のループの nin を kin[m] に格納する rsqo(m)=rsq ! m 回目のループの rsq を rsqo[m] に格納する almo(m)=alm ! m 回目のループの alm を almo[m] に格納する lmu=m if(m.lt.mnl)goto 10281 if(flmin.ge.1.0)goto 10281 me=0
ループ⑨(回帰係数が推定された変数のカウント)
以下ではelnet1
のここまでのループによって推定された回帰係数を確認し、0.0 ではない変数の数をカウントしています。改めて、j
は変数、m
はlambda
のインデックスです。
! 9番目のループ do 10561 j=1,nin if(ao(j,m).ne.0.0) me=me+1 10561 continue ! 9番目のループここまで
最後にme
、rsq
、rsq0
の確認をし、問題なければ次のlambda
に移ります。
continue if(me.gt.ne)goto 10282 if(rsq-rsq0.lt.sml*rsq)goto 10282 if(rsq.gt.rsqmax)goto 10282 10281 continue ! 1番目のループはここまで 10282 continue deallocate(a,mm,c,da) return end
終わりに
以上でelnet1
は終了です。
ここまで随分とかかりましたが、なんとか{glmnet}のメインの処理を最後まで追いかけることが出来ました(途中でわからない部分を飛ばしたりしましたが)。
今回の調査での一番のポイントはやはり、「Lassoでは推定された回帰係数が罰則よりも小さければ 0 に丸めてしまう」ということを確認できたことだと思います。 「Lassoは不要な変数を0として推定することで変数選択できる」というのは間違ってはいないのですが、0として推定できるというよりも明示的に0にしてしまっているという表現の方が正しいと思います。 なので「変数選択できる」という言葉も本来であれば「効果の小さな変数を無視することで変数選択している」という言い方になるのかなと思いました。
こういったモデルにおける重要なポイントを、ソースコードを追いかけながら理解するというのは本当に大事なことだと改めて思います。
それでは。