11 #include "common_magma.h"
13 #define A(m,n) (a+(n)*(*lda)+(m))
14 #define T(m) (work+(m)*(nb))
15 #define W(k,n) &(local_work[(mt)*(n-1)+(k)])
21 const cuDoubleComplex *Atmp;
29 if (trans[0] ==
'C') {
33 ap = (
double *)&(*Atmp);
71 quark_unpack_args_13(quark, M, N, MM, NN, IB, V, LDV, C, LDC, T, LDT, W, LDW);
74 printf(
"SCHED_zlarfb: illegal value of M\n");
77 printf(
"SCHED_zlarfb: illegal value of N\n");
80 printf(
"SCHED_zlarfb: illegal value of IB\n");
83 *W = (cuDoubleComplex*) malloc(LDW*MM*
sizeof(cuDoubleComplex));
90 &NN, &MM, &c_one, V, &LDV, *W, &LDW);
95 &c_one, &C[MM], &LDC, &V[MM], &LDV, &c_one, *W, &LDW);
98 &NN, &MM, &c_one, T, &LDT, *W, &LDW);
112 cuDoubleComplex *
TAU;
113 cuDoubleComplex *WORK;
121 printf(
"SCHED_zgeqrt: illegal value of M\n");
125 printf(
"SCHED_zgeqrt: illegal value of N\n");
128 if ((IB < 0) || ( (IB == 0) && ((M > 0) && (N > 0)) )) {
129 printf(
"SCHED_zgeqrt: illegal value of IB\n");
132 if ((LDA <
max(1,M)) && (M > 0)) {
133 printf(
"SCHED_zgeqrt: illegal value of LDA\n");
136 if ((LDT <
max(1,IB)) && (IB > 0)) {
137 printf(
"SCHED_zgeqrt: illegal value of LDT\n");
153 cuDoubleComplex alpha;
158 cuDoubleComplex beta;
161 cuDoubleComplex *
work;
167 quark_unpack_args_11(quark, m, n, alpha, a, lda, b, ldb, beta, c, ldc, work);
170 printf(
"SCHED_ztrmm: illegal value of m\n");
174 printf(
"SCHED_ztrmm: illegal value of n\n");
180 &m, &n, &alpha, a, &lda, work, &m);
182 for (j = 0; j < n; j++)
184 blasf77_zaxpy(&m, &beta, &(work[j*m]), &one, &(c[j*ldc]), &one);
195 cuDoubleComplex alpha;
200 cuDoubleComplex beta;
204 cuDoubleComplex *fake;
208 quark_unpack_args_13(quark, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, fake, dkdk);
211 &m, &n, &k, &alpha, a, &lda, *b, &ldb, &beta, c, &ldc);
220 cuDoubleComplex alpha,
225 cuDoubleComplex beta,
228 cuDoubleComplex *fake,
238 sizeof(cuDoubleComplex), &alpha,
VALUE,
239 sizeof(cuDoubleComplex)*ldb*ldb, a,
INPUT,
241 sizeof(cuDoubleComplex*), b,
INPUT,
243 sizeof(cuDoubleComplex), &beta,
VALUE,
258 cuDoubleComplex alpha,
263 cuDoubleComplex beta,
273 sizeof(cuDoubleComplex), &alpha,
VALUE,
274 sizeof(cuDoubleComplex)*ldb*ldb, a,
INPUT,
276 sizeof(cuDoubleComplex*), b,
INPUT,
278 sizeof(cuDoubleComplex), &beta,
VALUE,
281 sizeof(cuDoubleComplex)*ldb*ldb, NULL,
SCRATCH,
296 cuDoubleComplex *tau,
308 sizeof(cuDoubleComplex)*ldt*ldt, t,
OUTPUT,
310 sizeof(cuDoubleComplex)*ldt, tau,
OUTPUT,
311 sizeof(cuDoubleComplex)*ldt*ldt, NULL,
SCRATCH,
345 sizeof(cuDoubleComplex)*m*n, v,
INPUT,
347 sizeof(cuDoubleComplex)*m*n, c,
INPUT,
349 sizeof(cuDoubleComplex)*ib*ib, t,
INPUT,
362 cuDoubleComplex *a,
magma_int_t *lda, cuDoubleComplex *tau,
448 char sgeqrt_dag_label[1000];
449 char slarfb_dag_label[1000];
450 char strmm_dag_label[1000];
451 char sgemm_dag_label[1000];
463 long int lquery = *lwork == -1;
470 }
else if (*lda <
max(1,*m)) {
472 }
else if (*lwork <
max(1,*n) && ! lquery) {
488 magma_int_t nt = (((*n)%nb) == 0) ? (*n)/nb : (*n)/nb + 1;
489 magma_int_t mt = (((*m)%nb) == 0) ? (*m)/nb : (*m)/nb + 1;
491 cuDoubleComplex **local_work = (cuDoubleComplex**) malloc(
sizeof(cuDoubleComplex*)*(nt-1)*mt);
492 memset(local_work, 0,
sizeof(cuDoubleComplex*)*(nt-1)*mt);
497 for (i = 0; i < k; i += nb) {
503 sprintf(sgeqrt_dag_label,
"GEQRT %d",ii);
507 0, (*m)-i,
min(nb,(*n)-i),
A(i,i), *lda,
T(i), nb, &tau[i], sgeqrt_dag_label);
514 for (j = (i-nb) + (2*nb); j < *n; j += nb) {
520 sprintf(slarfb_dag_label,
"LARFB %d %d",ii-1, jj);
524 (*m)-(i-nb),
min(nb,(*n)-(i-nb)),
min(nb,(*m)-(i-nb)),
min(nb,(*n)-j), nb,
525 A(i-nb,i-nb), *lda,
A(i-nb,j), *lda,
T(i-nb), nb,
W(ii-1,jj), nb, slarfb_dag_label, priority);
527 sprintf(strmm_dag_label,
"TRMM %d %d",ii-1, jj);
531 A(i-nb,i-nb), *lda,
W(ii-1,jj), nb, c_one,
A(i-nb,j), *lda, strmm_dag_label, priority);
533 sprintf(sgemm_dag_label,
"GEMM %d %d %d",ii-1, jj, ll);
537 A(i,i-nb), *lda,
W(ii-1,jj), nb, c_one,
A(i,j), *lda,
A(i,j), sgemm_dag_label, priority, jj);
556 sprintf(slarfb_dag_label,
"LARFB %d %d",ii, jj);
560 (*m)-i,
min(nb,(*n)-i),
min(nb,(*m)-i),
min(nb,(*n)-j), nb,
561 A(i,i), *lda,
A(i,j), *lda,
T(i), nb,
W(ii,jj), nb, slarfb_dag_label, priority);
563 sprintf(strmm_dag_label,
"TRMM %d %d",ii, jj);
567 A(i,i), *lda,
W(ii,jj), nb, c_one,
A(i,j), *lda, strmm_dag_label, priority);
569 sprintf(sgemm_dag_label,
"GEMM %d %d %d",ii, jj, ll);
573 A(i+nb,i), *lda,
W(ii,jj), nb, c_one,
A(i+nb,j), *lda,
A(i+nb,j), sgemm_dag_label, priority, jj);
583 for(k = 0 ; k < (nt-1)*mt; k++) {
584 if (local_work[k] != NULL) {