Actual source code: cd_cupm.cxx

  1: #include "../denseqn.h"
  2: #include <petsc/private/cupminterface.hpp>
  3: #include <petsc/private/cupmobject.hpp>

  5: namespace Petsc
  6: {

  8: namespace device
  9: {

 11: namespace cupm
 12: {

 14: namespace impl
 15: {

 17: template <DeviceType T>
 18: struct UpperTriangular : CUPMObject<T> {
 19:   PETSC_CUPMOBJECT_HEADER(T);

 21:   static PetscErrorCode SolveInPlace(PetscDeviceContext, PetscBool, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
 22:   static PetscErrorCode SolveInPlaceCyclic(PetscDeviceContext, PetscBool, PetscInt, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
 23: };

 25: template <DeviceType T>
 26: PetscErrorCode UpperTriangular<T>::SolveInPlace(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt N, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
 27: {
 28:   cupmBlasInt_t    n;
 29:   cupmBlasHandle_t handle;
 30:   auto             _A = cupmScalarPtrCast(A);
 31:   auto             _x = cupmScalarPtrCast(x);

 33:   PetscFunctionBegin;
 34:   if (!N) PetscFunctionReturn(PETSC_SUCCESS);
 35:   PetscCall(PetscCUPMBlasIntCast(N, &n));
 36:   PetscCall(GetHandlesFrom_(dctx, &handle));
 37:   PetscCall(PetscLogGpuTimeBegin());
 38:   PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, hermitian_transpose ? CUPMBLAS_OP_C : CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n, _A, lda, _x, stride));
 39:   PetscCall(PetscLogGpuTimeEnd());

 41:   PetscCall(PetscLogGpuFlops(1.0 * N * N));
 42:   PetscFunctionReturn(PETSC_SUCCESS);
 43: }

 45: template <DeviceType T>
 46: PetscErrorCode UpperTriangular<T>::SolveInPlaceCyclic(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt N, PetscInt oldest_index, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
 47: {
 48:   cupmBlasInt_t         n_old, n_new;
 49:   cupmBlasPointerMode_t pointer_mode;
 50:   cupmBlasHandle_t      handle;
 51:   auto                  sone      = cupmScalarCast(1.0);
 52:   auto                  minus_one = cupmScalarCast(-1.0);
 53:   auto                  _A        = cupmScalarPtrCast(A);
 54:   auto                  _x        = cupmScalarPtrCast(x);

 56:   PetscFunctionBegin;
 57:   if (!N) PetscFunctionReturn(PETSC_SUCCESS);
 58:   PetscCall(PetscCUPMBlasIntCast(N - oldest_index, &n_old));
 59:   PetscCall(PetscCUPMBlasIntCast(oldest_index, &n_new));
 60:   PetscCall(GetHandlesFrom_(dctx, &handle));
 61:   PetscCall(PetscLogGpuTimeBegin());
 62:   PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
 63:   PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
 64:   if (!hermitian_transpose) {
 65:     PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_new, _A, lda, _x, stride));
 66:     PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_N, n_old, n_new, &minus_one, &_A[oldest_index], lda, _x, stride, &sone, &_x[oldest_index], stride));
 67:     PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_old, &_A[oldest_index * (lda + 1)], lda, &_x[oldest_index], stride));
 68:   } else {
 69:     PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_old, &_A[oldest_index * (lda + 1)], lda, &_x[oldest_index], stride));
 70:     PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_C, n_old, n_new, &minus_one, &_A[oldest_index], lda, &_x[oldest_index], stride, &sone, _x, stride));
 71:     PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_new, _A, lda, _x, stride));
 72:   }
 73:   PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
 74:   PetscCall(PetscLogGpuTimeEnd());

 76:   PetscCall(PetscLogGpuFlops(1.0 * N * N));
 77:   PetscFunctionReturn(PETSC_SUCCESS);
 78: }

 80: #if PetscDefined(HAVE_CUDA)
 81: template struct UpperTriangular<DeviceType::CUDA>;
 82: #endif

 84: #if PetscDefined(HAVE_HIP)
 85: template struct UpperTriangular<DeviceType::HIP>;
 86: #endif

 88: } // namespace impl

 90: } // namespace cupm

 92: } // namespace device

 94: } // namespace Petsc

 96: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_CUPM(PetscBool hermitian_transpose, PetscInt n, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
 97: {
 98:   using ::Petsc::device::cupm::impl::UpperTriangular;
 99:   using ::Petsc::device::cupm::DeviceType;
100:   PetscDeviceContext dctx;
101:   PetscDeviceType    device_type;

103:   PetscFunctionBegin;
104:   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
105:   PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
106:   switch (device_type) {
107: #if PetscDefined(HAVE_CUDA)
108:   case PETSC_DEVICE_CUDA:
109:     PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
110:     break;
111: #endif
112: #if PetscDefined(HAVE_HIP)
113:   case PETSC_DEVICE_HIP:
114:     PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
115:     break;
116: #endif
117:   default:
118:     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
119:   }
120:   PetscFunctionReturn(PETSC_SUCCESS);
121: }

123: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlaceCyclic_CUPM(PetscBool hermitian_transpose, PetscInt n, PetscInt oldest_index, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
124: {
125:   using ::Petsc::device::cupm::impl::UpperTriangular;
126:   using ::Petsc::device::cupm::DeviceType;
127:   PetscDeviceContext dctx;
128:   PetscDeviceType    device_type;

130:   PetscFunctionBegin;
131:   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
132:   PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
133:   switch (device_type) {
134: #if PetscDefined(HAVE_CUDA)
135:   case PETSC_DEVICE_CUDA:
136:     PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlaceCyclic(dctx, hermitian_transpose, n, oldest_index, A, lda, x, stride));
137:     break;
138: #endif
139: #if PetscDefined(HAVE_HIP)
140:   case PETSC_DEVICE_HIP:
141:     PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlaceCyclic(dctx, hermitian_transpose, n, oldest_index, A, lda, x, stride));
142:     break;
143: #endif
144:   default:
145:     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
146:   }
147:   PetscFunctionReturn(PETSC_SUCCESS);
148: }