統計コンサルの議事メモ

統計や機械学習の話題を中心に、思うがままに

データを小集団に分割しながら線形回帰の解を推定する

背景

突然ですが、一般に線形回帰と言えば以下の正規方程式:

 X'Xb = (X'X)^{-1}X'y

をbについて解くことで得られると教わり、そのまま理解していることが多いのではないでしょうか。

これ自体は決して間違っていないのですが、実装においては計算精度の問題から、逆行列ではなくQR分解を活用して解いている場合があります。例えばRでも、以前の記事においてlmソースコードをたどっていった結果、ハウスホルダー変換によってQR分解が行われていることを確認しました。

過去の記事はこちら。 ushi-goroshi.hatenablog.com

ここでlm逆行列およびQR分解による解の推定値をちょっと見てみましょう。適当にデータを作成します。

set.seed(123)
n <- 100
b <- c(1, 1.5) # 切片と回帰係数
x <- cbind(1, rnorm(n))
y <- x %*% b + rnorm(n)

また、それぞれによる解の推定方法を以下のように定義します。

## lm(.fit)を使う
my_lm <- function() { coef(lm.fit(x, y)) }

## 逆行列で解く
my_solve <- function() { solve(crossprod(x, x)) %*% crossprod(x, y) }

## QR分解で解く
my_qr <- function() { solve(qr.R(qr(x))) %*% t(qr.Q(qr(x))) %*% y }

上で定義した関数は、いずれも同じ解を返します:

> cbind(my_solve(), my_qr(), my_lm())
        [,1]      [,2]      [,3]
x1 0.8971969 0.8971969 0.8971969
x2 1.4475284 1.4475284 1.4475284

一緒の値になっていますね。少し脱線しますが、ついでに計算時間も見てみましょう:

time_1000 <- data.frame(microbenchmark::microbenchmark(my_solve(), my_qr(), my_lm(), times = 1000))
> library(ggplot2)
> ggplot(time_1000, aes(x = expr, y = log(time), group = expr)) +
     geom_violin() + 
     coord_flip() +
     labs(x = "functions") +
     NULL

f:id:ushi-goroshi:20190225182213p:plain

逆行列を用いた場合が一番早く、QR分解を用いたものが最も遅いようでした。なおこのグラフは横軸が対数となっていることに注意してください。

さて、このようにして線形回帰の解はQR分解を使って求めることができますが、実は計算を工夫することで、Xを小集団に分割した上でそれぞれのデータからX全体の解を得ることができます。これが何を意味するかというと、メモリに全部載せきれないような大きいデータであっても解を推定したり、あるいは線形回帰であっても並列に計算を回すことができる、ということです1

もともと今回の記事を書こうと思ったのは、以前に「線形回帰はデータを分割して並列計算できる」という話を知人から聞いたことをふと思い出したのがきっかけです。当時は何を言っているのか今いち理解できなかったのですが、大変わかりやすい下記の記事を見つけたため、写経した内容をメモしておきます。

freakonometrics.hypotheses.org

手順

実装に取り掛かる前に手順について簡単に理解しておきましょう。まずXをQR分解すると、冒頭に示した正規方程式から得られる \hat{\beta}は以下のようになります:


X = QR \
\hat{\beta} = (X'X)^{-1}X'y = (R'Q'QR)^{-1}R'Q'y = (R'R)^{-1}R'Q'y

QR分解によって得られる行列Qは直交行列であるため、 (Q'Q) = Iとなります。またここで積の逆行列 (AB)^{-1} = B^{-1}A^{-1}という性質があることから、


(R'R)^{-1}R'Q'y = R^{-1}R'^{-1}R'Q'y = R^{-1}Q'y

となります。すなわちQR分解によって得られた行列Rの逆行列と、行列Qの転置があれば良いことになります。先ほどmy_qrを定義したときは説明なく示しましたが、これは下のように書けます:

## my_qrの定義(再掲)
solve(qr.R(qr(x))) %*% t(qr.Q(qr(x))) %*% y

問題は、この R^{-1}および Q'をどのようにして小集団から再構成するか、ということになりますが、これは以下の手順で計算できるようです:

  1. 共通処理
    1. X、yをそれぞれ小集団に分割する
    2. 各小集団のXをQR分解する
  2.  R^{-1}を計算する
    1. 各小集団からのRを統合する
    2. 再度QR分解してRを得る
    3. Rの逆行列 R^{-1}を求める
  3.  Q'を計算する
    1. 1-2で得られたQを2-2で得たQに乗じる( Q'
  4. 2と3の結果およびyにより解を得る
    1. 3-1で得たQ'にyを乗じる
    2. 両者を乗じる

なおこの手順で確かに R^{-1} Q'が再構成できることは確認できたのですが、これがなぜ上手くいくのかについては残念ながら調べても分からなかりませんでした。もしご存知でしたら誰か教えてください。

実装

それでは実装に入りますが、先にデータをすべて使った時の回帰係数を確認しておきましょう。サンプルデータにはcarsを使い、目的変数をdist、説明変数をspeedとした単回帰を回してみます。

> lm(dist ~ speed, data = cars)$coefficients
(Intercept)       speed 
 -17.579095    3.932409 

切片とspeedの回帰係数がそれぞれ-17.5793.932と推定されました。冒頭でも確認した通り、lmの結果は下記の方法と一致します。

y <- cars$dist
x <- cbind(1, cars$speed)
> cbind(
+    solve(crossprod(x, x)) %*% crossprod(x, y),
+    solve(qr.R(qr(x))) %*% t(qr.Q(qr(x))) %*% y
+ )
           [,1]       [,2]
[1,] -17.579095 -17.579095
[2,]   3.932409   3.932409

この数値を、分割した小集団に対する計算結果から再び得ることが目標となります。

1. 共通処理

1. X、yをそれぞれ小集団に分割する

それではxを小集団に分割した上で解を推定していきます。今回はデータを5つに分割しましょう。xは50行のデータなので各データセットには10行ずつ割り当てられます。各データをlist形式で保存しておきます。

# 分割するデータの数
m <- 5 
n_per_d <- nrow(x) / m

# 割り切れなかった場合用
if (nrow(x) %% m != 0) m <- m + 1 

xlist <- list() # 各xの保存用リスト
ylist <- list() # 各yの保存用リスト
for (i in 1:m) {
   if(i == m) {
      xlist[[i]] = x[((i-1) * n_per_d + 1):nrow(x), ]
      ylist[[i]] = y[((i-1) * n_per_d + 1):nrow(x)]
   }
   xlist[[i]] = x[(i-1) * n_per_d + 1:n_per_d, ]
   ylist[[i]] = y[(i-1) * n_per_d + 1:n_per_d]
}

このような形でデータが保存されます:

> head(xlist[[1]])
     [,1] [,2]
[1,]    1    4
[2,]    1    4
[3,]    1    7
[4,]    1    7
[5,]    1    8
[6,]    1    9
2. 各小集団のXをQR分解する

次に各小集団をQR分解し、その結果として得られる行列QおよびRをそれぞれ保存しておきましょう。リストの各要素は、更にそれぞれQとRを要素に持つリストとなります。

QR1 <- list() # 各データセットに対するQR分解の結果を保存するリスト
for (i in 1:m) {
   QR1[[i]] <- list(Q = qr.Q(qr(xlist[[i]])),
                    R = qr.R(qr(xlist[[i]])))
}

この時点でQR1は、10行2列の行列Qと2行2列の上三角行列Rを要素に持つリストになっています。

> str(QR1)
List of 5
 $ :List of 2
  ..$ Q: num [1:10, 1:2] -0.316 -0.316 -0.316 -0.316 -0.316 ...
  ..$ R: num [1:2, 1:2] -3.16 0 -25.3 7.48
 $ :List of 2
  ..$ Q: num [1:10, 1:2] -0.316 -0.316 -0.316 -0.316 -0.316 ...
  ..$ R: num [1:2, 1:2] -3.16 0 -39.53 2.55
 $ :List of 2
  ..$ Q: num [1:10, 1:2] -0.316 -0.316 -0.316 -0.316 -0.316 ...
  ..$ R: num [1:2, 1:2] -3.16 0 -48.38 3.48
 $ :List of 2
  ..$ Q: num [1:10, 1:2] -0.316 -0.316 -0.316 -0.316 -0.316 ...
  ..$ R: num [1:2, 1:2] -3.16 0 -58.82 2.9
 $ :List of 2
  ..$ Q: num [1:10, 1:2] -0.316 -0.316 -0.316 -0.316 -0.316 ...
  ..$ R: num [1:2, 1:2] -3.16 0 -71.47 5.87

2. R^{-1}を計算する

1. 各小集団からのRを統合する

続いてQR1に格納された行列Rを、rbindで一つにまとめます。

R1 <- c()
for(i in 1:m) {
   R1 <- rbind(R1, QR1[[i]]$R)
}
2. 再度QR分解してRを得る

このR1を再度QR分解し、 その行列Rを得ます。

R2 <- qr.R(qr(R1))
3. Rの逆行列を求める(R^{-1})

この逆行列が、当初求めようとしていたものの1つ R^{-1}になります。

R_inv <- solve(R2)

では、このR_invがデータ全体を使って求めた R^{-1}を同じ値になっているかを確認してみましょう。

> R_inv
          [,1]        [,2]
[1,] 0.1414214  0.41606428
[2,] 0.0000000 -0.02701716

> solve(qr.R(qr(x)))
           [,1]        [,2]
[1,] -0.1414214 -0.41606428
[2,]  0.0000000  0.02701716

あれ?符号が反転していますね。。

3. Q'を計算する

ひとまず置いておいて、先に進みましょう。

1. 1-2で得られたQを2-2で得たQに乗じる(Q')

先ほどR2を計算したときと同じQR分解で、今度は行列Qを得ます。

Q1 <- qr.Q(qr(R1))

さらに説明変数の数(今回は2)ごとにデータを分割します。

## 説明変数の数
p <- ncol(x)

Q2list <- list()
for(i in 1:m) {
   Q2list[[i]] <- Q1[(i-1) * p + 1:p, ]
}

このQ2listに、最初にQR分解した結果の行列Q(QR1$Q)を掛け合わせます。

Q3list <- list()
for(i in 1:m) {
   Q3list[[i]] <- QR1[[i]]$Q %*% Q2list[[i]]
}

ここで得られたQ3listはデータ全体を使ってQR分解したときの Q'になっているはずです。確認してみましょう:

> head(cbind(
+    do.call("rbind", Q3list),
+    qr.Q(qr(x))))
          [,1]      [,2]       [,3]       [,4]
[1,] 0.1414214 0.3079956 -0.1414214 -0.3079956
[2,] 0.1414214 0.3079956 -0.1414214 -0.3079956
[3,] 0.1414214 0.2269442 -0.1414214 -0.2269442
[4,] 0.1414214 0.2269442 -0.1414214 -0.2269442
[5,] 0.1414214 0.1999270 -0.1414214 -0.1999270
[6,] 0.1414214 0.1729098 -0.1414214 -0.1729098

また符号が反転してますね。。。

原因はわかりませんが、 R^{-1}も符号が反転していたので、結果的には元に戻るはずです。気にしないで進めましょう。

4. 2と3の結果およびyにより解を得る

1. 3-1で得たQ'にyを乗じる

上で計算された行列をylistと乗じ、結果を要素ごとに足し合わせます。

Vlist <- list()
for(i in 1:m) {
   Vlist[[i]] <- t(Q3list[[i]]) %*% ylist[[i]]
}

sumV <- Vlist[[1]]
for(i in 2:m) {
   sumV <- sumV + Vlist[[i]]
}
2. 両者を乗じる

最後に、2-3で得た R^{-1}sumVを掛け合わせれば解が得られるはずです。どうでしょうか?

> cbind(
+    R_inv %*% sumV,
+    solve(crossprod(x, x)) %*% crossprod(x, y),
+    solve(qr.R(qr(x))) %*% t(qr.Q(qr(x))) %*% y
+ )
           [,1]       [,2]       [,3]
[1,] -17.579095 -17.579095 -17.579095
[2,]   3.932409   3.932409   3.932409

同じですね!

終わりに

今回はデータを小集団に分割しながら線形回帰の解を推定する、ということを紹介しました。今の時代にどうしても必要な知識かと言えばそんなこともありませんが、知っておくとと良いこともあるよ、ということで。

なおこの記事の冒頭で紹介したこちらのページでは、さらに「データソースが複数に分かれている」条件でも線形回帰の解が推定できることを示しています(例えばデータを格納しているサーバーが複数に分かれており、しかもデータのコピーが難しい状況を想定しているようです)。こちらはなかなか実用的なのではないでしょうか?

freakonometrics.hypotheses.org

おまけ

上記の工程をtidyに実行しようとすると、以下のようになるようです(こちらから)

library(tidyverse)
X <- data_frame(intercept = 1, speed = cars$speed) %>% as.matrix()
y <- cars$dist
mats <- X %>%
   as_data_frame() %>%
   mutate(
      id = rep(1:5, each = 10) ,
      y = y
   ) %>% 
   ## this is where partitioning happens
   nest(-id) %>% 
   mutate(
      X = map(data, ~ .x %>% select(-y) %>% as.matrix()),
      y = map(data, ~ .x %>% pull(y))
   ) %>% 
   ## We calculate QR decomposition for each partition independently
   mutate(
      Q2 = map(X, ~ .x %>% qr() %>% qr.Q()),
      R1 = map(X, ~ .x %>% qr() %>% qr.R())
   )


df_collect <- mats$R1 %>% do.call(what = 'rbind', args = .)
data.frame(dimension = c('rows', 'columns'), cbind(X %>% dim(), df_collect %>% dim()))


## Number of groups for nesting can be automatically inferred
m2 <-  dim(mats$R1[[1]])[2]

## The map-stage QR-decomposition
Q1 = df_collect %>% qr %>% qr.Q
R2 = df_collect %>% qr %>% qr.R

## For some reason this did not work with a `mutate` command...
mats$Q1 = 
   Q1 %>% 
   as_data_frame() %>% 
   mutate(id = ceiling(row_number() / m2)) %>% 
   nest(-id) %>% 
   mutate(data = map(data, ~ as.matrix(.x))) %>% 
   pull(data)

v_sum = 
   mats %>% 
   mutate(Q3_t = map2(.x = Q2, .y = Q1, .f = ~ t(.x %*% .y))) %>%
   mutate(V = map2(.x = Q3_t, .y = y, .f = ~ .x %*% .y)) %>% 
   pull(V) %>% 
   reduce(`+`)

t(solve(R2) %*% v_sum)

  1. 果たして今の時代にどれほどのニーズがあるのかわかりませんが。。。