00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #ifndef MATRIX_H_
00035 #define MATRIX_H_
00036
00037 #include "matrix_funcs.h"
00038 #ifdef NUMPY_INTERFACE
00039 #include <Python.h>
00040 #include <arrayobject.h>
00041 #endif
00042 #include <limits>
00043 #include <assert.h>
00044 #include <stdio.h>
00045 #include <string.h>
00046
00047 #ifdef USE_MKL
00048 #include <mkl.h>
00049 #include <mkl_cblas.h>
00050 #include <mkl_vsl.h>
00051 #include <mkl_vml.h>
00052
00053 #define IS_MKL true
00054
00055 #ifdef DOUBLE_PRECISION
00056 #define MKL_UNIFORM vdRngUniform
00057 #define MKL_NORMAL vdRngGaussian
00058 #define MKL_UNIFORM_RND_METHOD VSL_METHOD_DUNIFORM_STD_ACCURATE
00059 #define MKL_GAUSSIAN_RND_METHOD VSL_METHOD_DGAUSSIAN_BOXMULLER
00060 #define MKL_EXP vdExp
00061 #define MKL_RECIP vdInv
00062 #define MKL_SQUARE vdSqr
00063 #define MKL_TANH vdTanh
00064 #define MKL_LOG vdLn
00065 #define MKL_VECMUL vdMul
00066 #define MKL_VECDIV vdDiv
00067 #else
00068 #define MKL_UNIFORM vsRngUniform
00069 #define MKL_NORMAL vsRngGaussian
00070 #define MKL_UNIFORM_RND_METHOD VSL_METHOD_SUNIFORM_STD_ACCURATE
00071 #define MKL_GAUSSIAN_RND_METHOD VSL_METHOD_SGAUSSIAN_BOXMULLER
00072 #define MKL_EXP vsExp
00073 #define MKL_RECIP vsInv
00074 #define MKL_SQUARE vsSqr
00075 #define MKL_TANH vsTanh
00076 #define MKL_LOG vsLn
00077 #define MKL_VECMUL vsMul
00078 #define MKL_VECDIV vsDiv
00079 #endif
00080
00081 #else
00082 #include <cblas.h>
00083 #define IS_MKL false
00084 #endif
00085
00086 #ifdef DOUBLE_PRECISION
00087 #define CBLAS_GEMM cblas_dgemm
00088 #define CBLAS_SCAL cblas_dscal
00089 #define CBLAS_AXPY cblas_daxpy
00090 #else
00091 #define CBLAS_GEMM cblas_sgemm
00092 #define CBLAS_SCAL cblas_sscal
00093 #define CBLAS_AXPY cblas_saxpy
00094 #endif
00095
00096 #define MTYPE_MAX numeric_limits<MTYPE>::max()
00097
00098 class Matrix {
00099 private:
00100 MTYPE* _data;
00101 bool _ownsData;
00102 int _numRows, _numCols;
00103 int _numElements;
00104 int _numDataBytes;
00105 CBLAS_TRANSPOSE _trans;
00106
00107 void _init(MTYPE* data, int numRows, int numCols, bool transpose, bool ownsData);
00108 void _tileTo2(Matrix& target) const;
00109 void _copyAllTo(Matrix& target) const;
00110 MTYPE _sum_column(int col) const;
00111 MTYPE _sum_row(int row) const;
00112 MTYPE _aggregate(MTYPE(*agg_func)(MTYPE, MTYPE), MTYPE initialValue) const;
00113 void _aggregate(int axis, Matrix& target, MTYPE(*agg_func)(MTYPE, MTYPE), MTYPE initialValue) const;
00114 MTYPE _aggregateRow(int row, MTYPE(*agg_func)(MTYPE, MTYPE), MTYPE initialValue) const;
00115 MTYPE _aggregateCol(int row, MTYPE(*agg_func)(MTYPE, MTYPE), MTYPE initialValue) const;
00116 void _updateDims(int numRows, int numCols);
00117 void _applyLoop(MTYPE(*func)(MTYPE));
00118 void _applyLoop(MTYPE (*func)(MTYPE), Matrix& target);
00119 void _applyLoop2(const Matrix& a, MTYPE(*func)(MTYPE, MTYPE), Matrix& target) const;
00120 void _applyLoop2(const Matrix& a, MTYPE (*func)(MTYPE,MTYPE, MTYPE), MTYPE scalar, Matrix& target) const;
00121 void _applyLoopScalar(const MTYPE scalar, MTYPE(*func)(MTYPE, MTYPE), Matrix& target) const;
00122 void _checkBounds(int startRow, int endRow, int startCol, int endCol) const;
00123 void _divideByVector(const Matrix& vec, Matrix& target);
00124 inline int _getNumColsBackEnd() const {
00125 return _trans == CblasNoTrans ? _numCols : _numRows;
00126 }
00127 public:
00128 enum FUNCTION {
00129 TANH, RECIPROCAL, SQUARE, ABS, EXP, LOG, ZERO, ONE, LOGISTIC1, LOGISTIC2
00130 };
00131 Matrix();
00132 Matrix(int numRows, int numCols);
00133 #ifdef NUMPY_INTERFACE
00134 Matrix(const PyArrayObject *src);
00135 #endif
00136 Matrix(const Matrix &like);
00137 Matrix(MTYPE* data, int numRows, int numCols);
00138 Matrix(MTYPE* data, int numRows, int numCols, bool transpose);
00139 ~Matrix();
00140
00141 inline MTYPE& getCell(int i, int j) const {
00142
00143
00144 if (_trans == CblasTrans) {
00145 return _data[j * _numRows + i];
00146 }
00147 return _data[i * _numCols + j];
00148 }
00149
00150 MTYPE& operator()(int i, int j) const {
00151 return getCell(i, j);
00152 }
00153
00154 inline MTYPE* getData() const {
00155 return _data;
00156 }
00157
00158 inline bool isView() const {
00159 return !_ownsData;
00160 }
00161
00162 inline int getNumRows() const {
00163 return _numRows;
00164 }
00165
00166 inline int getNumCols() const {
00167 return _numCols;
00168 }
00169
00170 inline int getNumDataBytes() const {
00171 return _numDataBytes;
00172 }
00173
00174 inline int getNumElements() const {
00175 return _numElements;
00176 }
00177
00178 inline CBLAS_TRANSPOSE getBLASTrans() const {
00179 return _trans;
00180 }
00181
00182 inline bool isSameDims(const Matrix& a) const {
00183 return a.getNumRows() == getNumRows() && a.getNumCols() == getNumCols();
00184 }
00185
00186 inline bool isTrans() const {
00187 return _trans == CblasTrans;
00188 }
00189
00190
00191
00192
00193
00194
00195
00196 inline void setTrans(bool trans) {
00197 _trans = trans ? CblasTrans : CblasNoTrans;
00198 }
00199
00200 void apply(FUNCTION f);
00201 void apply(Matrix::FUNCTION f, Matrix& target);
00202 void subtractFromScalar(MTYPE scalar);
00203 void subtractFromScalar(MTYPE scalar, Matrix &target) const;
00204 void biggerThanScalar(MTYPE scalar);
00205 void smallerThanScalar(MTYPE scalar);
00206 void equalsScalar(MTYPE scalar);
00207 void biggerThanScalar(MTYPE scalar, Matrix& target) const;
00208 void smallerThanScalar(MTYPE scalar, Matrix& target) const;
00209 void equalsScalar(MTYPE scalar, Matrix& target) const;
00210 void biggerThan(Matrix& a);
00211 void biggerThan(Matrix& a, Matrix& target) const;
00212 void smallerThan(Matrix& a);
00213 void smallerThan(Matrix& a, Matrix& target) const;
00214 void minWith(Matrix &a);
00215 void minWith(Matrix &a, Matrix &target) const;
00216 void maxWith(Matrix &a);
00217 void maxWith(Matrix &a, Matrix &target) const;
00218 void equals(Matrix& a);
00219 void equals(Matrix& a, Matrix& target) const;
00220 void notEquals(Matrix& a) ;
00221 void notEquals(Matrix& a, Matrix& target) const;
00222 void add(const Matrix &m);
00223 void add(const Matrix &m, MTYPE scale);
00224 void add(const Matrix &m, Matrix& target);
00225 void add(const Matrix &m, MTYPE scale, Matrix& target);
00226 void subtract(const Matrix &m);
00227 void subtract(const Matrix &m, Matrix& target);
00228 void subtract(const Matrix &m, MTYPE scale);
00229 void subtract(const Matrix &m, MTYPE scale, Matrix& target);
00230 void addVector(const Matrix& vec, MTYPE scale);
00231 void addVector(const Matrix& vec, MTYPE scale, Matrix& target);
00232 void addVector(const Matrix& vec);
00233 void addVector(const Matrix& vec, Matrix& target);
00234 void addScalar(MTYPE scalar);
00235 void addScalar(MTYPE scalar, Matrix& target) const;
00236 void maxWithScalar(MTYPE scalar);
00237 void maxWithScalar(MTYPE scalar, Matrix &target) const;
00238 void minWithScalar(MTYPE scalar);
00239 void minWithScalar(MTYPE scalar, Matrix &target) const;
00240 void eltWiseMultByVector(const Matrix& vec);
00241 void eltWiseMultByVector(const Matrix& vec, Matrix& target);
00242 void eltWiseDivideByVector(const Matrix& vec);
00243 void eltWiseDivideByVector(const Matrix& vec, Matrix& target);
00244 void resize(int newNumRows, int newNumCols);
00245 void resize(const Matrix& like);
00246 Matrix& slice(int startRow, int endRow, int startCol, int endCol) const;
00247 void slice(int startRow, int endRow, int startCol, int endCol, Matrix &target) const;
00248 Matrix& sliceRows(int startRow, int endRow) const;
00249 void sliceRows(int startRow, int endRow, Matrix& target) const;
00250 Matrix& sliceCols(int startCol, int endCol) const;
00251 void sliceCols(int startCol, int endCol, Matrix& target) const;
00252 void rightMult(const Matrix &b, MTYPE scale);
00253 void rightMult(const Matrix &b, Matrix &target) const;
00254 void rightMult(const Matrix &b);
00255 void rightMult(const Matrix &b, MTYPE scaleAB, Matrix &target) const;
00256 void addProduct(const Matrix &a, const Matrix &b, MTYPE scaleAB, MTYPE scaleThis);
00257 void addProduct(const Matrix& a, const Matrix& b);
00258 void eltWiseMult(const Matrix& a);
00259 void eltWiseMult(const Matrix& a, Matrix& target) const;
00260 void eltWiseDivide(const Matrix& a);
00261 void eltWiseDivide(const Matrix& a, Matrix &target) const;
00262 Matrix& transpose() const;
00263 Matrix& transpose(bool hard) const;
00264 Matrix& tile(int timesY, int timesX) const;
00265 void tile(int timesY, int timesX, Matrix& target) const;
00266 void copy(Matrix &dest, int srcStartRow, int srcEndRow, int srcStartCol, int srcEndCol, int destStartRow, int destStartCol) const;
00267 Matrix& copy() const;
00268 void copy(Matrix& target) const;
00269 Matrix& sum(int axis) const;
00270 void sum(int axis, Matrix &target) const;
00271 MTYPE sum() const;
00272 MTYPE max() const;
00273 Matrix& max(int axis) const;
00274 void max(int axis, Matrix& target) const;
00275 MTYPE min() const;
00276 Matrix& min(int axis) const;
00277 void min(int axis, Matrix& target) const;
00278 void scale(MTYPE scale);
00279 #ifdef USE_MKL
00280 void randomizeNormal(VSLStreamStatePtr stream, MTYPE mean, MTYPE stdev);
00281 void randomizeUniform(VSLStreamStatePtr stream);
00282 void randomizeNormal(VSLStreamStatePtr stream);
00283 #else
00284 void randomizeNormal(MTYPE mean, MTYPE stdev);
00285 void randomizeUniform();
00286 void randomizeNormal();
00287 #endif
00288 void print() const;
00289 void print(int startRow,int rows, int startCol,int cols) const;
00290 void print(int rows, int cols) const;
00291 };
00292
00293 #endif