Categories
Coding R

Matrix Exponentiation in R

A short while ago, I needed to do some matrix exponentiation in R (raising a matrix to a power). By default, the exponentiation operator ^ in R, when applied to a matrix, will just raise each element of the matrix to a power, rather than the matrix itself:

> A < - matrix(c(1:4), nrow=2, byrow=T) > A
[,1] [,2]
[1,] 1 2
[2,] 3 4
> A ^ 2
[,1] [,2]
[1,] 1 4
[2,] 9 16
>
>

Whereas what we really want in this case is the equivalent of:

> A %*% A
[,1] [,2]
[1,] 7 10
[2,] 15 22
>

There are a few different ways of creating a matrix exponentiation operator in R: we could create an R function and create an exponentiation operator for matrices, similar to the %*% matrix multiplication operator that exists already, or we could write the function in C and link to it. It seems logical to create an %^% operator to match the current set of matrix operators. The choice of writing the exponentiation routine in R or C is actually less important than the exponentiation algorithm used. A fast exponentiation algorithm is the “square and multiply” method, an explanation of which can be found here.

I chose to write a basic exponentiation routine in C, by creating a new operator definition, and following these basic rules:

  • tex:A^{-1} returns the matrix inverse.
  • tex:A^0 returns the identity matrix.
  • tex:A^1 returns the original matrix.
  • tex:A^n returns A to the nth power.
  • Complex as well as real-valued matrices should be supported.

The code for this is shown below. Note that to calculate the matrix inverse, I don’t call directly into the underlying BLAS libraries, but actually call back into R’s solve() function to calculate the inverse. The changes were made to array.c in the R source tree under /src/main.

SEXP do_matexp(SEXP call, SEXP op, SEXP args, SEXP rho)
{
int nrows, ncols;
SEXP matrix, tmp, dims, dims2;
SEXP x, y, x_, x__;
int i,j,e,mode;

// necessary?
mode = isComplex(CAR(args)) ? CPLXSXP : REALSXP;

SETCAR(args, coerceVector(CAR(args), mode));
x = CAR(args);
y = CADR(args);

dims = getAttrib(x, R_DimSymbol);
nrows = INTEGER(dims)[0];
ncols = INTEGER(dims)[1];

if (nrows != ncols)
error(_("can only raise square matrix to power"));

if (!isNumeric(y))
error(_("exponent must be a scalar integer"));

e = asInteger(y);

if (e < -1)
error(_("exponent must be >= -1"));
else if (e == 1)
return x;
else if (e == -1) { /* return matrix inverse via solve() */
SEXP p1, p2, inv;
PROTECT(p1 = p2 = allocList(2));
SET_TYPEOF(p1, LANGSXP);
CAR(p2) = install("solve.default");
p2 = CDR(p2);
CAR(p2) = x;
inv = eval(p1, rho);

UNPROTECT(1);
return inv;
}

PROTECT(matrix = allocVector(mode, nrows * ncols));
PROTECT(tmp = allocVector(mode, nrows * ncols));
PROTECT(x_ = allocVector(mode, nrows * ncols));
PROTECT(x__ = allocVector(mode, nrows * ncols));

if (mode == REALSXP)
Memcpy(REAL(x_), REAL(x), (size_t)nrows*ncols);
else
Memcpy(COMPLEX(x_), COMPLEX(x), (size_t)nrows*ncols);

// Initialize matrix to identity matrix
// Set x[i * ncols + i] = 1
if (mode == REALSXP)
for (i = 0; i < ncols*nrows; i++)
REAL(matrix)[i] = ((i % (ncols+1) == 0) ? 1 : 0);
else
for (i = 0; i < ncols*nrows;i++) {
COMPLEX(matrix)[i].i = 0.0;
COMPLEX(matrix)[i].r = ((i % (ncols+1) == 0) ? 1.0 : 0.0);
}

if (e == 0) {
; // return identity matrix
}
else
while (e > 0) {
if (e & 1) {
if (mode == REALSXP)
matprod(REAL(matrix), nrows, ncols,
REAL(x_), nrows, ncols, REAL(tmp));
else
cmatprod(COMPLEX(matrix), nrows, ncols,
COMPLEX(x_), nrows, ncols, COMPLEX(tmp));

//copyMatrixData(tmp, matrix, nrows, ncols, mode);
if (mode == REALSXP)
Memcpy(REAL(matrix), REAL(tmp), (size_t)nrows*ncols);
else
Memcpy(COMPLEX(matrix), COMPLEX(tmp), (size_t)nrows*ncols);
e--;
}

if (mode == REALSXP)
matprod(REAL(x_), nrows, ncols,
REAL(x_), nrows, ncols, REAL(x__));
else
cmatprod(COMPLEX(x_), nrows, ncols,
COMPLEX(x_), nrows, ncols, COMPLEX(x__));

//copyMatrixData(x__, x_, nrows, ncols, mode);
if (mode == REALSXP)
Memcpy(REAL(x_), REAL(x__), (size_t)nrows*ncols);
else
Memcpy(COMPLEX(x_), COMPLEX(x__), (size_t)nrows*ncols);

e >>= 1;
}

PROTECT(dims2 = allocVector(INTSXP, 2));
INTEGER(dims2)[0] = nrows;
INTEGER(dims2)[1] = ncols;
setAttrib(matrix, R_DimSymbol, dims2);

UNPROTECT(5);
return matrix;
}

To actually hook this routine up to the %^% operator, I needed a further piece of glue in /src/main/names.c to associate the operator:

{"%^%", do_matexp, 3, 1, 2, {PP_BINARY, PREC_POWER, 0}}

Then to try it out (after recompiling R of course!):

> A < - matrix(c(1:4), nrow=2, byrow=TRUE) > A %^% 2
[,1] [,2]
[1,] 7 10
[2,] 15 22
>