Actual source code: vecmpicupm.hpp

  1: #pragma once

  3: #include <petsc/private/veccupmimpl.h>
  4: #include <../src/vec/vec/impls/seq/cupm/vecseqcupm.hpp>
  5: #include <../src/vec/vec/impls/mpi/pvecimpl.h>

  7: namespace Petsc
  8: {

 10: namespace vec
 11: {

 13: namespace cupm
 14: {

 16: namespace impl
 17: {

 19: template <device::cupm::DeviceType T>
 20: class VecMPI_CUPM : public Vec_CUPMBase<T, VecMPI_CUPM<T>> {
 21: public:
 22:   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecMPI_CUPM<T>);
 23:   using VecSeq_T = VecSeq_CUPM<T>;

 25: private:
 26:   PETSC_NODISCARD static Vec_MPI          *VecIMPLCast_(Vec) noexcept;
 27:   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
 28:   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;

 30:   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
 31:   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
 32:   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
 33:   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;

 35:   static PetscErrorCode CreateMPICUPM_(Vec, PetscDeviceContext, PetscBool /*allocate_missing*/ = PETSC_TRUE, PetscInt /*nghost*/ = 0, PetscScalar * /*host_array*/ = nullptr, PetscScalar * /*device_array*/ = nullptr) noexcept;

 37: public:
 38:   // callable directly via a bespoke function
 39:   static PetscErrorCode CreateMPICUPM(MPI_Comm, PetscInt, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
 40:   static PetscErrorCode CreateMPICUPMWithArrays(MPI_Comm, PetscInt, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;

 42:   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
 43:   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
 44:   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
 45:   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
 46:   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
 47:   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
 48:   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
 49:   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
 50:   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
 51:   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
 52:   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
 53:   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
 54: };

 56: } // namespace impl

 58: template <device::cupm::DeviceType T>
 59: inline PetscErrorCode VecCreateMPICUPMAsync(MPI_Comm comm, PetscInt n, PetscInt N, Vec *v) noexcept
 60: {
 61:   PetscFunctionBegin;
 62:   PetscAssertPointer(v, 4);
 63:   PetscCall(impl::VecMPI_CUPM<T>::CreateMPICUPM(comm, 0, n, N, v, PETSC_TRUE));
 64:   PetscFunctionReturn(PETSC_SUCCESS);
 65: }

 67: template <device::cupm::DeviceType T>
 68: inline PetscErrorCode VecCreateMPICUPMWithArrays(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v)
 69: {
 70:   PetscFunctionBegin;
 71:   if (n && cpuarray) PetscAssertPointer(cpuarray, 5);
 72:   PetscAssertPointer(v, 7);
 73:   PetscCall(impl::VecMPI_CUPM<T>::CreateMPICUPMWithArrays(comm, bs, n, N, cpuarray, gpuarray, v));
 74:   PetscFunctionReturn(PETSC_SUCCESS);
 75: }

 77: template <device::cupm::DeviceType T>
 78: inline PetscErrorCode VecCreateMPICUPMWithArray(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar gpuarray[], Vec *v)
 79: {
 80:   PetscFunctionBegin;
 81:   PetscCall(VecCreateMPICUPMWithArrays<T>(comm, bs, n, N, nullptr, gpuarray, v));
 82:   PetscFunctionReturn(PETSC_SUCCESS);
 83: }

 85: } // namespace cupm

 87: } // namespace vec

 89: } // namespace Petsc

 91: #if PetscDefined(HAVE_CUDA)
 92: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecMPI_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
 93: #endif

 95: #if PetscDefined(HAVE_HIP)
 96: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecMPI_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
 97: #endif