Actual source code: vecmpicupm_impl.hpp

  1: #pragma once

  3: #include "vecmpicupm.hpp"

  5: #include <../src/sys/objects/device/impls/cupm/kernels.hpp>

  7: #include <petsc/private/sfimpl.h>

  9: namespace Petsc
 10: {

 12: namespace vec
 13: {

 15: namespace cupm
 16: {

 18: namespace impl
 19: {

 21: template <device::cupm::DeviceType T>
 22: inline Vec_MPI *VecMPI_CUPM<T>::VecIMPLCast_(Vec v) noexcept
 23: {
 24:   return static_cast<Vec_MPI *>(v->data);
 25: }

 27: template <device::cupm::DeviceType T>
 28: inline constexpr VecType VecMPI_CUPM<T>::VECIMPLCUPM_() noexcept
 29: {
 30:   return VECMPICUPM();
 31: }

 33: template <device::cupm::DeviceType T>
 34: inline constexpr VecType VecMPI_CUPM<T>::VECIMPL_() noexcept
 35: {
 36:   return VECMPI;
 37: }

 39: template <device::cupm::DeviceType T>
 40: inline PetscErrorCode VecMPI_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
 41: {
 42:   PetscFunctionBegin;
 43:   PetscCall(VecSeq_T::ClearAsyncFunctions(v));
 44:   PetscCall(VecDestroy_MPI(v));
 45:   PetscFunctionReturn(PETSC_SUCCESS);
 46: }

 48: template <device::cupm::DeviceType T>
 49: inline PetscErrorCode VecMPI_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
 50: {
 51:   return VecResetArray_MPI(v);
 52: }

 54: template <device::cupm::DeviceType T>
 55: inline PetscErrorCode VecMPI_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
 56: {
 57:   return VecPlaceArray_MPI(v, a);
 58: }

 60: template <device::cupm::DeviceType T>
 61: inline PetscErrorCode VecMPI_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt nghost, PetscScalar *) noexcept
 62: {
 63:   PetscFunctionBegin;
 64:   if (alloc_missing) *alloc_missing = PETSC_TRUE;
 65:   // note host_array is always ignored, we never create it as part of the construction sequence
 66:   // for VecMPI since we always want to either allocate it ourselves with pinned memory or set
 67:   // it in Initialize_CUPMBase()
 68:   PetscCall(VecCreate_MPI_Private(v, PETSC_FALSE, nghost, nullptr));
 69:   PetscCall(VecSeq_T::InitializeAsyncFunctions(v));
 70:   PetscFunctionReturn(PETSC_SUCCESS);
 71: }

 73: template <device::cupm::DeviceType T>
 74: inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPM_(Vec v, PetscDeviceContext dctx, PetscBool allocate_missing, PetscInt nghost, PetscScalar *host_array, PetscScalar *device_array) noexcept
 75: {
 76:   PetscFunctionBegin;
 77:   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, nghost));
 78:   PetscCall(Initialize_CUPMBase(v, allocate_missing, host_array, device_array, dctx));
 79:   PetscFunctionReturn(PETSC_SUCCESS);
 80: }

 82: // ================================================================================== //
 83: //                                                                                    //
 84: //                                  public methods                                    //
 85: //                                                                                    //
 86: // ================================================================================== //

 88: // ================================================================================== //
 89: //                             constructors/destructors                               //

 91: // VecCreateMPICUPM()
 92: template <device::cupm::DeviceType T>
 93: inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPM(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, Vec *v, PetscBool call_set_type) noexcept
 94: {
 95:   PetscFunctionBegin;
 96:   PetscCall(Create_CUPMBase(comm, bs, n, N, v, call_set_type));
 97:   PetscFunctionReturn(PETSC_SUCCESS);
 98: }

100: // VecCreateMPICUPMWithArray[s]()
101: template <device::cupm::DeviceType T>
102: inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPMWithArrays(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
103: {
104:   PetscDeviceContext dctx;

106:   PetscFunctionBegin;
107:   PetscCall(GetHandles_(&dctx));
108:   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
109:   // CreateMPICUPM_() is called!
110:   PetscCall(CreateMPICUPM(comm, bs, n, N, v, PETSC_FALSE));
111:   PetscCall(CreateMPICUPM_(*v, dctx, PETSC_FALSE, 0, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
112:   PetscFunctionReturn(PETSC_SUCCESS);
113: }

115: // v->ops->duplicate
116: template <device::cupm::DeviceType T>
117: inline PetscErrorCode VecMPI_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
118: {
119:   const auto         vimpl  = VecIMPLCast(v);
120:   const auto         nghost = vimpl->nghost;
121:   PetscDeviceContext dctx;

123:   PetscFunctionBegin;
124:   PetscCall(GetHandles_(&dctx));
125:   // does not call VecSetType(), we set up the data structures ourselves
126:   PetscCall(Duplicate_CUPMBase(v, y, dctx, [=](Vec z) { return CreateMPICUPM_(z, dctx, PETSC_FALSE, nghost); }));

128:   /* save local representation of the parallel vector (and scatter) if it exists */
129:   if (const auto locrep = vimpl->localrep) {
130:     const auto   yimpl   = VecIMPLCast(*y);
131:     auto        &ylocrep = yimpl->localrep;
132:     PetscScalar *array;

134:     PetscCall(VecGetArray(*y, &array));
135:     PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, std::abs(v->map->bs), v->map->n + nghost, array, &ylocrep));
136:     PetscCall(VecRestoreArray(*y, &array));
137:     ylocrep->ops[0] = locrep->ops[0];
138:     if (const auto scatter = (yimpl->localupdate = vimpl->localupdate)) PetscCall(PetscObjectReference(PetscObjectCast(scatter)));
139:   }
140:   PetscFunctionReturn(PETSC_SUCCESS);
141: }

143: // v->ops->bintocpu
144: template <device::cupm::DeviceType T>
145: inline PetscErrorCode VecMPI_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
146: {
147:   PetscDeviceContext dctx;

149:   PetscFunctionBegin;
150:   PetscCall(GetHandles_(&dctx));
151:   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));

153:   VecSetOp_CUPM(dot, VecDot_MPI, Dot);
154:   VecSetOp_CUPM(mdot, VecMDot_MPI, MDot);
155:   VecSetOp_CUPM(norm, VecNorm_MPI, Norm);
156:   VecSetOp_CUPM(tdot, VecTDot_MPI, TDot);
157:   VecSetOp_CUPM(resetarray, VecResetArray_MPI, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
158:   VecSetOp_CUPM(placearray, VecPlaceArray_MPI, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
159:   VecSetOp_CUPM(max, VecMax_MPI, Max);
160:   VecSetOp_CUPM(min, VecMin_MPI, Min);
161:   PetscFunctionReturn(PETSC_SUCCESS);
162: }

164: // ================================================================================== //
165: //                                   compute methods                                  //

167: template <device::cupm::DeviceType T>
168: inline PetscErrorCode VecMPI_CUPM<T>::Norm(Vec v, NormType type, PetscReal *z) noexcept
169: {
170:   PetscFunctionBegin;
171:   PetscCall(VecNorm_MPI_Default(v, type, z, VecSeq_T::Norm));
172:   PetscFunctionReturn(PETSC_SUCCESS);
173: }

175: template <device::cupm::DeviceType T>
176: inline PetscErrorCode VecMPI_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
177: {
178:   PetscFunctionBegin;
179:   PetscCall(VecErrorWeightedNorms_MPI_Default(U, Y, E, wnormtype, atol, vatol, rtol, vrtol, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc, VecSeq_T::ErrorWnorm));
180:   PetscFunctionReturn(PETSC_SUCCESS);
181: }

183: template <device::cupm::DeviceType T>
184: inline PetscErrorCode VecMPI_CUPM<T>::Dot(Vec x, Vec y, PetscScalar *z) noexcept
185: {
186:   PetscFunctionBegin;
187:   PetscCall(VecXDot_MPI_Default(x, y, z, VecSeq_T::Dot));
188:   PetscFunctionReturn(PETSC_SUCCESS);
189: }

191: template <device::cupm::DeviceType T>
192: inline PetscErrorCode VecMPI_CUPM<T>::TDot(Vec x, Vec y, PetscScalar *z) noexcept
193: {
194:   PetscFunctionBegin;
195:   PetscCall(VecXDot_MPI_Default(x, y, z, VecSeq_T::TDot));
196:   PetscFunctionReturn(PETSC_SUCCESS);
197: }

199: template <device::cupm::DeviceType T>
200: inline PetscErrorCode VecMPI_CUPM<T>::MDot(Vec x, PetscInt nv, const Vec y[], PetscScalar *z) noexcept
201: {
202:   PetscFunctionBegin;
203:   PetscCall(VecMXDot_MPI_Default(x, nv, y, z, VecSeq_T::MDot));
204:   PetscFunctionReturn(PETSC_SUCCESS);
205: }

207: template <device::cupm::DeviceType T>
208: inline PetscErrorCode VecMPI_CUPM<T>::DotNorm2(Vec x, Vec y, PetscScalar *dp, PetscScalar *nm) noexcept
209: {
210:   PetscFunctionBegin;
211:   PetscCall(VecDotNorm2_MPI_Default(x, y, dp, nm, VecSeq_T::DotNorm2));
212:   PetscFunctionReturn(PETSC_SUCCESS);
213: }

215: template <device::cupm::DeviceType T>
216: inline PetscErrorCode VecMPI_CUPM<T>::Max(Vec x, PetscInt *idx, PetscReal *z) noexcept
217: {
218:   const MPI_Op ops[] = {MPIU_MAXLOC, MPIU_MAX};

220:   PetscFunctionBegin;
221:   PetscCall(VecMinMax_MPI_Default(x, idx, z, VecSeq_T::Max, ops));
222:   PetscFunctionReturn(PETSC_SUCCESS);
223: }

225: template <device::cupm::DeviceType T>
226: inline PetscErrorCode VecMPI_CUPM<T>::Min(Vec x, PetscInt *idx, PetscReal *z) noexcept
227: {
228:   const MPI_Op ops[] = {MPIU_MINLOC, MPIU_MIN};

230:   PetscFunctionBegin;
231:   PetscCall(VecMinMax_MPI_Default(x, idx, z, VecSeq_T::Min, ops));
232:   PetscFunctionReturn(PETSC_SUCCESS);
233: }

235: template <device::cupm::DeviceType T>
236: inline PetscErrorCode VecMPI_CUPM<T>::SetPreallocationCOO(Vec x, PetscCount ncoo, const PetscInt coo_i[]) noexcept
237: {
238:   PetscDeviceContext dctx;

240:   PetscFunctionBegin;
241:   PetscCall(GetHandles_(&dctx));
242:   PetscCall(VecSetPreallocationCOO_MPI(x, ncoo, coo_i));
243:   // both of these must exist for this to work
244:   PetscCall(VecCUPMAllocateCheck_(x));
245:   {
246:     const auto vcu  = VecCUPMCast(x);
247:     const auto vmpi = VecIMPLCast(x);

249:     // clang-format off
250:     PetscCall(
251:       SetPreallocationCOO_CUPMBase(
252:         x, ncoo, coo_i, dctx,
253:         util::make_array(
254:           make_coo_pair(vcu->imap2_d, vmpi->imap2, vmpi->nnz2),
255:           make_coo_pair(vcu->jmap2_d, vmpi->jmap2, vmpi->nnz2 + 1),
256:           make_coo_pair(vcu->perm2_d, vmpi->perm2, vmpi->recvlen),
257:           make_coo_pair(vcu->Cperm_d, vmpi->Cperm, vmpi->sendlen)
258:         ),
259:         util::make_array(
260:           make_coo_pair(vcu->sendbuf_d, vmpi->sendbuf, vmpi->sendlen),
261:           make_coo_pair(vcu->recvbuf_d, vmpi->recvbuf, vmpi->recvlen)
262:         )
263:       )
264:     );
265:     // clang-format on
266:   }
267:   PetscFunctionReturn(PETSC_SUCCESS);
268: }

270: namespace kernels
271: {

273: namespace
274: {

276: PETSC_KERNEL_DECL void pack_coo_values(const PetscScalar *PETSC_RESTRICT vv, PetscCount nnz, const PetscCount *PETSC_RESTRICT perm, PetscScalar *PETSC_RESTRICT buf)
277: {
278:   Petsc::device::cupm::kernels::util::grid_stride_1D(nnz, [=](PetscCount i) { buf[i] = vv[perm[i]]; });
279:   return;
280: }

282: PETSC_KERNEL_DECL void add_remote_coo_values(const PetscScalar *PETSC_RESTRICT vv, PetscCount nnz2, const PetscCount *PETSC_RESTRICT imap2, const PetscCount *PETSC_RESTRICT jmap2, const PetscCount *PETSC_RESTRICT perm2, PetscScalar *PETSC_RESTRICT xv)
283: {
284:   add_coo_values_impl(vv, nnz2, jmap2, perm2, ADD_VALUES, xv, [=](PetscCount i) { return imap2[i]; });
285:   return;
286: }

288: } // namespace

290: #if PetscDefined(USING_HCC)
291: namespace do_not_use
292: {

294: // Needed to silence clang warning:
295: //
296: // warning: function 'FUNCTION NAME' is not needed and will not be emitted
297: //
298: // The warning is silly, since the function *is* used, however the host compiler does not
299: // appear see this. Likely because the function using it is in a template.
300: //
301: // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
302: inline void silence_warning_function_pack_coo_values_is_not_needed_and_will_not_be_emitted()
303: {
304:   (void)pack_coo_values;
305: }

307: inline void silence_warning_function_add_remote_coo_values_is_not_needed_and_will_not_be_emitted()
308: {
309:   (void)add_remote_coo_values;
310: }

312: } // namespace do_not_use
313: #endif

315: } // namespace kernels

317: template <device::cupm::DeviceType T>
318: inline PetscErrorCode VecMPI_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
319: {
320:   PetscDeviceContext dctx;
321:   PetscMemType       v_memtype;
322:   cupmStream_t       stream;

324:   PetscFunctionBegin;
325:   PetscCall(GetHandles_(&dctx, &stream));
326:   PetscCall(PetscGetMemType(v, &v_memtype));
327:   {
328:     const auto vmpi      = VecIMPLCast(x);
329:     const auto vcu       = VecCUPMCast(x);
330:     const auto sf        = vmpi->coo_sf;
331:     const auto sendbuf_d = vcu->sendbuf_d;
332:     const auto recvbuf_d = vcu->recvbuf_d;
333:     const auto xv        = imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data();
334:     auto       vv        = const_cast<PetscScalar *>(v);

336:     if (PetscMemTypeHost(v_memtype)) {
337:       const auto size = vmpi->coo_n;

339:       /* If user gave v[] in host, we might need to copy it to device if any */
340:       PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
341:       PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
342:     }

344:     /* Pack entries to be sent to remote */
345:     if (const auto sendlen = vmpi->sendlen) {
346:       PetscCall(PetscCUPMLaunchKernel1D(sendlen, 0, stream, kernels::pack_coo_values, vv, sendlen, vcu->Cperm_d, sendbuf_d));
347:       // need to sync up here since we are about to send this to petscsf
348:       // REVIEW ME: no we dont, sf just needs to learn to use PetscDeviceContext
349:       PetscCallCUPM(cupmStreamSynchronize(stream));
350:     }

352:     PetscCall(PetscSFReduceWithMemTypeBegin(sf, MPIU_SCALAR, PETSC_MEMTYPE_CUPM(), sendbuf_d, PETSC_MEMTYPE_CUPM(), recvbuf_d, MPI_REPLACE));

354:     if (const auto n = x->map->n) PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, xv));

356:     PetscCall(PetscSFReduceEnd(sf, MPIU_SCALAR, sendbuf_d, recvbuf_d, MPI_REPLACE));

358:     /* Add received remote entries */
359:     if (const auto nnz2 = vmpi->nnz2) PetscCall(PetscCUPMLaunchKernel1D(nnz2, 0, stream, kernels::add_remote_coo_values, recvbuf_d, nnz2, vcu->imap2_d, vcu->jmap2_d, vcu->perm2_d, xv));

361:     if (PetscMemTypeHost(v_memtype)) PetscCall(PetscDeviceFree(dctx, vv));
362:     PetscCall(PetscDeviceContextSynchronize(dctx));
363:   }
364:   PetscFunctionReturn(PETSC_SUCCESS);
365: }

367: } // namespace impl

369: } // namespace cupm

371: } // namespace vec

373: } // namespace Petsc