MAGMA  1.2.0
MatrixAlgebraonGPUandMulticoreArchitectures
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Groups
codelet_zgemm.c
Go to the documentation of this file.
1 
17 #include "morse_starpu.h"
18 
19 /*
20  * Codelet CPU
21  */
22 static void cl_zgemm_cpu_func(void *descr[], void *cl_arg)
23 {
24  int transA;
25  int transB;
26  int M;
27  int N;
28  int K;
29  PLASMA_Complex64_t alpha;
30  PLASMA_Complex64_t *A;
31  int LDA;
32  PLASMA_Complex64_t *B;
33  int LDB;
34  PLASMA_Complex64_t beta;
35  PLASMA_Complex64_t *C;
36  int LDC;
37 
38  A = (PLASMA_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[0]);
39  B = (PLASMA_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[1]);
40  C = (PLASMA_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[2]);
41 
42  starpu_unpack_cl_args(cl_arg, &transA, &transB, &M, &N, &K, &alpha, &LDA, &LDB, &beta, &LDC);
45  (CBLAS_TRANSPOSE)transA, (CBLAS_TRANSPOSE)transB,
46  M, N, K,
47  CBLAS_SADDR(alpha), A, LDA,
48  B, LDB,
49  CBLAS_SADDR(beta), C, LDC);
50 }
51 
52 /*
53  * Codelet Multi-cores
54  */
55 #ifdef MORSE_USE_MULTICORE
56 static void cl_zgemm_mc_func(void *descr[], void *cl_arg)
57 {
58  int transA;
59  int transB;
60  int M;
61  int N;
62  int K;
63  PLASMA_Complex64_t alpha;
64  PLASMA_Complex64_t *A;
65  int LDA;
66  PLASMA_Complex64_t *B;
67  int LDB;
68  PLASMA_Complex64_t beta;
69  PLASMA_Complex64_t *C;
70  int LDC;
71 
72  A = (PLASMA_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[0]);
73  B = (PLASMA_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[1]);
74  C = (PLASMA_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[2]);
75 
76  starpu_unpack_cl_args(cl_arg, &transA, &transB, &M, &N, &K, &alpha, &LDA, &LDB, &beta, &LDC);
77 
78  PLASMA_zgemm_Lapack(transA, transB,
79  M, N, K,
80  alpha, A, LDA,
81  B, LDB, beta,
82  C, LDC);
83 }
84 #else
85 #define cl_zgemm_mc_func cl_zgemm_cpu_func
86 #endif
87 
88 /*
89  * Codelet GPU
90  */
91 #ifdef MORSE_USE_CUDA
92 static void cl_zgemm_cuda_func(void *descr[], void *cl_arg)
93 {
94  int transA;
95  int transB;
96  int M;
97  int N;
98  int K;
99  cuDoubleComplex alpha;
100  cuDoubleComplex *A;
101  int LDA;
102  cuDoubleComplex *B;
103  int LDB;
104  cuDoubleComplex beta;
105  cuDoubleComplex *C;
106  int LDC;
107 
108  A = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
109  B = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
110  C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]);
111 
112  starpu_unpack_cl_args(cl_arg, &transA, &transB, &M, &N, &K, &alpha, &LDA, &LDB, &beta, &LDC);
113 
114  cublasZgemm (
115  plasma_lapack_constants[transA][0],
116  plasma_lapack_constants[transB][0],
117  M, N, K,
118  alpha, A, LDA,
119  B, LDB, beta, C, LDC);
120  cudaThreadSynchronize();
121 }
122 #endif
123 
124 /*
125  * Codelet definition
126  */
127 CODELETS(zgemm, 3, cl_zgemm_cpu_func, cl_zgemm_cuda_func, cl_zgemm_mc_func)
128 
129 
130 /*
131  * Wrapper
132  */
133 void MORSE_zgemm( MorseOption_t *options,
134  int transA, int transB,
135  int m, int n, int k,
136  PLASMA_Complex64_t alpha, magma_desc_t *A, int Am, int An,
137  magma_desc_t *B, int Bm, int Bn,
138  PLASMA_Complex64_t beta, magma_desc_t *C, int Cm, int Cn)
139 {
140  starpu_codelet *zgemm_codelet;
141  void (*callback)(void*) = options->profiling ? cl_zgemm_callback : NULL;
142  int lda = BLKLDD( A, Am );
143  int ldb = BLKLDD( B, Bm );
144  int ldc = BLKLDD( C, Cm );
145 
146 #ifdef MORSE_USE_MULTICORE
147  zgemm_codelet = options->parallel ? &cl_zgemm_mc : &cl_zgemm;
148 #else
149  zgemm_codelet = &cl_zgemm;
150 #endif
151 
153  zgemm_codelet,
154  VALUE, &transA, sizeof(PLASMA_enum),
155  VALUE, &transB, sizeof(PLASMA_enum),
156  VALUE, &m, sizeof(int),
157  VALUE, &n, sizeof(int),
158  VALUE, &k, sizeof(int),
159  VALUE, &alpha, sizeof(PLASMA_Complex64_t),
160  INPUT, BLKADDR( A, PLASMA_Complex64_t, Am, An ),
161  VALUE, &lda, sizeof(int),
162  INPUT, BLKADDR( B, PLASMA_Complex64_t, Bm, Bn ),
163  VALUE, &ldb, sizeof(int),
164  VALUE, &beta, sizeof(PLASMA_Complex64_t),
165  INOUT, BLKADDR( C, PLASMA_Complex64_t, Cm, Cn ),
166  VALUE, &ldc, sizeof(int),
167  PRIORITY, options->priority,
168  CALLBACK, callback, NULL,
169  0);
170 }