001: /* ///////////////////////////// P /// L /// A /// S /// M /// A /////////////////////////////// */
002: /* ///                    PLASMA computational routine (version 2.1.0)                       ///
003:  * ///                    Author: Emmanuel Agullo                                            ///
004:  * ///                    Release Date: November, 15th 2009                                  ///
005:  * ///                    PLASMA is a software package provided by Univ. of Tennessee,       ///
006:  * ///                    Univ. of California Berkeley and Univ. of Colorado Denver          /// */
007: 
008: /* /////////////////////////// P /// U /// R /// P /// O /// S /// E /////////////////////////// */
009: // PLASMA_sgemm - Performs one of the matrix-matrix operations 
010: //
011: //   C = alpha*op( A )*op( B ) + beta*C,
012: //
013: // where op( X ) is one of 
014: //
015: //   op( X ) = X  or op( X ) = X'
016: //
017: // alpha and beta are scalars, and A, B and C  are matrices, with op( A ) 
018: // an m by k matrix, op( B ) a k by n matrix and C an m by n matrix.  
019: 
020: /* ///////////////////// A /// R /// G /// U /// M /// E /// N /// T /// S ///////////////////// */
021: // transA   PLASMA_enum (IN)
022: //          Specifies whether the matrix A is transposed, not transposed or conjugate transposed:
023: //          = PlasmaNoTrans:   A is transposed;
024: //          = PlasmaTrans:     A is not transposed;
025: //          = PlasmaTrans: A is conjugate transposed.
026: //          Currently only PlasmaNoTrans is supported
027: //
028: // transB   PLASMA_enum (IN)
029: //          Specifies whether the matrix B is transposed, not transposed or conjugate transposed:
030: //          = PlasmaNoTrans:   B is transposed;
031: //          = PlasmaTrans:     B is not transposed;
032: //          = PlasmaTrans: B is conjugate transposed.
033: //          Currently only PlasmaNoTrans is supported
034: //
035: // M        int (IN)
036: //          M specifies the number of rows of the matrix op( A ) and of the matrix C. M >= 0.
037: //
038: // N        int (IN)
039: //          N specifies the number of columns of the matrix op( B ) and of the matrix C. N >= 0.
040: //
041: // K        int (IN)
042: //          K specifies the number of columns of the matrix op( A ) and the number of rows of 
043: //          the matrix op( B ). K >= 0.
044: //
045: // alpha    float (IN)
046: //          alpha specifies the scalar alpha
047: //
048: // A        float* (IN)
049: //          A is a LDA-by-ka matrix, where ka is K when  transA = PlasmaNoTrans,  
050: //          and is  M  otherwise.
051: //
052: // LDA      int (IN)
053: //          The leading dimension of the array A. LDA >= max(1,M).
054: //
055: // B        float* (IN)
056: //          B is a LDB-by-kb matrix, where kb is N when  transB = PlasmaNoTrans,  
057: //          and is  K  otherwise.
058: //
059: // LDB      int (IN)
060: //          The leading dimension of the array B. LDB >= max(1,N).
061: //
062: // beta     float (IN)
063: //          beta specifies the scalar beta
064: //
065: // C        float* (INOUT)
066: //          C is a LDC-by-N matrix.
067: //          On exit, the array is overwritten by the M by N matrix ( alpha*op( A )*op( B ) + beta*C )
068: //
069: // LDC      int (IN)
070: //          The leading dimension of the array C. LDC >= max(1,M).
071: 
072: /* ///////////// R /// E /// T /// U /// R /// N /////// V /// A /// L /// U /// E ///////////// */
073: //          = 0: successful exit
074: 
075: /* //////////////////////////////////// C /// O /// D /// E //////////////////////////////////// */
076: #include "common.h"
077: #include "lapack.h"
078: 
079: int PLASMA_sgemm(PLASMA_enum transA, PLASMA_enum transB, int M, int N, int K, 
080:                  float alpha, float *A, int LDA,
081:                  float *B, int LDB, 
082:                  float beta, float *C, int LDC)
083: {
084:     int NB, MT, NT, KT;
085:     int nrowA, nrowB;
086:     int status;
087:     float *Abdl;
088:     float *Bbdl;
089:     float *Cbdl;
090:     plasma_context_t *plasma;
091: 
092:     plasma = plasma_context_self();
093:     if (plasma == NULL) {
094:         plasma_fatal_error("PLASMA_sgemm", "PLASMA not initialized");
095:         return PLASMA_ERR_NOT_INITIALIZED;
096:     }
097: 
098:     /* TODO: to adapt nrowA and nrowB depending on transA et transB, respectively. */
099:     nrowA = M;
100:     nrowB = N;
101: 
102:     /* Check input arguments */
103:     if (transA != PlasmaNoTrans && transA != PlasmaTrans && transA != PlasmaTrans) {
104:         plasma_error("PLASMA_sgemm", "illegal value of transA");
105:         return 1;
106:     }
107:     if (transB != PlasmaNoTrans && transB != PlasmaTrans && transB != PlasmaTrans) {
108:         plasma_error("PLASMA_sgemm", "illegal value of transB");
109:         return 2;
110:     }
111:     if (M < 0) {
112:         plasma_error("PLASMA_sgemm", "illegal value of M");
113:         return 3;
114:     }
115:     if (N < 0) {
116:         plasma_error("PLASMA_sgemm", "illegal value of N");
117:         return 4;
118:     }
119:     if (K < 0) {
120:         plasma_error("PLASMA_sgemm", "illegal value of N");
121:         return 5;
122:     }
123:     if (LDA < max(1, nrowA)) {
124:         plasma_error("PLASMA_sgemm", "illegal value of LDA");
125:         return 8;
126:     }
127:     if (LDB < max(1, nrowB)) {
128:         plasma_error("PLASMA_sgemm", "illegal value of LDB");
129:         return 10;
130:     }
131:     if (LDC < max(1, M)) {
132:         plasma_error("PLASMA_sgemm", "illegal value of LDC");
133:         return 13;
134:     }
135: 
136:     /* Quick return - currently NOT equivalent to LAPACK's
137:      * LAPACK does not have such check for DPOSV */
138: 
139:     if (M == 0 || N == 0 ||
140:         ((alpha == (float)0.0 || K == 0.0) && beta == (float)1.0))
141:         return PLASMA_SUCCESS;
142: 
143:     /* Tune NB depending on M, N & NRHS; Set NBNBSIZE */
144:     status = plasma_tune(PLASMA_FUNC_SGEMM, M, N, 0);
145:     if (status != PLASMA_SUCCESS) {
146:         plasma_error("PLASMA_sgemm", "plasma_tune() failed");
147:         return status;
148:     }
149: 
150:     /* Set MT & NT & KT */
151:     NB = PLASMA_NB;
152:     MT = (M%NB==0) ? (M/NB) : (M/NB+1);
153:     NT = (N%NB==0) ? (N/NB) : (N/NB+1);
154:     KT = (K%NB==0) ? (K/NB) : (K/NB+1);
155: 
156:     /* Allocate memory for matrices in block layout */
157:     Abdl = (float *)plasma_shared_alloc(plasma, MT*KT*PLASMA_NBNBSIZE, PlasmaRealFloat);
158:     Bbdl = (float *)plasma_shared_alloc(plasma, KT*NT*PLASMA_NBNBSIZE, PlasmaRealFloat);
159:     Cbdl = (float *)plasma_shared_alloc(plasma, MT*NT*PLASMA_NBNBSIZE, PlasmaRealFloat);
160:     if (Abdl == NULL || Bbdl == NULL || Cbdl == NULL) {
161:         plasma_error("PLASMA_sgemm", "plasma_shared_alloc() failed");
162:         plasma_shared_free(plasma, Abdl);
163:         plasma_shared_free(plasma, Bbdl);
164:         plasma_shared_free(plasma, Cbdl);
165:         return PLASMA_ERR_OUT_OF_RESOURCES;
166:     }
167: 
168:     PLASMA_desc descA = plasma_desc_init(
169:         Abdl, PlasmaRealFloat,
170:         PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
171:         M, K, 0, 0, M, K);
172: 
173:     PLASMA_desc descB = plasma_desc_init(
174:         Bbdl, PlasmaRealFloat,
175:         PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
176:         K, N, 0, 0, K, N);
177: 
178:     PLASMA_desc descC = plasma_desc_init(
179:         Cbdl, PlasmaRealFloat,
180:         PLASMA_NB, PLASMA_NB, PLASMA_NBNBSIZE,
181:         M, N, 0, 0, M, N);
182: 
183:     plasma_parallel_call_3(plasma_lapack_to_tile,
184:         float*, A,
185:         int, LDA,
186:         PLASMA_desc, descA);
187: 
188:     plasma_parallel_call_3(plasma_lapack_to_tile,
189:         float*, B,
190:         int, LDB,
191:         PLASMA_desc, descB);
192: 
193:     plasma_parallel_call_3(plasma_lapack_to_tile,
194:         float*, C,
195:         int, LDC,
196:         PLASMA_desc, descC);
197: 
198:     plasma_parallel_call_7(plasma_psgemm,
199:         PLASMA_enum, transA,
200:         PLASMA_enum, transB,
201:         float, alpha,
202:         PLASMA_desc, descA,
203:         PLASMA_desc, descB,
204:         float, beta,
205:         PLASMA_desc, descC);
206: 
207: 
208: 
209:     if (PLASMA_INFO == PLASMA_SUCCESS)
210:     {
211:         plasma_parallel_call_3(plasma_tile_to_lapack,
212:             PLASMA_desc, descC,
213:             float*, C,
214:             int, LDC);
215:     }
216:     plasma_shared_free(plasma, Abdl);
217:     plasma_shared_free(plasma, Bbdl);
218:     plasma_shared_free(plasma, Cbdl);
219:     return PLASMA_INFO;
220: }
221: