Actual source code: cd_utils.c

  1: #include <../src/ksp/ksp/utils/lmvm/dense/denseqn.h>
  2: #include <../src/ksp/ksp/utils/lmvm/diagbrdn/diagbrdn.h>
  3: #include <petscblaslapack.h>
  4: #include <petscmat.h>
  5: #include <petscsys.h>
  6: #include <petscsystypes.h>
  7: #include <petscis.h>
  8: #include <petscoptions.h>
  9: #include <petscdevice.h>
 10: #include <petsc/private/deviceimpl.h>

 12: const char *const MatLMVMDenseTypes[] = {"reorder", "inplace", "MatLMVMDenseType", "MAT_LMVM_DENSE_", NULL};

 14: PETSC_INTERN PetscErrorCode MatMultAddColumnRange(Mat A, Vec xx, Vec zz, Vec yy, PetscInt c_start, PetscInt c_end)
 15: {
 16:   PetscFunctionBegin;
 17:   PetscCall(PetscLogEventBegin(MAT_MultAdd, (PetscObject)A, NULL, NULL, NULL));
 18:   PetscUseMethod(A, "MatMultAddColumnRange_C", (Mat, Vec, Vec, Vec, PetscInt, PetscInt), (A, xx, zz, yy, c_start, c_end));
 19:   PetscCall(PetscLogEventEnd(MAT_MultAdd, (PetscObject)A, NULL, NULL, NULL));
 20:   PetscFunctionReturn(PETSC_SUCCESS);
 21: }

 23: PETSC_INTERN PetscErrorCode MatMultHermitianTransposeColumnRange(Mat A, Vec xx, Vec yy, PetscInt c_start, PetscInt c_end)
 24: {
 25:   PetscFunctionBegin;
 26:   PetscCall(PetscLogEventBegin(MAT_MultTranspose, (PetscObject)A, NULL, NULL, NULL));
 27:   PetscUseMethod(A, "MatMultHermitianTransposeColumnRange_C", (Mat, Vec, Vec, PetscInt, PetscInt), (A, xx, yy, c_start, c_end));
 28:   PetscCall(PetscLogEventEnd(MAT_MultTranspose, (PetscObject)A, NULL, NULL, NULL));
 29:   PetscFunctionReturn(PETSC_SUCCESS);
 30: }

 32: PETSC_INTERN PetscErrorCode MatMultHermitianTransposeAddColumnRange(Mat A, Vec xx, Vec zz, Vec yy, PetscInt c_start, PetscInt c_end)
 33: {
 34:   PetscFunctionBegin;
 35:   PetscCall(PetscLogEventBegin(MAT_MultTransposeAdd, (PetscObject)A, NULL, NULL, NULL));
 36:   PetscUseMethod(A, "MatMultHermitianTransposeAddColumnRange_C", (Mat, Vec, Vec, Vec, PetscInt, PetscInt), (A, xx, zz, yy, c_start, c_end));
 37:   PetscCall(PetscLogEventEnd(MAT_MultTransposeAdd, (PetscObject)A, NULL, NULL, NULL));
 38:   PetscFunctionReturn(PETSC_SUCCESS);
 39: }

 41: PETSC_INTERN PetscErrorCode VecCyclicShift(Mat B, Vec X, PetscInt d, Vec cyclic_work_vec)
 42: {
 43:   Mat_LMVM          *lmvm = (Mat_LMVM *)B->data;
 44:   PetscInt           m    = lmvm->m;
 45:   PetscInt           n;
 46:   const PetscScalar *src;
 47:   PetscScalar       *dest;
 48:   PetscMemType       src_memtype;
 49:   PetscMemType       dest_memtype;

 51:   PetscFunctionBegin;
 52:   PetscCall(VecGetLocalSize(X, &n));
 53:   if (!cyclic_work_vec) PetscCall(VecDuplicate(X, &cyclic_work_vec));
 54:   PetscCall(VecCopy(X, cyclic_work_vec));
 55:   PetscCall(VecGetArrayReadAndMemType(cyclic_work_vec, &src, &src_memtype));
 56:   PetscCall(VecGetArrayWriteAndMemType(X, &dest, &dest_memtype));
 57:   if (n == 0) { /* no work on this process */
 58:     PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
 59:     PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
 60:     PetscFunctionReturn(PETSC_SUCCESS);
 61:   }
 62:   PetscAssert(src_memtype == dest_memtype, PETSC_COMM_SELF, PETSC_ERR_PLIB, "memtype of duplicate does not match");
 63:   if (PetscMemTypeHost(src_memtype)) {
 64:     PetscCall(PetscArraycpy(dest, &src[d], m - d));
 65:     PetscCall(PetscArraycpy(&dest[m - d], src, d));
 66:   } else {
 67:     PetscDeviceContext dctx;

 69:     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
 70:     PetscCall(PetscDeviceRegisterMemory(dest, dest_memtype, m * sizeof(*dest)));
 71:     PetscCall(PetscDeviceRegisterMemory(src, src_memtype, m * sizeof(*src)));
 72:     PetscCall(PetscDeviceArrayCopy(dctx, dest, &src[d], m - d));
 73:     PetscCall(PetscDeviceArrayCopy(dctx, &dest[m - d], src, d));
 74:   }
 75:   PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
 76:   PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
 77:   PetscFunctionReturn(PETSC_SUCCESS);
 78: }

 80: static inline PetscInt recycle_index(PetscInt m, PetscInt idx)
 81: {
 82:   return idx % m;
 83: }

 85: static inline PetscInt oldest_update(PetscInt m, PetscInt idx)
 86: {
 87:   return PetscMax(0, idx - m);
 88: }

 90: PETSC_INTERN PetscErrorCode VecRecycleOrderToHistoryOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
 91: {
 92:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
 93:   PetscInt  m    = lmvm->m;
 94:   PetscInt  oldest_index;

 96:   PetscFunctionBegin;
 97:   oldest_index = recycle_index(m, oldest_update(m, num_updates));
 98:   if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in history order */
 99:   PetscCall(VecCyclicShift(B, X, oldest_index, cyclic_work_vec));
100:   PetscFunctionReturn(PETSC_SUCCESS);
101: }

103: PETSC_INTERN PetscErrorCode VecHistoryOrderToRecycleOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
104: {
105:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
106:   PetscInt  m    = lmvm->m;
107:   PetscInt  oldest_index;

109:   PetscFunctionBegin;
110:   oldest_index = recycle_index(m, oldest_update(m, num_updates));
111:   if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in recycle order */
112:   PetscCall(VecCyclicShift(B, X, m - oldest_index, cyclic_work_vec));
113:   PetscFunctionReturn(PETSC_SUCCESS);
114: }

116: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_Internal(MatLMVMDenseType type, PetscMemType memtype, PetscBool hermitian_transpose, PetscInt N, PetscInt oldest_index, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
117: {
118:   PetscFunctionBegin;
119:   /* if oldest_index == 0, the two strategies are equivalent, redirect to the simpler one */
120:   if (oldest_index == 0) type = MAT_LMVM_DENSE_REORDER;
121:   switch (type) {
122:   case MAT_LMVM_DENSE_REORDER:
123:     if (PetscMemTypeHost(memtype)) {
124:       PetscBLASInt n, lda_blas, one = 1;
125:       PetscCall(PetscBLASIntCast(N, &n));
126:       PetscCall(PetscBLASIntCast(lda, &lda_blas));
127:       PetscCallBLAS("BLAStrsv", BLAStrsv_("U", hermitian_transpose ? "C" : "N", "NotUnitTriangular", &n, A, &lda_blas, x, &one));
128:       PetscCall(PetscLogFlops(1.0 * n * n));
129: #if defined(PETSC_HAVE_CUPM)
130:     } else if (PetscMemTypeDevice(memtype)) {
131:       PetscCall(MatUpperTriangularSolveInPlace_CUPM(hermitian_transpose, N, A, lda, x, 1));
132: #endif
133:     } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
134:     break;
135:   case MAT_LMVM_DENSE_INPLACE:
136:     if (PetscMemTypeHost(memtype)) {
137:       PetscBLASInt n_old, n_new, lda_blas, one = 1;
138:       PetscScalar  minus_one = -1.0;
139:       PetscScalar  sone      = 1.0;
140:       PetscCall(PetscBLASIntCast(N - oldest_index, &n_old));
141:       PetscCall(PetscBLASIntCast(oldest_index, &n_new));
142:       PetscCall(PetscBLASIntCast(lda, &lda_blas));
143:       if (!hermitian_transpose) {
144:         PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
145:         PetscCallBLAS("BLASgemv", BLASgemv_("N", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, x, &one, &sone, &x[oldest_index], &one));
146:         PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
147:       } else {
148:         PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
149:         PetscCallBLAS("BLASgemv", BLASgemv_("C", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, &x[oldest_index], &one, &sone, x, &one));
150:         PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
151:       }
152:       PetscCall(PetscLogFlops(1.0 * N * N));
153: #if defined(PETSC_HAVE_CUPM)
154:     } else if (PetscMemTypeDevice(memtype)) {
155:       PetscCall(MatUpperTriangularSolveInPlaceCyclic_CUPM(hermitian_transpose, N, oldest_index, A, lda, x, stride));
156: #endif
157:     } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
158:     break;
159:   default:
160:     PetscUnreachable();
161:   }
162:   PetscFunctionReturn(PETSC_SUCCESS);
163: }

165: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace(Mat B, Mat Amat, Vec X, PetscBool hermitian_transpose, PetscInt num_updates, MatLMVMDenseType strategy)
166: {
167:   Mat_LMVM          *lmvm = (Mat_LMVM *)B->data;
168:   PetscInt           m    = lmvm->m;
169:   PetscInt           h, local_n;
170:   PetscInt           oldest_index;
171:   PetscInt           lda;
172:   PetscScalar       *x;
173:   PetscMemType       memtype_r, memtype_x;
174:   const PetscScalar *A;

176:   PetscFunctionBegin;
177:   h = num_updates - oldest_update(m, num_updates);
178:   if (!h) PetscFunctionReturn(PETSC_SUCCESS);
179:   PetscCall(VecGetLocalSize(X, &local_n));
180:   PetscCall(VecGetArrayAndMemType(X, &x, &memtype_x));
181:   PetscCall(MatDenseGetArrayReadAndMemType(Amat, &A, &memtype_r));
182:   if (!local_n) {
183:     PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
184:     PetscCall(VecRestoreArrayAndMemType(X, &x));
185:     PetscFunctionReturn(PETSC_SUCCESS);
186:   }
187:   PetscAssert(memtype_x == memtype_r, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Incompatible device pointers");
188:   PetscCall(MatDenseGetLDA(Amat, &lda));
189:   oldest_index = recycle_index(m, oldest_update(m, num_updates));
190:   PetscCall(MatUpperTriangularSolveInPlace_Internal(strategy, memtype_x, hermitian_transpose, h, oldest_index, A, lda, x, 1));
191:   PetscCall(VecRestoreArrayWriteAndMemType(X, &x));
192:   PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
193:   PetscFunctionReturn(PETSC_SUCCESS);
194: }

196: /* Shifts R[end-m_keep:end,end-m_keep:end] to R[0:m_keep, 0:m_keep] */

198: PETSC_INTERN PetscErrorCode MatMove_LR3(Mat B, Mat R, PetscInt m_keep)
199: {
200:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
201:   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
202:   PetscInt  M;
203:   Mat       mat_local, local_sub, local_temp, temp_sub;

205:   PetscFunctionBegin;
206:   if (!lqn->temp_mat) PetscCall(MatDuplicate(R, MAT_SHARE_NONZERO_PATTERN, &lqn->temp_mat));
207:   PetscCall(MatGetLocalSize(R, &M, NULL));
208:   if (M == 0) PetscFunctionReturn(PETSC_SUCCESS);

210:   PetscCall(MatDenseGetLocalMatrix(R, &mat_local));
211:   PetscCall(MatDenseGetLocalMatrix(lqn->temp_mat, &local_temp));
212:   PetscCall(MatDenseGetSubMatrix(mat_local, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &local_sub));
213:   PetscCall(MatDenseGetSubMatrix(local_temp, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &temp_sub));
214:   PetscCall(MatCopy(local_sub, temp_sub, SAME_NONZERO_PATTERN));
215:   PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
216:   PetscCall(MatDenseGetSubMatrix(mat_local, 0, m_keep, 0, m_keep, &local_sub));
217:   PetscCall(MatCopy(temp_sub, local_sub, SAME_NONZERO_PATTERN));
218:   PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
219:   PetscCall(MatDenseRestoreSubMatrix(local_temp, &temp_sub));
220:   PetscFunctionReturn(PETSC_SUCCESS);
221: }