Actual source code: vecmpicupm.hpp
1: #pragma once
3: #include <petsc/private/veccupmimpl.h>
4: #include <../src/vec/vec/impls/seq/cupm/vecseqcupm.hpp>
5: #include <../src/vec/vec/impls/mpi/pvecimpl.h>
7: namespace Petsc
8: {
10: namespace vec
11: {
13: namespace cupm
14: {
16: namespace impl
17: {
19: template <device::cupm::DeviceType T>
20: class VecMPI_CUPM : public Vec_CUPMBase<T, VecMPI_CUPM<T>> {
21: public:
22: PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecMPI_CUPM<T>);
23: using VecSeq_T = VecSeq_CUPM<T>;
25: private:
26: PETSC_NODISCARD static Vec_MPI *VecIMPLCast_(Vec) noexcept;
27: PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
28: PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
30: static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
31: static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
32: static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
33: static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
35: static PetscErrorCode CreateMPICUPM_(Vec, PetscDeviceContext, PetscBool /*allocate_missing*/ = PETSC_TRUE, PetscInt /*nghost*/ = 0, PetscScalar * /*host_array*/ = nullptr, PetscScalar * /*device_array*/ = nullptr) noexcept;
37: public:
38: // callable directly via a bespoke function
39: static PetscErrorCode CreateMPICUPM(MPI_Comm, PetscInt, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
40: static PetscErrorCode CreateMPICUPMWithArrays(MPI_Comm, PetscInt, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
42: static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
43: static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
44: static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
45: static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
46: static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
47: static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
48: static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
49: static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
50: static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
51: static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
52: static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
53: static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
54: };
56: } // namespace impl
58: template <device::cupm::DeviceType T>
59: inline PetscErrorCode VecCreateMPICUPMAsync(MPI_Comm comm, PetscInt n, PetscInt N, Vec *v) noexcept
60: {
61: PetscFunctionBegin;
62: PetscAssertPointer(v, 4);
63: PetscCall(impl::VecMPI_CUPM<T>::CreateMPICUPM(comm, 0, n, N, v, PETSC_TRUE));
64: PetscFunctionReturn(PETSC_SUCCESS);
65: }
67: template <device::cupm::DeviceType T>
68: inline PetscErrorCode VecCreateMPICUPMWithArrays(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v)
69: {
70: PetscFunctionBegin;
71: if (n && cpuarray) PetscAssertPointer(cpuarray, 5);
72: PetscAssertPointer(v, 7);
73: PetscCall(impl::VecMPI_CUPM<T>::CreateMPICUPMWithArrays(comm, bs, n, N, cpuarray, gpuarray, v));
74: PetscFunctionReturn(PETSC_SUCCESS);
75: }
77: template <device::cupm::DeviceType T>
78: inline PetscErrorCode VecCreateMPICUPMWithArray(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar gpuarray[], Vec *v)
79: {
80: PetscFunctionBegin;
81: PetscCall(VecCreateMPICUPMWithArrays<T>(comm, bs, n, N, nullptr, gpuarray, v));
82: PetscFunctionReturn(PETSC_SUCCESS);
83: }
85: } // namespace cupm
87: } // namespace vec
89: } // namespace Petsc
91: #if PetscDefined(HAVE_CUDA)
92: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecMPI_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
93: #endif
95: #if PetscDefined(HAVE_HIP)
96: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecMPI_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
97: #endif