矩阵乘法的Strassen算法

  最简单的矩阵乘法可以通过三重循环来实现,其时间复杂度为\(\Theta(n^{3})\),Strassen算法通过巧妙的增加加法来减少乘法实现了\(O(n^{2.81})\)的时间复杂度

Strassen算法的四个步骤:

  1. 将输入矩阵A、B与输出矩阵C分解为\(n/2\times n/2\)的子矩阵,采用下标计算方法,此步骤花费\(\Theta\)(1)时间。
  2. 创建10个\(n/2\times n/2\)的矩阵,每个矩阵保存步骤1中创建的两个子矩阵的和或差,花费\(\Theta(n^2)\)
  3. 用步骤1中创建的子矩阵和步骤2中创建的10个矩阵,递归的计算7个\(P_i\)矩阵积。
  4. 通过\(P_i\)矩阵的不同组合进行加减运算,计算出C的子矩阵,花费时间\(\Theta(n^2)\)

  为了方便计算矩阵积C=A\(\cdot\)B,假定三个矩阵均为\(n\times n\)矩阵,其中n为2的幂。做出这个假设是因为在每个分解步骤中,\(n\times n\)矩阵都被划分为4个\(n/2\times n/2\)的子矩阵,如果假定\(n\)是2的幂,则只要\(n\geq 2\)即可保证子矩阵规模\(n/2\)为整数。

假定将A、B和C均分解为4个\(n/2\times n/2\)的子矩阵:

\[A= \left[ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \\ \end{matrix} \right] ,B= \left[ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \\ \end{matrix} \right] ,C= \left[ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \\ \end{matrix} \right] \] 根据矩阵乘法的定义,可以得到如下4个公式: \[ \begin{equation}\label{1} \begin{aligned} C_{11}&=A_{11}\cdot B_{11}+A_{12}\cdot B_{21} \\ C_{12}&=A_{11}\cdot B_{12}+A_{12}\cdot B_{22} \\ C_{21}&=A_{21}\cdot B_{11}+A_{22}\cdot B_{21} \\ C_{22}&=A_{21}\cdot B_{12}+A_{22}\cdot B_{22} \\ \end{aligned} \end{equation} \]

步骤2中,创建如下10个矩阵: \[ \begin{equation}\label{2} \begin{aligned} S_1&=B_{12}-B_{22}\\ S_2&=A_{11}-A_{12}\\ S_3&=A_{21}+A_{22}\\ S_4&=B_{21}-B_{21}\\ S_5&=A_{11}+A_{22}\\ S_6&=B_{11}+B_{22}\\ S_7&=A_{12}-A_{22}\\ S_8&=B_{21}+B_{22}\\ S_9&=A_{11}-A_{21}\\ S_{10}&=B_{11}+B_{12}\\ \end{aligned} \end{equation} \] 由于必须进行10次\(n/2\times n/2\)矩阵的加减法,因此,该步骤花费\(\Theta(n^2)\)时间。 步骤3中,递归的计算7次\(n/2\times n/2\)矩阵的乘法,如下所示: \[ \begin{align*} P_1&=A_{11}\cdot S_1\\ P_2&=S_2\cdot B_{22}\\ P_3&=S_3\cdot B_{11}\\ P_4&=A_{22}\cdot S_4\\ P_5&=S_5\cdot S_6\\ P_6&=S_7\cdot S_8\\ P_7&=S_9\cdot S_{10}\\ \end{align*} \] 步骤4中, \[ \begin{align*} C_{11}&=P_5+P_4-P_2+P_6\\ C_{12}&=P_1+P_2\\ C_{21}&=P_3+P_4\\ C_{22}&=P_5+P_1-P_3-P_7\\ \end{align*} \] 共进行了8次\(n/2\times n/2\)矩阵的加减法,因此花费\(\Theta(n^2)\)时间。 代值计算后可以发现(2)式结果与(1)式是相同的。

描述Strassen算法运行时间T(n)的递归式: \[T(n)=\begin{cases} \Theta(1)&n=1\\ 7T(n/2)+\Theta(n^2)&n>1\\ \end{cases}\]

  用主方法来求解这个递归式,可知解为\(T\left(n\right)=\Theta(n^{lg7})\),由于\(lg7\)介于2.80和2.81之间,所以时间复杂度为\(O(n^{2.81})\)

天知道Strassen是怎么想到这个方法的QAQ