Actual source code: cd_utils.c
1: #include <../src/ksp/ksp/utils/lmvm/dense/denseqn.h>
2: #include <../src/ksp/ksp/utils/lmvm/diagbrdn/diagbrdn.h>
3: #include <petscblaslapack.h>
4: #include <petscmat.h>
5: #include <petscsys.h>
6: #include <petscsystypes.h>
7: #include <petscis.h>
8: #include <petscoptions.h>
9: #include <petscdevice.h>
10: #include <petsc/private/deviceimpl.h>
12: const char *const MatLMVMDenseTypes[] = {"reorder", "inplace", "MatLMVMDenseType", "MAT_LMVM_DENSE_", NULL};
14: PETSC_INTERN PetscErrorCode MatMultAddColumnRange(Mat A, Vec xx, Vec zz, Vec yy, PetscInt c_start, PetscInt c_end)
15: {
16: PetscFunctionBegin;
17: PetscCall(PetscLogEventBegin(MAT_MultAdd, (PetscObject)A, NULL, NULL, NULL));
18: PetscUseMethod(A, "MatMultAddColumnRange_C", (Mat, Vec, Vec, Vec, PetscInt, PetscInt), (A, xx, zz, yy, c_start, c_end));
19: PetscCall(PetscLogEventEnd(MAT_MultAdd, (PetscObject)A, NULL, NULL, NULL));
20: PetscFunctionReturn(PETSC_SUCCESS);
21: }
23: PETSC_INTERN PetscErrorCode MatMultHermitianTransposeColumnRange(Mat A, Vec xx, Vec yy, PetscInt c_start, PetscInt c_end)
24: {
25: PetscFunctionBegin;
26: PetscCall(PetscLogEventBegin(MAT_MultTranspose, (PetscObject)A, NULL, NULL, NULL));
27: PetscUseMethod(A, "MatMultHermitianTransposeColumnRange_C", (Mat, Vec, Vec, PetscInt, PetscInt), (A, xx, yy, c_start, c_end));
28: PetscCall(PetscLogEventEnd(MAT_MultTranspose, (PetscObject)A, NULL, NULL, NULL));
29: PetscFunctionReturn(PETSC_SUCCESS);
30: }
32: PETSC_INTERN PetscErrorCode MatMultHermitianTransposeAddColumnRange(Mat A, Vec xx, Vec zz, Vec yy, PetscInt c_start, PetscInt c_end)
33: {
34: PetscFunctionBegin;
35: PetscCall(PetscLogEventBegin(MAT_MultTransposeAdd, (PetscObject)A, NULL, NULL, NULL));
36: PetscUseMethod(A, "MatMultHermitianTransposeAddColumnRange_C", (Mat, Vec, Vec, Vec, PetscInt, PetscInt), (A, xx, zz, yy, c_start, c_end));
37: PetscCall(PetscLogEventEnd(MAT_MultTransposeAdd, (PetscObject)A, NULL, NULL, NULL));
38: PetscFunctionReturn(PETSC_SUCCESS);
39: }
41: PETSC_INTERN PetscErrorCode VecCyclicShift(Mat B, Vec X, PetscInt d, Vec cyclic_work_vec)
42: {
43: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
44: PetscInt m = lmvm->m;
45: PetscInt n;
46: const PetscScalar *src;
47: PetscScalar *dest;
48: PetscMemType src_memtype;
49: PetscMemType dest_memtype;
51: PetscFunctionBegin;
52: PetscCall(VecGetLocalSize(X, &n));
53: if (!cyclic_work_vec) PetscCall(VecDuplicate(X, &cyclic_work_vec));
54: PetscCall(VecCopy(X, cyclic_work_vec));
55: PetscCall(VecGetArrayReadAndMemType(cyclic_work_vec, &src, &src_memtype));
56: PetscCall(VecGetArrayWriteAndMemType(X, &dest, &dest_memtype));
57: if (n == 0) { /* no work on this process */
58: PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
59: PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
60: PetscFunctionReturn(PETSC_SUCCESS);
61: }
62: PetscAssert(src_memtype == dest_memtype, PETSC_COMM_SELF, PETSC_ERR_PLIB, "memtype of duplicate does not match");
63: if (PetscMemTypeHost(src_memtype)) {
64: PetscCall(PetscArraycpy(dest, &src[d], m - d));
65: PetscCall(PetscArraycpy(&dest[m - d], src, d));
66: } else {
67: PetscDeviceContext dctx;
69: PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
70: PetscCall(PetscDeviceRegisterMemory(dest, dest_memtype, m * sizeof(*dest)));
71: PetscCall(PetscDeviceRegisterMemory(src, src_memtype, m * sizeof(*src)));
72: PetscCall(PetscDeviceArrayCopy(dctx, dest, &src[d], m - d));
73: PetscCall(PetscDeviceArrayCopy(dctx, &dest[m - d], src, d));
74: }
75: PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
76: PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
77: PetscFunctionReturn(PETSC_SUCCESS);
78: }
80: static inline PetscInt recycle_index(PetscInt m, PetscInt idx)
81: {
82: return idx % m;
83: }
85: static inline PetscInt oldest_update(PetscInt m, PetscInt idx)
86: {
87: return PetscMax(0, idx - m);
88: }
90: PETSC_INTERN PetscErrorCode VecRecycleOrderToHistoryOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
91: {
92: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
93: PetscInt m = lmvm->m;
94: PetscInt oldest_index;
96: PetscFunctionBegin;
97: oldest_index = recycle_index(m, oldest_update(m, num_updates));
98: if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in history order */
99: PetscCall(VecCyclicShift(B, X, oldest_index, cyclic_work_vec));
100: PetscFunctionReturn(PETSC_SUCCESS);
101: }
103: PETSC_INTERN PetscErrorCode VecHistoryOrderToRecycleOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
104: {
105: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
106: PetscInt m = lmvm->m;
107: PetscInt oldest_index;
109: PetscFunctionBegin;
110: oldest_index = recycle_index(m, oldest_update(m, num_updates));
111: if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in recycle order */
112: PetscCall(VecCyclicShift(B, X, m - oldest_index, cyclic_work_vec));
113: PetscFunctionReturn(PETSC_SUCCESS);
114: }
116: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_Internal(MatLMVMDenseType type, PetscMemType memtype, PetscBool hermitian_transpose, PetscInt N, PetscInt oldest_index, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
117: {
118: PetscFunctionBegin;
119: /* if oldest_index == 0, the two strategies are equivalent, redirect to the simpler one */
120: if (oldest_index == 0) type = MAT_LMVM_DENSE_REORDER;
121: switch (type) {
122: case MAT_LMVM_DENSE_REORDER:
123: if (PetscMemTypeHost(memtype)) {
124: PetscBLASInt n, lda_blas, one = 1;
125: PetscCall(PetscBLASIntCast(N, &n));
126: PetscCall(PetscBLASIntCast(lda, &lda_blas));
127: PetscCallBLAS("BLAStrsv", BLAStrsv_("U", hermitian_transpose ? "C" : "N", "NotUnitTriangular", &n, A, &lda_blas, x, &one));
128: PetscCall(PetscLogFlops(1.0 * n * n));
129: #if defined(PETSC_HAVE_CUPM)
130: } else if (PetscMemTypeDevice(memtype)) {
131: PetscCall(MatUpperTriangularSolveInPlace_CUPM(hermitian_transpose, N, A, lda, x, 1));
132: #endif
133: } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
134: break;
135: case MAT_LMVM_DENSE_INPLACE:
136: if (PetscMemTypeHost(memtype)) {
137: PetscBLASInt n_old, n_new, lda_blas, one = 1;
138: PetscScalar minus_one = -1.0;
139: PetscScalar sone = 1.0;
140: PetscCall(PetscBLASIntCast(N - oldest_index, &n_old));
141: PetscCall(PetscBLASIntCast(oldest_index, &n_new));
142: PetscCall(PetscBLASIntCast(lda, &lda_blas));
143: if (!hermitian_transpose) {
144: PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
145: PetscCallBLAS("BLASgemv", BLASgemv_("N", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, x, &one, &sone, &x[oldest_index], &one));
146: PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
147: } else {
148: PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
149: PetscCallBLAS("BLASgemv", BLASgemv_("C", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, &x[oldest_index], &one, &sone, x, &one));
150: PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
151: }
152: PetscCall(PetscLogFlops(1.0 * N * N));
153: #if defined(PETSC_HAVE_CUPM)
154: } else if (PetscMemTypeDevice(memtype)) {
155: PetscCall(MatUpperTriangularSolveInPlaceCyclic_CUPM(hermitian_transpose, N, oldest_index, A, lda, x, stride));
156: #endif
157: } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
158: break;
159: default:
160: PetscUnreachable();
161: }
162: PetscFunctionReturn(PETSC_SUCCESS);
163: }
165: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace(Mat B, Mat Amat, Vec X, PetscBool hermitian_transpose, PetscInt num_updates, MatLMVMDenseType strategy)
166: {
167: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
168: PetscInt m = lmvm->m;
169: PetscInt h, local_n;
170: PetscInt oldest_index;
171: PetscInt lda;
172: PetscScalar *x;
173: PetscMemType memtype_r, memtype_x;
174: const PetscScalar *A;
176: PetscFunctionBegin;
177: h = num_updates - oldest_update(m, num_updates);
178: if (!h) PetscFunctionReturn(PETSC_SUCCESS);
179: PetscCall(VecGetLocalSize(X, &local_n));
180: PetscCall(VecGetArrayAndMemType(X, &x, &memtype_x));
181: PetscCall(MatDenseGetArrayReadAndMemType(Amat, &A, &memtype_r));
182: if (!local_n) {
183: PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
184: PetscCall(VecRestoreArrayAndMemType(X, &x));
185: PetscFunctionReturn(PETSC_SUCCESS);
186: }
187: PetscAssert(memtype_x == memtype_r, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Incompatible device pointers");
188: PetscCall(MatDenseGetLDA(Amat, &lda));
189: oldest_index = recycle_index(m, oldest_update(m, num_updates));
190: PetscCall(MatUpperTriangularSolveInPlace_Internal(strategy, memtype_x, hermitian_transpose, h, oldest_index, A, lda, x, 1));
191: PetscCall(VecRestoreArrayWriteAndMemType(X, &x));
192: PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
193: PetscFunctionReturn(PETSC_SUCCESS);
194: }
196: /* Shifts R[end-m_keep:end,end-m_keep:end] to R[0:m_keep, 0:m_keep] */
198: PETSC_INTERN PetscErrorCode MatMove_LR3(Mat B, Mat R, PetscInt m_keep)
199: {
200: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
201: Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
202: PetscInt M;
203: Mat mat_local, local_sub, local_temp, temp_sub;
205: PetscFunctionBegin;
206: if (!lqn->temp_mat) PetscCall(MatDuplicate(R, MAT_SHARE_NONZERO_PATTERN, &lqn->temp_mat));
207: PetscCall(MatGetLocalSize(R, &M, NULL));
208: if (M == 0) PetscFunctionReturn(PETSC_SUCCESS);
210: PetscCall(MatDenseGetLocalMatrix(R, &mat_local));
211: PetscCall(MatDenseGetLocalMatrix(lqn->temp_mat, &local_temp));
212: PetscCall(MatDenseGetSubMatrix(mat_local, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &local_sub));
213: PetscCall(MatDenseGetSubMatrix(local_temp, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &temp_sub));
214: PetscCall(MatCopy(local_sub, temp_sub, SAME_NONZERO_PATTERN));
215: PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
216: PetscCall(MatDenseGetSubMatrix(mat_local, 0, m_keep, 0, m_keep, &local_sub));
217: PetscCall(MatCopy(temp_sub, local_sub, SAME_NONZERO_PATTERN));
218: PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
219: PetscCall(MatDenseRestoreSubMatrix(local_temp, &temp_sub));
220: PetscFunctionReturn(PETSC_SUCCESS);
221: }