问题
I am reading Introduction to Algorithms by CLRS. Book shows pseudocode for simple divide and conquer matrix multiplication:
n = A.rows
let c be a new n x n matrix
if n == 1
c11 = a11 * b11
else partition A, B, and C
C11 = SquareMatrixMultiplyRecursive(A11, B11)
+ SquareMatrixMultiplyRecursive(A12, B21)
//...
return C
Where for example, A11 is submatrix of A of size n/2 x n/2. Author also hints that I should use index calculations instead of creating new matrices to represent submatrices, so I did this:
#include <iostream>
#include <vector>
template<class T>
struct Matrix
{
Matrix(size_t r, size_t c)
{
Data.resize(c, std::vector<T>(r, 0));
}
void SetSubMatrix(const int r, const int c, const int n, const Matrix<T>& A, const Matrix<T>& B)
{
for(int _c=c; _c<n; ++_c)
{
for(int _r=r; _r<n; ++_r)
{
Data[_c][_r] = A.Data[_c][_r] + B.Data[_c][_r];
}
}
}
static Matrix<T> SquareMultiplyRecursive(Matrix<T>& A, Matrix<T>& B, int ar, int ac, int br, int bc, int n)
{
Matrix<T> C(n, n);
if(n == 1)
{
C.Data[0][0] = A.Data[ac][ar] * B.Data[bc][br];
}
else
{
C.SetSubMatrix(0, 0, n / 2,
SquareMultiplyRecursive(A, B, ar, ac, br, bc, n / 2),
SquareMultiplyRecursive(A, B, ar, ac + (n / 2), br + (n / 2), bc, n / 2));
C.SetSubMatrix(0, n / 2, n / 2,
SquareMultiplyRecursive(A, B, ar, ac, br, bc + (n / 2), n / 2),
SquareMultiplyRecursive(A, B, ar, ac + (n / 2), br + (n / 2), bc + (n / 2), n / 2));
C.SetSubMatrix(n / 2, 0, n / 2,
SquareMultiplyRecursive(A, B, ar + (n / 2), ac, br, bc, n / 2),
SquareMultiplyRecursive(A, B, ar + (n / 2), ac + (n / 2), br + (n / 2), bc, n / 2));
C.SetSubMatrix(n / 2, n / 2, n / 2,
SquareMultiplyRecursive(A, B, ar + (n / 2), ac, br, bc + (n / 2), n / 2),
SquareMultiplyRecursive(A, B, ar + (n / 2), ac + (n / 2), br + (n / 2), bc + (n / 2), n / 2));
}
return C;
}
void Print()
{
for(int c=0; c<Data.size(); ++c)
{
for(int r=0; r<Data[0].size(); ++r)
{
std::cout << Data[c][r] << " ";
}
std::cout << "\n";
}
std::cout << "\n";
}
std::vector<std::vector<T> > Data;
};
int main()
{
Matrix<int> A(2, 2);
Matrix<int> B(2, 2);
A.Data[0][0] = 2;
A.Data[0][1] = 1;
A.Data[1][0] = 1;
A.Data[1][1] = 2;
B.Data[0][0] = 2;
B.Data[0][1] = 1;
B.Data[1][0] = 1;
B.Data[1][1] = 2;
A.Print();
B.Print();
Matrix<int> C(Matrix<int>::SquareMultiplyRecursive(A, B, 0, 0, 0, 0, 2));
C.Print();
}
It gives me incorrect results, tho I am not sure what I'm doing wrong...
回答1:
// Recursive naive matrix multiplication in C, not strassen.
// 2013-Feb-15 Fri 12:28 moshahmed/at/gmail
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#define M 2
#define N (1<<M)
typedef int mat[N][N]; // mat[2**M,2**M] for divide and conquer mult.
typedef struct { int ra, rb, ca, cb; } corners; // for tracking rows and columns.
// set A[a] = k
void set(mat A, corners a, int k){
int i,j;
for(i=a.ra;i<a.rb;i++)
for(j=a.ca;j<a.cb;j++)
A[i][j] = k;
}
// set A[a] = [random(l..h)].
void randk(mat A, corners a, int l, int h){
int i,j;
for(i=a.ra;i<a.rb;i++)
for(j=a.ca;j<a.cb;j++)
A[i][j] = l + rand()% (h-l);
}
// Print A[a]
void print(mat A, corners a, char *name) {
int i,j;
printf("%s = {\n",name);
for(i=a.ra;i<a.rb;i++){
for(j=a.ca;j<a.cb;j++)
printf("%4d, ", A[i][j]);
printf("\n");
}
printf("}\n");
}
// Return 1/4 of the matrix: top/bottom , left/right.
void find_corners(corners a, int i, int j, corners *b) {
int rm = a.ra + (a.rb - a.ra)/2 ;
int cm = a.ca + (a.cb - a.ca)/2 ;
*b = a;
if (i==0) b->rb = rm; // top rows
else b->ra = rm; // bot rows
if (j==0) b->cb = cm; // left cols
else b->ca = cm; // right cols
}
// Naive Multiply: A[a] * B[b] => C[c], recursively.
void mul(mat A, mat B, mat C, corners a, corners b, corners c) {
corners aii[2][2], bii[2][2], cii[2][2];
int i, j, m, n, p;
// Check: A[m n] * B[n p] = C[m p]
m = a.rb - a.ra; assert(m==(c.rb-c.ra));
n = a.cb - a.ca; assert(n==(b.rb-b.ra));
p = b.cb - b.ca; assert(p==(c.cb-c.ca));
assert(m>0);
if (n==1) {
C[c.ra][c.ca] += A[a.ra][a.ca] * B[b.ra][b.ca];
return;
}
// Create the smaller matrices:
for(i=0;i<2;i++) {
for(j=0;j<2;j++) {
find_corners(a, i, j, &aii[i][j]);
find_corners(b, i, j, &bii[i][j]);
find_corners(c, i, j, &cii[i][j]);
}
}
// Now do the 8 sub matrix multiplications.
// C00 = A00*B00 + A01*B10
// C01 = A00*B01 + A01*B11
// C10 = A10*B00 + A11*B10
// C11 = A10*B01 + A11*B11
mul( A, B, C, aii[0][0], bii[0][0], cii[0][0] );
mul( A, B, C, aii[0][1], bii[1][0], cii[0][0] );
mul( A, B, C, aii[0][0], bii[0][1], cii[0][1] );
mul( A, B, C, aii[0][1], bii[1][1], cii[0][1] );
mul( A, B, C, aii[1][0], bii[0][0], cii[1][0] );
mul( A, B, C, aii[1][1], bii[1][0], cii[1][0] );
mul( A, B, C, aii[1][0], bii[0][1], cii[1][1] );
mul( A, B, C, aii[1][1], bii[1][1], cii[1][1] );
}
int main() {
mat A, B, C;
corners ai = {0,N,0,N};
corners bi = {0,N,0,N};
corners ci = {0,N,0,N};
//set(A,ai,2);
//set(B,bi,2);
srand(time(0));
randk(A,ai, 0, 2);
randk(B,bi, 0, 2);
set(C,ci,0); // set to zero before mult.
print(A, ai, "A");
print(B, bi, "B");
mul(A,B,C, ai, bi, ci);
print(C, ci, "C");
return 0;
}
回答2:
I found the solution... SetSubMatrix was completly incorrect:
void SetSubMatrix(const int r, const int c, const int rn, const int cn, const Matrix<T>& A, const Matrix<T>& B)
{
for(int _c=c; _c<cn; ++_c)
{
for(int _r=r; _r<rn; ++_r)
{
Data[_c][_r] = A.Data[_c-c][_r-r] + B.Data[_c-c][_r-r];
}
}
}
static Matrix<T> SquareMultiplyRecursive(Matrix<T>& A, Matrix<T>& B, int ar, int ac, int br, int bc, int n)
{
Matrix<T> C(n, n);
if(n == 1)
{
C.Data[0][0] = A.Data[ac][ar] * B.Data[bc][br];
}
else
{
C.SetSubMatrix(0, 0, n / 2, n / 2,
SquareMultiplyRecursive(A, B, ar, ac, br, bc, n / 2),
SquareMultiplyRecursive(A, B, ar, ac + (n / 2), br + (n / 2), bc, n / 2));
C.SetSubMatrix(0, n / 2, n / 2, n,
SquareMultiplyRecursive(A, B, ar, ac, br, bc + (n / 2), n / 2),
SquareMultiplyRecursive(A, B, ar, ac + (n / 2), br + (n / 2), bc + (n / 2), n / 2));
C.SetSubMatrix(n / 2, 0, n, n / 2,
SquareMultiplyRecursive(A, B, ar + (n / 2), ac, br, bc, n / 2),
SquareMultiplyRecursive(A, B, ar + (n / 2), ac + (n / 2), br + (n / 2), bc, n / 2));
C.SetSubMatrix(n / 2, n / 2, n, n,
SquareMultiplyRecursive(A, B, ar + (n / 2), ac, br, bc + (n / 2), n / 2),
SquareMultiplyRecursive(A, B, ar + (n / 2), ac + (n / 2), br + (n / 2), bc + (n / 2), n / 2));
}
return C;
}
回答3:
Here is my answer, depending on the recursive matrix multiplication. Only for N = 2 ^ M, Where M >= 2
template <std::size_t size>
int matrix_mul_recursive(int N, int i, int j, const int (&A)[size][size], const int (&B)[size][size], int (&C)[size][size]) {
if (N == 1) {
return *(const_cast<int*>(&(A[0][0])) + i) * (*(const_cast<int*>(&(B[0][0])) + j));
}
else {
const int H = N / 2;
const int T = (size * H);
int r = i / size;
int c = 0;
if (j < size) {
c = j;
}
else {
c = j % size;
}
C[r][c] += matrix_mul_recursive<size>(H, i, j, A, B, C) +
matrix_mul_recursive<size>(H, i + H, T + j, A, B, C);
C[r][c + H] += matrix_mul_recursive<size>(H, i, j + H, A, B, C) +
matrix_mul_recursive<size>(H, i + H, T + j + H, A, B, C);
C[r + H][c] += matrix_mul_recursive<size>(H, T + i, j, A, B, C) +
matrix_mul_recursive<size>(H, T + i + H, T + j, A, B, C);
C[r + H][c + H] += matrix_mul_recursive<size>(H, T + i, j + H, A, B, C) +
matrix_mul_recursive<size>(H, T + i + H, T + j + H, A, B, C);
}
return 0;
}
来源:https://stackoverflow.com/questions/12922031/recursive-matrix-multiplication