001: /* ///////////////////////////// P /// L /// A /// S /// M /// A /////////////////////////////// */
002: /* ///                    PLASMA auxiliary routines (version 2.1.0)                          ///
003:  * ///                    Author: Jakub Kurzak, Hatem Ltaief                                 ///
004:  * ///                    Release Date: November, 15th 2009                                  ///
005:  * ///                    PLASMA is a software package provided by Univ. of Tennessee,       ///
006:  * ///                    Univ. of California Berkeley and Univ. of Colorado Denver          /// */
007: /* ///////////////////////////////////////////////////////////////////////////////////////////// */
008: #include "common.h"
009: 
010: /* ///////////////////////////////////////////////////////////////////////////////////////////// */
011: //  Parallel triangular solve
012: #define A(m,n) &((float*)A.mat)[A.bsiz*(m)+A.bsiz*A.lmt*(n)]
013: #define B(m,n) &((float*)B.mat)[B.bsiz*(m)+B.bsiz*B.lmt*(n)]
014: void plasma_pstrsm(plasma_context_t *plasma)
015: {
016:     PLASMA_enum side;
017:     PLASMA_enum uplo;
018:     PLASMA_enum transA;
019:     PLASMA_enum diag;
020:     float      alpha;
021:     PLASMA_desc A;
022:     PLASMA_desc B;
023: 
024:     int k, m, n;
025:     int next_k;
026:     int next_m;
027:     int next_n;
028: 
029:     plasma_unpack_args_7(side, uplo, transA, diag, alpha, A, B);
030:     ss_init(B.mt, B.nt, -1);
031: 
032:     k = 0;
033:     m = PLASMA_RANK;
034:     while (m >= A.nt) {
035:         k++;
036:         m = m-A.nt+k;
037:     }
038:     n = 0;
039: 
040:     while (k < A.nt && m < A.nt) {
041:         next_n = n;
042:         next_m = m;
043:         next_k = k;
044: 
045:         next_n++;
046:         if (next_n >= B.nt) {
047:             next_m += PLASMA_SIZE;
048:             while (next_m >= A.nt && next_k < A.nt) {
049:                 next_k++;
050:                 next_m = next_m-A.nt+next_k;
051:             }
052:             next_n = 0;
053:         }
054: 
055:         if (m == k)
056:         {
057:             ss_cond_wait(m, n, k-1);
058:             if (uplo == PlasmaLower) {
059:                 if (transA == PlasmaNoTrans)
060:                     CORE_strsm(
061:                         PlasmaLeft, PlasmaLower,
062:                         PlasmaNoTrans, diag,
063:                         k == A.nt-1 ? A.n-k*A.nb : A.nb,
064:                         n == B.nt-1 ? B.n-n*B.nb : B.nb,
065:                         1.0, A(k, k), A.nb,
066:                              B(k, n), B.nb);
067:                 else
068:                     CORE_strsm(
069:                         PlasmaLeft, PlasmaLower,
070:                         PlasmaTrans, diag,
071:                         k == 0 ? A.n-(A.nt-1)*A.nb : A.nb,
072:                         n == B.nt-1 ? B.n-n*B.nb : B.nb,
073:                         1.0, A(A.nt-1-k, A.nt-1-k), A.nb,
074:                              B(A.nt-1-k, n), B.nb);
075:             }
076:             else {
077:                 if (transA == PlasmaNoTrans)
078:                     CORE_strsm(
079:                         PlasmaLeft, PlasmaUpper,
080:                         PlasmaNoTrans, diag,
081:                         k == 0 ? A.n-(A.nt-1)*A.nb : A.nb,
082:                         n == B.nt-1 ? B.n-n*B.nb : B.nb,
083:                         1.0, A(A.nt-1-k, A.nt-1-k), A.nb,
084:                              B(A.nt-1-k, n), A.nb);
085:                 else
086:                     CORE_strsm(
087:                         PlasmaLeft, PlasmaUpper,
088:                         PlasmaTrans, diag,
089:                         k == A.nt-1 ? A.n-k*A.nb : A.nb,
090:                         n == B.nt-1 ? B.n-n*B.nb : B.nb,
091:                         1.0, A(k, k), A.nb,
092:                              B(k, n), B.nb);
093:             }
094:             ss_cond_set(k, n, k);
095:         }
096:         else
097:         {
098:             ss_cond_wait(k, n, k);
099:             ss_cond_wait(m, n, k-1);
100:             if (uplo == PlasmaLower) {
101:                 if (transA == PlasmaNoTrans)
102:                     CORE_sgemm(
103:                         PlasmaNoTrans, PlasmaNoTrans,
104:                         m == A.nt-1 ? A.n-m*A.nb : A.nb,
105:                         n == B.nt-1 ? B.n-n*B.nb : B.nb,
106:                         A.nb,
107:                        -1.0, A(m, k), A.nb,
108:                              B(k, n), B.nb,
109:                         1.0, B(m, n), B.nb);
110:                 else
111:                     CORE_sgemm(
112:                         PlasmaTrans, PlasmaNoTrans,
113:                         A.nb,
114:                         n == B.nt-1 ? B.n-n*B.nb : B.nb,
115:                         k == 0 ? A.n-(A.nt-1)*A.nb : A.nb,
116:                        -1.0, A(A.nt-1-k, A.nt-1-m), A.nb,
117:                              B(A.nt-1-k, n), B.nb,
118:                         1.0, B(A.nt-1-m, n), B.nb);
119:             }
120:             else {
121:                 if (transA == PlasmaNoTrans)
122:                     CORE_sgemm(
123:                         PlasmaNoTrans, PlasmaNoTrans,
124:                         A.nb,
125:                         n == B.nt-1 ? B.n-n*B.nb : B.nb,
126:                         k == 0 ? A.n-(A.nt-1)*A.nb : A.nb,
127:                        -1.0, A(A.nt-1-m, A.nt-1-k), A.nb,
128:                              B(A.nt-1-k, n), B.nb,
129:                         1.0, B(A.nt-1-m, n), B.nb);
130:                 else
131:                     CORE_sgemm(
132:                         PlasmaTrans, PlasmaNoTrans,
133:                         m == A.nt-1 ? A.n-m*A.nb : A.nb,
134:                         n == B.nt-1 ? B.n-n*B.nb : B.nb,
135:                         A.nb,
136:                        -1.0, A(k, m), A.nb,
137:                              B(k, n), B.nb,
138:                         1.0, B(m, n), B.nb);
139:             }
140:             ss_cond_set(m, n, k);
141:         }
142:         n = next_n;
143:         m = next_m;
144:         k = next_k;
145:     }
146:     ss_finalize();
147: }
148: