Actual source code: cd_cupm.cxx
1: #include "../denseqn.h"
2: #include <petsc/private/cupminterface.hpp>
3: #include <petsc/private/cupmobject.hpp>
5: namespace Petsc
6: {
8: namespace device
9: {
11: namespace cupm
12: {
14: namespace impl
15: {
17: template <DeviceType T>
18: struct UpperTriangular : CUPMObject<T> {
19: PETSC_CUPMOBJECT_HEADER(T);
21: static PetscErrorCode SolveInPlace(PetscDeviceContext, PetscBool, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
22: static PetscErrorCode SolveInPlaceCyclic(PetscDeviceContext, PetscBool, PetscInt, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
23: };
25: template <DeviceType T>
26: PetscErrorCode UpperTriangular<T>::SolveInPlace(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt N, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
27: {
28: cupmBlasInt_t n;
29: cupmBlasHandle_t handle;
30: auto _A = cupmScalarPtrCast(A);
31: auto _x = cupmScalarPtrCast(x);
33: PetscFunctionBegin;
34: if (!N) PetscFunctionReturn(PETSC_SUCCESS);
35: PetscCall(PetscCUPMBlasIntCast(N, &n));
36: PetscCall(GetHandlesFrom_(dctx, &handle));
37: PetscCall(PetscLogGpuTimeBegin());
38: PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, hermitian_transpose ? CUPMBLAS_OP_C : CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n, _A, lda, _x, stride));
39: PetscCall(PetscLogGpuTimeEnd());
41: PetscCall(PetscLogGpuFlops(1.0 * N * N));
42: PetscFunctionReturn(PETSC_SUCCESS);
43: }
45: template <DeviceType T>
46: PetscErrorCode UpperTriangular<T>::SolveInPlaceCyclic(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt N, PetscInt oldest_index, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
47: {
48: cupmBlasInt_t n_old, n_new;
49: cupmBlasPointerMode_t pointer_mode;
50: cupmBlasHandle_t handle;
51: auto sone = cupmScalarCast(1.0);
52: auto minus_one = cupmScalarCast(-1.0);
53: auto _A = cupmScalarPtrCast(A);
54: auto _x = cupmScalarPtrCast(x);
56: PetscFunctionBegin;
57: if (!N) PetscFunctionReturn(PETSC_SUCCESS);
58: PetscCall(PetscCUPMBlasIntCast(N - oldest_index, &n_old));
59: PetscCall(PetscCUPMBlasIntCast(oldest_index, &n_new));
60: PetscCall(GetHandlesFrom_(dctx, &handle));
61: PetscCall(PetscLogGpuTimeBegin());
62: PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
63: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
64: if (!hermitian_transpose) {
65: PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_new, _A, lda, _x, stride));
66: PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_N, n_old, n_new, &minus_one, &_A[oldest_index], lda, _x, stride, &sone, &_x[oldest_index], stride));
67: PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_old, &_A[oldest_index * (lda + 1)], lda, &_x[oldest_index], stride));
68: } else {
69: PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_old, &_A[oldest_index * (lda + 1)], lda, &_x[oldest_index], stride));
70: PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_C, n_old, n_new, &minus_one, &_A[oldest_index], lda, &_x[oldest_index], stride, &sone, _x, stride));
71: PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_new, _A, lda, _x, stride));
72: }
73: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
74: PetscCall(PetscLogGpuTimeEnd());
76: PetscCall(PetscLogGpuFlops(1.0 * N * N));
77: PetscFunctionReturn(PETSC_SUCCESS);
78: }
80: #if PetscDefined(HAVE_CUDA)
81: template struct UpperTriangular<DeviceType::CUDA>;
82: #endif
84: #if PetscDefined(HAVE_HIP)
85: template struct UpperTriangular<DeviceType::HIP>;
86: #endif
88: } // namespace impl
90: } // namespace cupm
92: } // namespace device
94: } // namespace Petsc
96: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_CUPM(PetscBool hermitian_transpose, PetscInt n, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
97: {
98: using ::Petsc::device::cupm::impl::UpperTriangular;
99: using ::Petsc::device::cupm::DeviceType;
100: PetscDeviceContext dctx;
101: PetscDeviceType device_type;
103: PetscFunctionBegin;
104: PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
105: PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
106: switch (device_type) {
107: #if PetscDefined(HAVE_CUDA)
108: case PETSC_DEVICE_CUDA:
109: PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
110: break;
111: #endif
112: #if PetscDefined(HAVE_HIP)
113: case PETSC_DEVICE_HIP:
114: PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
115: break;
116: #endif
117: default:
118: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
119: }
120: PetscFunctionReturn(PETSC_SUCCESS);
121: }
123: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlaceCyclic_CUPM(PetscBool hermitian_transpose, PetscInt n, PetscInt oldest_index, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
124: {
125: using ::Petsc::device::cupm::impl::UpperTriangular;
126: using ::Petsc::device::cupm::DeviceType;
127: PetscDeviceContext dctx;
128: PetscDeviceType device_type;
130: PetscFunctionBegin;
131: PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
132: PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
133: switch (device_type) {
134: #if PetscDefined(HAVE_CUDA)
135: case PETSC_DEVICE_CUDA:
136: PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlaceCyclic(dctx, hermitian_transpose, n, oldest_index, A, lda, x, stride));
137: break;
138: #endif
139: #if PetscDefined(HAVE_HIP)
140: case PETSC_DEVICE_HIP:
141: PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlaceCyclic(dctx, hermitian_transpose, n, oldest_index, A, lda, x, stride));
142: break;
143: #endif
144: default:
145: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
146: }
147: PetscFunctionReturn(PETSC_SUCCESS);
148: }