PLASMA  2.4.5
PLASMA - Parallel Linear Algebra for Scalable Multi-core Architectures
 All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Groups
psgemm.c
Go to the documentation of this file.
1 
16 #include "common.h"
17 
18 #define A(m, n) BLKADDR(A, float, m, n)
19 #define B(m, n) BLKADDR(B, float, m, n)
20 #define C(m, n) BLKADDR(C, float, m, n)
21 /***************************************************************************/
25 {
26  PLASMA_enum transA;
27  PLASMA_enum transB;
28  float alpha;
29  PLASMA_desc A;
30  PLASMA_desc B;
31  float beta;
32  PLASMA_desc C;
33  PLASMA_sequence *sequence;
34  PLASMA_request *request;
35 
36  int K, X, Y;
37  int k, m, n;
38  int next_m;
39  int next_n;
40  int ldam, ldak, ldbn, ldbk, ldcm;
41 
42  float zbeta;
43  float zone = (float)1.0;
44 
45  plasma_unpack_args_9(transA, transB, alpha, A, B, beta, C, sequence, request);
46  if (sequence->status != PLASMA_SUCCESS)
47  return;
48 
49  n = 0;
50  m = PLASMA_RANK;
51  while (m >= C.mt && n < C.nt) {
52  n++;
53  m = m-C.mt;
54  }
55 
56  while (n < C.nt) {
57  next_m = m;
58  next_n = n;
59 
60  next_m += PLASMA_SIZE;
61  while (next_m >= C.mt && next_n < C.nt) {
62  next_n++;
63  next_m = next_m - C.mt;
64  }
65 
66  X = m == C.mt-1 ? C.m - m*C.mb : C.mb;
67  Y = n == C.nt-1 ? C.n - n*C.nb : C.nb;
68 
69  ldcm = BLKLDD(C, m);
70  /*
71  * A: PlasmaNoTrans / B: PlasmaNoTrans
72  */
73  if (transA == PlasmaNoTrans) {
74  ldam = BLKLDD(A, m);
75  if (transB == PlasmaNoTrans) {
76  for (k = 0; k < A.nt; k++) {
77  K = k == A.nt-1 ? A.n-k*A.nb : A.nb;
78  ldbk = BLKLDD(B, k);
79  zbeta = k == 0 ? beta : zone;
80  CORE_sgemm(
81  transA, transB,
82  X, Y, K,
83  alpha, A(m, k), ldam,
84  B(k, n), ldbk,
85  zbeta, C(m, n), ldcm);
86  }
87  }
88  /*
89  * A: PlasmaNoTrans / B: Plasma[Conj]Trans
90  */
91  else {
92  ldbn = BLKLDD(B, n);
93  for (k = 0; k < A.nt; k++) {
94  K = k == A.nt-1 ? A.n-k*A.nb : A.nb;
95  zbeta = k == 0 ? beta : zone;
96  CORE_sgemm(
97  transA, transB,
98  X, Y, K,
99  alpha, A(m, k), ldam,
100  B(n, k), ldbn,
101  zbeta, C(m, n), ldcm);
102  }
103  }
104  }
105  /*
106  * A: Plasma[Conj]Trans / B: PlasmaNoTrans
107  */
108  else {
109  if (transB == PlasmaNoTrans) {
110  for (k = 0; k < A.mt; k++) {
111  K = k == A.mt-1 ? A.m-k*A.mb : A.mb;
112  ldak = BLKLDD(A, k);
113  ldbk = BLKLDD(B, k);
114  zbeta = k == 0 ? beta : zone;
115  CORE_sgemm(
116  transA, transB,
117  X, Y, K,
118  alpha, A(k, m), ldak,
119  B(k, n), ldbk,
120  zbeta, C(m, n), ldcm);
121  }
122  }
123  /*
124  * A: Plasma[Conj]Trans / B: Plasma[Conj]Trans
125  */
126  else {
127  ldbn = BLKLDD(B, n);
128  for (k = 0; k < A.mt; k++) {
129  K = k == A.mt-1 ? A.m-k*A.mb : A.mb;
130  ldak = BLKLDD(A, k);
131  zbeta = k == 0 ? beta : zone;
132  CORE_sgemm(
133  transA, transB,
134  X, Y, K,
135  alpha, A(k, m), ldak,
136  B(n, k), ldbn,
137  zbeta, C(m, n), ldcm);
138  }
139  }
140  }
141  m = next_m;
142  n = next_n;
143  }
144 }
145 
146 /***************************************************************************/
150  float alpha, PLASMA_desc A, PLASMA_desc B,
151  float beta, PLASMA_desc C,
152  PLASMA_sequence *sequence, PLASMA_request *request)
153 {
156 
157  int m, n, k;
158  int ldam, ldak, ldbn, ldbk, ldcm;
159  int tempmm, tempnn, tempkn, tempkm;
160 
161  float zbeta;
162  float zone = (float)1.0;
163 
164  plasma = plasma_context_self();
165  if (sequence->status != PLASMA_SUCCESS)
166  return;
167  QUARK_Task_Flag_Set(&task_flags, TASK_SEQUENCE, (intptr_t)sequence->quark_sequence);
168 
169  for (m = 0; m < C.mt; m++) {
170  tempmm = m == C.mt-1 ? C.m-m*C.mb : C.mb;
171  ldcm = BLKLDD(C, m);
172  for (n = 0; n < C.nt; n++) {
173  tempnn = n == C.nt-1 ? C.n-n*C.nb : C.nb;
174  /*
175  * A: PlasmaNoTrans / B: PlasmaNoTrans
176  */
177  if (transA == PlasmaNoTrans) {
178  ldam = BLKLDD(A, m);
179  if (transB == PlasmaNoTrans) {
180  for (k = 0; k < A.nt; k++) {
181  tempkn = k == A.nt-1 ? A.n-k*A.nb : A.nb;
182  ldbk = BLKLDD(B, k);
183  zbeta = k == 0 ? beta : zone;
185  plasma->quark, &task_flags,
186  transA, transB,
187  tempmm, tempnn, tempkn, A.mb,
188  alpha, A(m, k), ldam, /* lda * Z */
189  B(k, n), ldbk, /* ldb * Y */
190  zbeta, C(m, n), ldcm); /* ldc * Y */
191  }
192  }
193  /*
194  * A: PlasmaNoTrans / B: Plasma[Conj]Trans
195  */
196  else {
197  ldbn = BLKLDD(B, n);
198  for (k = 0; k < A.nt; k++) {
199  tempkn = k == A.nt-1 ? A.n-k*A.nb : A.nb;
200  zbeta = k == 0 ? beta : zone;
202  plasma->quark, &task_flags,
203  transA, transB,
204  tempmm, tempnn, tempkn, A.mb,
205  alpha, A(m, k), ldam, /* lda * Z */
206  B(n, k), ldbn, /* ldb * Z */
207  zbeta, C(m, n), ldcm); /* ldc * Y */
208  }
209  }
210  }
211  /*
212  * A: Plasma[Conj]Trans / B: PlasmaNoTrans
213  */
214  else {
215  if (transB == PlasmaNoTrans) {
216  for (k = 0; k < A.mt; k++) {
217  tempkm = k == A.mt-1 ? A.m-k*A.mb : A.mb;
218  ldak = BLKLDD(A, k);
219  ldbk = BLKLDD(B, k);
220  zbeta = k == 0 ? beta : zone;
222  plasma->quark, &task_flags,
223  transA, transB,
224  tempmm, tempnn, tempkm, A.mb,
225  alpha, A(k, m), ldak, /* lda * X */
226  B(k, n), ldbk, /* ldb * Y */
227  zbeta, C(m, n), ldcm); /* ldc * Y */
228  }
229  }
230  /*
231  * A: Plasma[Conj]Trans / B: Plasma[Conj]Trans
232  */
233  else {
234  ldbn = BLKLDD(B, n);
235  for (k = 0; k < A.mt; k++) {
236  tempkm = k == A.mt-1 ? A.m-k*A.mb : A.mb;
237  ldak = BLKLDD(A, k);
238  zbeta = k == 0 ? beta : zone;
240  plasma->quark, &task_flags,
241  transA, transB,
242  tempmm, tempnn, tempkm, A.mb,
243  alpha, A(k, m), ldak, /* lda * X */
244  B(n, k), ldbn, /* ldb * Z */
245  zbeta, C(m, n), ldcm); /* ldc * Y */
246  }
247  }
248  }
249  }
250  }
251 }