MAGMA  1.2.0
MatrixAlgebraonGPUandMulticoreArchitectures
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Groups
codelet_zgeqrt.c
Go to the documentation of this file.
1 
17 #include "morse_starpu.h"
18 
19 /*
20  * Codelet CPU
21  */
22 static void cl_zgeqrt_cpu_func(void *descr[], void *cl_arg)
23 {
24  int M;
25  int N;
26  int IB;
27  PLASMA_Complex64_t *A;
28  int LDA;
29  PLASMA_Complex64_t *T;
30  int LDT;
31  PLASMA_Complex64_t *TAU;
32  PLASMA_Complex64_t *WORK;
33 
34  morse_starpu_ws_t *h_work;
35 
36  starpu_unpack_cl_args(cl_arg, &M, &N, &IB, &LDA, &LDT,
37  &h_work, NULL);
38 
39  /* descr[0] : tile from A, descr[1] : tile from T */
40  A = (PLASMA_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[0]);
41  T = (PLASMA_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[1]);
42 
43  TAU = (PLASMA_Complex64_t*)morse_starpu_ws_getlocal(h_work);
44  WORK = TAU + min( M, N );
45 
46  CORE_zgeqrt(M, N, IB, A, LDA, T, LDT, TAU, WORK);
47 }
48 
49 /*
50  * Codelet Multi-cores
51  */
52 #if defined(MORSE_USE_MULTICORE) && 0
53 static void cl_zgeqrt_mc_func(void *descr[], void *cl_arg)
54 {
55 }
56 #endif
57 
58 /*
59  * Codelet GPU
60  */
61 #if defined(MORSE_USE_CUDA) && 0
62 static void cl_zgeqrt_cuda_func(void *descr[], void *cl_arg)
63 {
64  int M;
65  int N;
66  int IB;
67  cuDoubleComplex *h_A, *d_A;
68  int LDA;
69  cuDoubleComplex *h_T, *d_T;
70  int LDT;
71  cuDoubleComplex *h_D;
72  cuDoubleComplex *h_TAU;
73  cuDoubleComplex *h_WORK, *d_WORK;
74  cuDoubleComplex *d_D;
75  int MxMx2;
76  int INFO;
77 
78  morse_starpu_ws_t *scratch_work;
79  morse_starpu_ws_t *scratch_h_work;
80  morse_starpu_ws_t *scratch_h_a;
81  morse_starpu_ws_t *scratch_h_T;
82  morse_starpu_ws_t *scratch_h_D;
83  morse_starpu_ws_t *scratch_d_D;
84  morse_starpu_ws_t *scratch_tau;
85 
86  starpu_unpack_cl_args(cl_arg, &M, &N, &IB, &LDA, &LDT,
87  &scratch_tau, &scratch_work,
88  &scratch_h_work, &scratch_h_a,
89  &scratch_h_T, &scratch_h_D, &scratch_d_D);
90 
91  /* descr[0] : tile from A, descr[1] : tile from T */
92  d_A = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
93  d_T = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
94  d_WORK = morse_starpu_ws_getlocal(scratch_work);
95  h_A = morse_starpu_ws_getlocal(scratch_h_a);
96  h_T = morse_starpu_ws_getlocal(scratch_h_T);
97  h_D = morse_starpu_ws_getlocal(scratch_h_D);
98  h_WORK = morse_starpu_ws_getlocal(scratch_h_work);
99  h_TAU = morse_starpu_ws_getlocal(scratch_tau);
100  d_D = morse_starpu_ws_getlocal(scratch_d_D);
101 
102  /* TODO are the memset really needed ? */
103  memset(h_A, 0, M*N *sizeof(cuDoubleComplex));
104  memset(h_T, 0, IB*N *sizeof(cuDoubleComplex));
105  memset(h_D, 0, IB*M *sizeof(cuDoubleComplex));
106  memset(h_TAU, 0, M *sizeof(cuDoubleComplex));
107  memset(h_WORK, 0, 2*M*M*sizeof(cuDoubleComplex));
108 
109  /* Copy A panel */
110  cudaMemcpy(h_A, d_A, M*IB*sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost);
111 
112  MxMx2 = M*M*2;
113  magma_zgeqrt_gpu(
114  &M, &N, IB,
115  d_A, &LDA,
116  h_A, &LDA,
117  d_T, &LDT,
118  h_T, &LDT,
119  h_D, &IB,
120  h_TAU,
121  h_WORK, &MxMx2,
122  d_WORK, &INFO);
123 
124  cudaMemcpy(d_D, h_D, IB*M*sizeof(cuDoubleComplex), cudaMemcpyHostToDevice);
125  splagma_zload_d_into_tile(M, IB, d_A, d_D);
126  cudaThreadSynchronize();
127 }
128 #endif
129 
130 /*
131  * Codelet definition
132  */
133 //CODELETS(zgeqrt, 2, cl_zgeqrt_cpu_func, cl_zgeqrt_cuda_func, cl_zgeqrt_cpu_func)
135 
136 /*
137  * Wrapper
138  */
139 void MORSE_zgeqrt( MorseOption_t *options,
140  int m, int n, int ib,
141  magma_desc_t *A, int Am, int An,
142  magma_desc_t *T, int Tm, int Tn)
143 {
144  starpu_codelet *zgeqrt_codelet;
145  void (*callback)(void*) = options->profiling ? cl_zgeqrt_callback : NULL;
146  int lda = BLKLDD( A, Am );
147  int ldt = BLKLDD( T, Tm );
148  morse_starpu_ws_t *h_work = (morse_starpu_ws_t*)(options->ws_host);
149  morse_starpu_ws_t *d_work = (morse_starpu_ws_t*)(options->ws_device);
150 
151 #ifdef MORSE_USE_MULTICORE
152  zgeqrt_codelet = options->parallel ? &cl_zgeqrt_mc : &cl_zgeqrt;
153 #else
154  zgeqrt_codelet = &cl_zgeqrt;
155 #endif
156 
158  &cl_zgeqrt,
159  VALUE, &m, sizeof(int),
160  VALUE, &n, sizeof(int),
161  VALUE, &ib, sizeof(int),
162  INOUT, BLKADDR( A, PLASMA_Complex64_t, Am, An ),
163  VALUE, &lda, sizeof(int),
164  OUTPUT, BLKADDR( T, PLASMA_Complex64_t, Tm, Tn ),
165  VALUE, &ldt, sizeof(int),
166  VALUE, &h_work, sizeof(morse_starpu_ws_t *),
167  VALUE, &d_work, sizeof(morse_starpu_ws_t *),
168  PRIORITY, options->priority,
169  CALLBACK, callback, NULL,
170  0);
171 }