Actual source code: sfcupm_impl.hpp

  1: #pragma once

  3: #include "sfcupm.hpp"
  4: #include <../src/sys/objects/device/impls/cupm/kernels.hpp>
  5: #include <petsc/private/cupmatomics.hpp>

  7: namespace Petsc
  8: {

 10: namespace sf
 11: {

 13: namespace cupm
 14: {

 16: namespace kernels
 17: {

 19: /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
 20: PETSC_NODISCARD static PETSC_DEVICE_INLINE_DECL PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid) noexcept
 21: {
 22:   PetscInt        i, j, k, m, n, r;
 23:   const PetscInt *offset, *start, *dx, *dy, *X, *Y;

 25:   n      = opt[0];
 26:   offset = opt + 1;
 27:   start  = opt + n + 2;
 28:   dx     = opt + 2 * n + 2;
 29:   dy     = opt + 3 * n + 2;
 30:   X      = opt + 5 * n + 2;
 31:   Y      = opt + 6 * n + 2;
 32:   for (r = 0; r < n; r++) {
 33:     if (tid < offset[r + 1]) break;
 34:   }
 35:   m = (tid - offset[r]);
 36:   k = m / (dx[r] * dy[r]);
 37:   j = (m - k * dx[r] * dy[r]) / dx[r];
 38:   i = m - k * dx[r] * dy[r] - j * dx[r];

 40:   return start[r] + k * X[r] * Y[r] + j * X[r] + i;
 41: }

 43: /*====================================================================================*/
 44: /*  Templated CUPM kernels for pack/unpack. The Op can be regular or atomic           */
 45: /*====================================================================================*/

 47: /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
 48:    <Type> is PetscReal, which is the primitive type we operate on.
 49:    <bs>   is 16, which says <unit> contains 16 primitive types.
 50:    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
 51:    <EQ>   is 0, which is (bs == BS ? 1 : 0)

 53:   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
 54:   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
 55: */
 56: template <class Type, PetscInt BS, PetscInt EQ>
 57: PETSC_KERNEL_DECL static void d_Pack(PetscInt bs, PetscInt count, PetscInt start, const PetscInt *opt, const PetscInt *idx, const Type *data, Type *buf)
 58: {
 59:   const PetscInt M   = (EQ) ? 1 : bs / BS; /* If EQ, then M=1 enables compiler's const-propagation */
 60:   const PetscInt MBS = M * BS;             /* MBS=bs. We turn MBS into a compile-time const when EQ=1. */

 62:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
 63:     PetscInt t = (opt ? MapTidToIndex(opt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
 64:     PetscInt s = tid * MBS;
 65:     for (PetscInt i = 0; i < MBS; i++) buf[s + i] = data[t + i];
 66:   });
 67: }

 69: template <class Type, class Op, PetscInt BS, PetscInt EQ>
 70: PETSC_KERNEL_DECL static void d_UnpackAndOp(PetscInt bs, PetscInt count, PetscInt start, const PetscInt *opt, const PetscInt *idx, Type *data, const Type *buf)
 71: {
 72:   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
 73:   Op             op;

 75:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
 76:     PetscInt t = (opt ? MapTidToIndex(opt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
 77:     PetscInt s = tid * MBS;
 78:     for (PetscInt i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
 79:   });
 80: }

 82: template <class Type, class Op, PetscInt BS, PetscInt EQ>
 83: PETSC_KERNEL_DECL static void d_FetchAndOp(PetscInt bs, PetscInt count, PetscInt rootstart, const PetscInt *rootopt, const PetscInt *rootidx, Type *rootdata, Type *leafbuf)
 84: {
 85:   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
 86:   Op             op;

 88:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
 89:     PetscInt r = (rootopt ? MapTidToIndex(rootopt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
 90:     PetscInt l = tid * MBS;
 91:     for (PetscInt i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
 92:   });
 93: }

 95: template <class Type, class Op, PetscInt BS, PetscInt EQ>
 96: PETSC_KERNEL_DECL static void d_ScatterAndOp(PetscInt bs, PetscInt count, PetscInt srcx, PetscInt srcy, PetscInt srcX, PetscInt srcY, PetscInt srcStart, const PetscInt *srcIdx, const Type *src, PetscInt dstx, PetscInt dsty, PetscInt dstX, PetscInt dstY, PetscInt dstStart, const PetscInt *dstIdx, Type *dst)
 97: {
 98:   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
 99:   Op             op;

101:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
102:     PetscInt s, t;

104:     if (!srcIdx) { /* src is either contiguous or 3D */
105:       PetscInt k = tid / (srcx * srcy);
106:       PetscInt j = (tid - k * srcx * srcy) / srcx;
107:       PetscInt i = tid - k * srcx * srcy - j * srcx;

109:       s = srcStart + k * srcX * srcY + j * srcX + i;
110:     } else {
111:       s = srcIdx[tid];
112:     }

114:     if (!dstIdx) { /* dst is either contiguous or 3D */
115:       PetscInt k = tid / (dstx * dsty);
116:       PetscInt j = (tid - k * dstx * dsty) / dstx;
117:       PetscInt i = tid - k * dstx * dsty - j * dstx;

119:       t = dstStart + k * dstX * dstY + j * dstX + i;
120:     } else {
121:       t = dstIdx[tid];
122:     }

124:     s *= MBS;
125:     t *= MBS;
126:     for (PetscInt i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
127:   });
128: }

130: template <class Type, class Op, PetscInt BS, PetscInt EQ>
131: PETSC_KERNEL_DECL static void d_FetchAndOpLocal(PetscInt bs, PetscInt count, PetscInt rootstart, const PetscInt *rootopt, const PetscInt *rootidx, Type *rootdata, PetscInt leafstart, const PetscInt *leafopt, const PetscInt *leafidx, const Type *leafdata, Type *leafupdate)
132: {
133:   const PetscInt M = (EQ) ? 1 : bs / BS, MBS = M * BS;
134:   Op             op;

136:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(count, [&](PetscInt tid) {
137:     PetscInt r = (rootopt ? MapTidToIndex(rootopt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
138:     PetscInt l = (leafopt ? MapTidToIndex(leafopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
139:     for (PetscInt i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
140:   });
141: }

143: /*====================================================================================*/
144: /*                             Regular operations on device                           */
145: /*====================================================================================*/
146: template <typename Type>
147: struct Insert {
148:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
149:   {
150:     Type old = x;
151:     x        = y;
152:     return old;
153:   }
154: };
155: template <typename Type>
156: struct Add {
157:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
158:   {
159:     Type old = x;
160:     x += y;
161:     return old;
162:   }
163: };
164: template <typename Type>
165: struct Mult {
166:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
167:   {
168:     Type old = x;
169:     x *= y;
170:     return old;
171:   }
172: };
173: template <typename Type>
174: struct Min {
175:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
176:   {
177:     Type old = x;
178:     x        = PetscMin(x, y);
179:     return old;
180:   }
181: };
182: template <typename Type>
183: struct Max {
184:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
185:   {
186:     Type old = x;
187:     x        = PetscMax(x, y);
188:     return old;
189:   }
190: };
191: template <typename Type>
192: struct LAND {
193:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
194:   {
195:     Type old = x;
196:     x        = x && y;
197:     return old;
198:   }
199: };
200: template <typename Type>
201: struct LOR {
202:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
203:   {
204:     Type old = x;
205:     x        = x || y;
206:     return old;
207:   }
208: };
209: template <typename Type>
210: struct LXOR {
211:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
212:   {
213:     Type old = x;
214:     x        = !x != !y;
215:     return old;
216:   }
217: };
218: template <typename Type>
219: struct BAND {
220:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
221:   {
222:     Type old = x;
223:     x        = x & y;
224:     return old;
225:   }
226: };
227: template <typename Type>
228: struct BOR {
229:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
230:   {
231:     Type old = x;
232:     x        = x | y;
233:     return old;
234:   }
235: };
236: template <typename Type>
237: struct BXOR {
238:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
239:   {
240:     Type old = x;
241:     x        = x ^ y;
242:     return old;
243:   }
244: };
245: template <typename Type>
246: struct Minloc {
247:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
248:   {
249:     Type old = x;
250:     if (y.a < x.a) x = y;
251:     else if (y.a == x.a) x.b = min(x.b, y.b);
252:     return old;
253:   }
254: };
255: template <typename Type>
256: struct Maxloc {
257:   PETSC_DEVICE_DECL Type operator()(Type &x, Type y) const
258:   {
259:     Type old = x;
260:     if (y.a > x.a) x = y;
261:     else if (y.a == x.a) x.b = min(x.b, y.b); /* See MPI MAXLOC */
262:     return old;
263:   }
264: };

266: } // namespace kernels

268: namespace impl
269: {

271: /*====================================================================================*/
272: /*  Wrapper functions of cupm kernels. Function pointers are stored in 'link'         */
273: /*====================================================================================*/
274: template <device::cupm::DeviceType T>
275: template <typename Type, PetscInt BS, PetscInt EQ>
276: inline PetscErrorCode SfInterface<T>::Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data, void *buf) noexcept
277: {
278:   const PetscInt *iarray = opt ? opt->array : NULL;

280:   PetscFunctionBegin;
281:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
282:   if (PetscDefined(USING_NVCC) && !opt && !idx) { /* It is a 'CUDA data to nvshmem buf' memory copy */
283:     PetscCallCUPM(cupmMemcpyAsync(buf, (char *)data + start * link->unitbytes, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream));
284:   } else {
285:     PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_Pack<Type, BS, EQ>, link->bs, count, start, iarray, idx, (const Type *)data, (Type *)buf));
286:   }
287:   PetscFunctionReturn(PETSC_SUCCESS);
288: }

290: template <device::cupm::DeviceType T>
291: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
292: inline PetscErrorCode SfInterface<T>::UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, const void *buf) noexcept
293: {
294:   const PetscInt *iarray = opt ? opt->array : NULL;

296:   PetscFunctionBegin;
297:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
298:   if (PetscDefined(USING_NVCC) && std::is_same<Op, kernels::Insert<Type>>::value && !opt && !idx) { /* It is a 'nvshmem buf to CUDA data' memory copy */
299:     PetscCallCUPM(cupmMemcpyAsync((char *)data + start * link->unitbytes, buf, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream));
300:   } else {
301:     PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_UnpackAndOp<Type, Op, BS, EQ>, link->bs, count, start, iarray, idx, (Type *)data, (const Type *)buf));
302:   }
303:   PetscFunctionReturn(PETSC_SUCCESS);
304: }

306: template <device::cupm::DeviceType T>
307: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
308: inline PetscErrorCode SfInterface<T>::FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf) noexcept
309: {
310:   const PetscInt *iarray = opt ? opt->array : NULL;

312:   PetscFunctionBegin;
313:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
314:   PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_FetchAndOp<Type, Op, BS, EQ>, link->bs, count, start, iarray, idx, (Type *)data, (const Type *)buf));
315:   PetscFunctionReturn(PETSC_SUCCESS);
316: }

318: template <device::cupm::DeviceType T>
319: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
320: inline PetscErrorCode SfInterface<T>::ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst) noexcept
321: {
322:   PetscInt nthreads = 256;
323:   PetscInt nblocks  = (count + nthreads - 1) / nthreads;
324:   PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;

326:   PetscFunctionBegin;
327:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
328:   nblocks = PetscMin(nblocks, link->maxResidentThreadsPerGPU / nthreads);

330:   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use 3D grid and block */
331:   if (srcOpt) {
332:     srcx     = srcOpt->dx[0];
333:     srcy     = srcOpt->dy[0];
334:     srcX     = srcOpt->X[0];
335:     srcY     = srcOpt->Y[0];
336:     srcStart = srcOpt->start[0];
337:     srcIdx   = NULL;
338:   } else if (!srcIdx) {
339:     srcx = srcX = count;
340:     srcy = srcY = 1;
341:   }

343:   if (dstOpt) {
344:     dstx     = dstOpt->dx[0];
345:     dsty     = dstOpt->dy[0];
346:     dstX     = dstOpt->X[0];
347:     dstY     = dstOpt->Y[0];
348:     dstStart = dstOpt->start[0];
349:     dstIdx   = NULL;
350:   } else if (!dstIdx) {
351:     dstx = dstX = count;
352:     dsty = dstY = 1;
353:   }

355:   PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_ScatterAndOp<Type, Op, BS, EQ>, link->bs, count, srcx, srcy, srcX, srcY, srcStart, srcIdx, (const Type *)src, dstx, dsty, dstX, dstY, dstStart, dstIdx, (Type *)dst));
356:   PetscFunctionReturn(PETSC_SUCCESS);
357: }

359: template <device::cupm::DeviceType T>
360: /* Specialization for Insert since we may use cupmMemcpyAsync */
361: template <typename Type, PetscInt BS, PetscInt EQ>
362: inline PetscErrorCode SfInterface<T>::ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst) noexcept
363: {
364:   PetscFunctionBegin;
365:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
366:   /*src and dst are contiguous */
367:   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
368:     PetscCallCUPM(cupmMemcpyAsync((Type *)dst + dstStart * link->bs, (const Type *)src + srcStart * link->bs, count * link->unitbytes, cupmMemcpyDeviceToDevice, link->stream));
369:   } else {
370:     PetscCall(ScatterAndOp<Type, kernels::Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
371:   }
372:   PetscFunctionReturn(PETSC_SUCCESS);
373: }

375: template <device::cupm::DeviceType T>
376: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
377: inline PetscErrorCode SfInterface<T>::FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata, void *leafupdate) noexcept
378: {
379:   const PetscInt *rarray = rootopt ? rootopt->array : NULL;
380:   const PetscInt *larray = leafopt ? leafopt->array : NULL;

382:   PetscFunctionBegin;
383:   if (!count) PetscFunctionReturn(PETSC_SUCCESS);
384:   PetscCall(PetscCUPMLaunchKernel1D(count, 0, link->stream, kernels::d_FetchAndOpLocal<Type, Op, BS, EQ>, link->bs, count, rootstart, rarray, rootidx, (Type *)rootdata, leafstart, larray, leafidx, (const Type *)leafdata, (Type *)leafupdate));
385:   PetscFunctionReturn(PETSC_SUCCESS);
386: }

388: /*====================================================================================*/
389: /*  Init various types and instantiate pack/unpack function pointers                  */
390: /*====================================================================================*/
391: template <device::cupm::DeviceType T>
392: template <typename Type, PetscInt BS, PetscInt EQ>
393: inline void SfInterface<T>::PackInit_RealType(PetscSFLink link) noexcept
394: {
395:   /* Pack/unpack for remote communication */
396:   link->d_Pack            = Pack<Type, BS, EQ>;
397:   link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
398:   link->d_UnpackAndAdd    = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>;
399:   link->d_UnpackAndMult   = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>;
400:   link->d_UnpackAndMin    = UnpackAndOp<Type, kernels::Min<Type>, BS, EQ>;
401:   link->d_UnpackAndMax    = UnpackAndOp<Type, kernels::Max<Type>, BS, EQ>;
402:   link->d_FetchAndAdd     = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>;

404:   /* Scatter for local communication */
405:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
406:   link->d_ScatterAndAdd    = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>;
407:   link->d_ScatterAndMult   = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>;
408:   link->d_ScatterAndMin    = ScatterAndOp<Type, kernels::Min<Type>, BS, EQ>;
409:   link->d_ScatterAndMax    = ScatterAndOp<Type, kernels::Max<Type>, BS, EQ>;
410:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>;

412:   /* Atomic versions when there are data-race possibilities */
413:   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
414:   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
415:   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
416:   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
417:   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
418:   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicAdd<Type>, BS, EQ>;

420:   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
421:   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
422:   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
423:   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
424:   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
425:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicAdd<Type>, BS, EQ>;
426: }

428: /* Have this templated class to specialize for char integers */
429: template <device::cupm::DeviceType T>
430: template <typename Type, PetscInt BS, PetscInt EQ, PetscInt size /*sizeof(Type)*/>
431: struct SfInterface<T>::PackInit_IntegerType_Atomic {
432:   static inline void Init(PetscSFLink link) noexcept
433:   {
434:     link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
435:     link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
436:     link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
437:     link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
438:     link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
439:     link->da_UnpackAndLAND   = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
440:     link->da_UnpackAndLOR    = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
441:     link->da_UnpackAndLXOR   = UnpackAndOp<Type, AtomicLXOR<Type>, BS, EQ>;
442:     link->da_UnpackAndBAND   = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
443:     link->da_UnpackAndBOR    = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
444:     link->da_UnpackAndBXOR   = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
445:     link->da_FetchAndAdd     = FetchAndOp<Type, AtomicAdd<Type>, BS, EQ>;

447:     link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
448:     link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
449:     link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
450:     link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
451:     link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
452:     link->da_ScatterAndLAND   = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
453:     link->da_ScatterAndLOR    = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
454:     link->da_ScatterAndLXOR   = ScatterAndOp<Type, AtomicLXOR<Type>, BS, EQ>;
455:     link->da_ScatterAndBAND   = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
456:     link->da_ScatterAndBOR    = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
457:     link->da_ScatterAndBXOR   = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
458:     link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicAdd<Type>, BS, EQ>;
459:   }
460: };

462: /* CUDA does not support atomics on chars. It is TBD in PETSc. */
463: template <device::cupm::DeviceType T>
464: template <typename Type, PetscInt BS, PetscInt EQ>
465: struct SfInterface<T>::PackInit_IntegerType_Atomic<Type, BS, EQ, 1> {
466:   static inline void Init(PetscSFLink) { /* Nothing to leave function pointers NULL */ }
467: };

469: template <device::cupm::DeviceType T>
470: template <typename Type, PetscInt BS, PetscInt EQ>
471: inline void SfInterface<T>::PackInit_IntegerType(PetscSFLink link) noexcept
472: {
473:   link->d_Pack            = Pack<Type, BS, EQ>;
474:   link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
475:   link->d_UnpackAndAdd    = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>;
476:   link->d_UnpackAndMult   = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>;
477:   link->d_UnpackAndMin    = UnpackAndOp<Type, kernels::Min<Type>, BS, EQ>;
478:   link->d_UnpackAndMax    = UnpackAndOp<Type, kernels::Max<Type>, BS, EQ>;
479:   link->d_UnpackAndLAND   = UnpackAndOp<Type, kernels::LAND<Type>, BS, EQ>;
480:   link->d_UnpackAndLOR    = UnpackAndOp<Type, kernels::LOR<Type>, BS, EQ>;
481:   link->d_UnpackAndLXOR   = UnpackAndOp<Type, kernels::LXOR<Type>, BS, EQ>;
482:   link->d_UnpackAndBAND   = UnpackAndOp<Type, kernels::BAND<Type>, BS, EQ>;
483:   link->d_UnpackAndBOR    = UnpackAndOp<Type, kernels::BOR<Type>, BS, EQ>;
484:   link->d_UnpackAndBXOR   = UnpackAndOp<Type, kernels::BXOR<Type>, BS, EQ>;
485:   link->d_FetchAndAdd     = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>;

487:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
488:   link->d_ScatterAndAdd    = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>;
489:   link->d_ScatterAndMult   = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>;
490:   link->d_ScatterAndMin    = ScatterAndOp<Type, kernels::Min<Type>, BS, EQ>;
491:   link->d_ScatterAndMax    = ScatterAndOp<Type, kernels::Max<Type>, BS, EQ>;
492:   link->d_ScatterAndLAND   = ScatterAndOp<Type, kernels::LAND<Type>, BS, EQ>;
493:   link->d_ScatterAndLOR    = ScatterAndOp<Type, kernels::LOR<Type>, BS, EQ>;
494:   link->d_ScatterAndLXOR   = ScatterAndOp<Type, kernels::LXOR<Type>, BS, EQ>;
495:   link->d_ScatterAndBAND   = ScatterAndOp<Type, kernels::BAND<Type>, BS, EQ>;
496:   link->d_ScatterAndBOR    = ScatterAndOp<Type, kernels::BOR<Type>, BS, EQ>;
497:   link->d_ScatterAndBXOR   = ScatterAndOp<Type, kernels::BXOR<Type>, BS, EQ>;
498:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>;
499:   PackInit_IntegerType_Atomic<Type, BS, EQ, sizeof(Type)>::Init(link);
500: }

502: #if defined(PETSC_HAVE_COMPLEX)
503: template <device::cupm::DeviceType T>
504: template <typename Type, PetscInt BS, PetscInt EQ>
505: inline void SfInterface<T>::PackInit_ComplexType(PetscSFLink link) noexcept
506: {
507:   link->d_Pack            = Pack<Type, BS, EQ>;
508:   link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
509:   link->d_UnpackAndAdd    = UnpackAndOp<Type, kernels::Add<Type>, BS, EQ>;
510:   link->d_UnpackAndMult   = UnpackAndOp<Type, kernels::Mult<Type>, BS, EQ>;
511:   link->d_FetchAndAdd     = FetchAndOp<Type, kernels::Add<Type>, BS, EQ>;

513:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
514:   link->d_ScatterAndAdd    = ScatterAndOp<Type, kernels::Add<Type>, BS, EQ>;
515:   link->d_ScatterAndMult   = ScatterAndOp<Type, kernels::Mult<Type>, BS, EQ>;
516:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, kernels::Add<Type>, BS, EQ>;

518:   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
519:   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
520:   link->da_UnpackAndMult   = NULL; /* Not implemented yet */
521:   link->da_FetchAndAdd     = NULL; /* Return value of atomicAdd on complex is not atomic */

523:   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
524:   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
525: }
526: #endif

528: typedef signed char   SignedChar;
529: typedef unsigned char UnsignedChar;
530: typedef struct {
531:   int a;
532:   int b;
533: } PairInt;
534: typedef struct {
535:   PetscInt a;
536:   PetscInt b;
537: } PairPetscInt;

539: template <device::cupm::DeviceType T>
540: template <typename Type>
541: inline void SfInterface<T>::PackInit_PairType(PetscSFLink link) noexcept
542: {
543:   link->d_Pack            = Pack<Type, 1, 1>;
544:   link->d_UnpackAndInsert = UnpackAndOp<Type, kernels::Insert<Type>, 1, 1>;
545:   link->d_UnpackAndMaxloc = UnpackAndOp<Type, kernels::Maxloc<Type>, 1, 1>;
546:   link->d_UnpackAndMinloc = UnpackAndOp<Type, kernels::Minloc<Type>, 1, 1>;

548:   link->d_ScatterAndInsert = ScatterAndOp<Type, kernels::Insert<Type>, 1, 1>;
549:   link->d_ScatterAndMaxloc = ScatterAndOp<Type, kernels::Maxloc<Type>, 1, 1>;
550:   link->d_ScatterAndMinloc = ScatterAndOp<Type, kernels::Minloc<Type>, 1, 1>;
551:   /* Atomics for pair types are not implemented yet */
552: }

554: template <device::cupm::DeviceType T>
555: template <typename Type, PetscInt BS, PetscInt EQ>
556: inline void SfInterface<T>::PackInit_DumbType(PetscSFLink link) noexcept
557: {
558:   link->d_Pack             = Pack<Type, BS, EQ>;
559:   link->d_UnpackAndInsert  = UnpackAndOp<Type, kernels::Insert<Type>, BS, EQ>;
560:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
561:   /* Atomics for dumb types are not implemented yet */
562: }

564: /* Some device-specific utilities */
565: template <device::cupm::DeviceType T>
566: inline PetscErrorCode SfInterface<T>::LinkSyncDevice(PetscSFLink) noexcept
567: {
568:   PetscFunctionBegin;
569:   PetscCallCUPM(cupmDeviceSynchronize());
570:   PetscFunctionReturn(PETSC_SUCCESS);
571: }

573: template <device::cupm::DeviceType T>
574: inline PetscErrorCode SfInterface<T>::LinkSyncStream(PetscSFLink link) noexcept
575: {
576:   PetscFunctionBegin;
577:   PetscCallCUPM(cupmStreamSynchronize(link->stream));
578:   PetscFunctionReturn(PETSC_SUCCESS);
579: }

581: template <device::cupm::DeviceType T>
582: inline PetscErrorCode SfInterface<T>::LinkMemcpy(PetscSFLink link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) noexcept
583: {
584:   PetscFunctionBegin;
585:   cupmMemcpyKind_t kinds[2][2] = {
586:     {cupmMemcpyHostToHost,   cupmMemcpyHostToDevice  },
587:     {cupmMemcpyDeviceToHost, cupmMemcpyDeviceToDevice}
588:   };

590:   if (n) {
591:     if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { /* Separate HostToHost so that pure-cpu code won't call cupm runtime */
592:       PetscCall(PetscMemcpy(dst, src, n));
593:     } else {
594:       int stype = PetscMemTypeDevice(srcmtype) ? 1 : 0;
595:       int dtype = PetscMemTypeDevice(dstmtype) ? 1 : 0;
596:       PetscCallCUPM(cupmMemcpyAsync(dst, src, n, kinds[stype][dtype], link->stream));
597:     }
598:   }
599:   PetscFunctionReturn(PETSC_SUCCESS);
600: }

602: template <device::cupm::DeviceType T>
603: inline PetscErrorCode SfInterface<T>::Malloc(PetscMemType mtype, size_t size, void **ptr) noexcept
604: {
605:   PetscFunctionBegin;
606:   if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
607:   else if (PetscMemTypeDevice(mtype)) {
608:     PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));
609:     PetscCallCUPM(cupmMalloc(ptr, size));
610:   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
611:   PetscFunctionReturn(PETSC_SUCCESS);
612: }

614: template <device::cupm::DeviceType T>
615: inline PetscErrorCode SfInterface<T>::Free(PetscMemType mtype, void *ptr) noexcept
616: {
617:   PetscFunctionBegin;
618:   if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
619:   else if (PetscMemTypeDevice(mtype)) PetscCallCUPM(cupmFree(ptr));
620:   else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
621:   PetscFunctionReturn(PETSC_SUCCESS);
622: }

624: /* Destructor when the link uses MPI for communication on CUPM device */
625: template <device::cupm::DeviceType T>
626: inline PetscErrorCode SfInterface<T>::LinkDestroy_MPI(PetscSF, PetscSFLink link) noexcept
627: {
628:   PetscFunctionBegin;
629:   for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
630:     PetscCallCUPM(cupmFree(link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
631:     PetscCallCUPM(cupmFree(link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
632:   }
633:   PetscFunctionReturn(PETSC_SUCCESS);
634: }

636: /*====================================================================================*/
637: /*                Main driver to init MPI datatype on device                          */
638: /*====================================================================================*/

640: /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
641: template <device::cupm::DeviceType T>
642: inline PetscErrorCode SfInterface<T>::LinkSetUp(PetscSF sf, PetscSFLink link, MPI_Datatype unit) noexcept
643: {
644:   PetscInt  nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
645:   PetscBool is2Int, is2PetscInt;
646: #if defined(PETSC_HAVE_COMPLEX)
647:   PetscInt nPetscComplex = 0;
648: #endif

650:   PetscFunctionBegin;
651:   if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS);
652:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
653:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
654:   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
655:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
656:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
657:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
658: #if defined(PETSC_HAVE_COMPLEX)
659:   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
660: #endif
661:   PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
662:   PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));

664:   if (is2Int) {
665:     PackInit_PairType<PairInt>(link);
666:   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
667:     PackInit_PairType<PairPetscInt>(link);
668:   } else if (nPetscReal) {
669: #if !defined(PETSC_HAVE_DEVICE)
670:     if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
671:     else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
672:     else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
673:     else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
674:     else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
675:     else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
676:     else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
677:     else if (nPetscReal % 1 == 0)
678: #endif
679:       PackInit_RealType<PetscReal, 1, 0>(link);
680:   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
681: #if !defined(PETSC_HAVE_DEVICE)
682:     if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
683:     else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
684:     else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
685:     else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
686:     else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
687:     else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
688:     else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
689:     else if (nPetscInt % 1 == 0)
690: #endif
691:       PackInit_IntegerType<llint, 1, 0>(link);
692:   } else if (nInt) {
693: #if !defined(PETSC_HAVE_DEVICE)
694:     if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
695:     else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
696:     else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
697:     else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
698:     else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
699:     else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
700:     else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
701:     else if (nInt % 1 == 0)
702: #endif
703:       PackInit_IntegerType<int, 1, 0>(link);
704:   } else if (nSignedChar) {
705: #if !defined(PETSC_HAVE_DEVICE)
706:     if (nSignedChar == 8) PackInit_IntegerType<SignedChar, 8, 1>(link);
707:     else if (nSignedChar % 8 == 0) PackInit_IntegerType<SignedChar, 8, 0>(link);
708:     else if (nSignedChar == 4) PackInit_IntegerType<SignedChar, 4, 1>(link);
709:     else if (nSignedChar % 4 == 0) PackInit_IntegerType<SignedChar, 4, 0>(link);
710:     else if (nSignedChar == 2) PackInit_IntegerType<SignedChar, 2, 1>(link);
711:     else if (nSignedChar % 2 == 0) PackInit_IntegerType<SignedChar, 2, 0>(link);
712:     else if (nSignedChar == 1) PackInit_IntegerType<SignedChar, 1, 1>(link);
713:     else if (nSignedChar % 1 == 0)
714: #endif
715:       PackInit_IntegerType<SignedChar, 1, 0>(link);
716:   } else if (nUnsignedChar) {
717: #if !defined(PETSC_HAVE_DEVICE)
718:     if (nUnsignedChar == 8) PackInit_IntegerType<UnsignedChar, 8, 1>(link);
719:     else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<UnsignedChar, 8, 0>(link);
720:     else if (nUnsignedChar == 4) PackInit_IntegerType<UnsignedChar, 4, 1>(link);
721:     else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<UnsignedChar, 4, 0>(link);
722:     else if (nUnsignedChar == 2) PackInit_IntegerType<UnsignedChar, 2, 1>(link);
723:     else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<UnsignedChar, 2, 0>(link);
724:     else if (nUnsignedChar == 1) PackInit_IntegerType<UnsignedChar, 1, 1>(link);
725:     else if (nUnsignedChar % 1 == 0)
726: #endif
727:       PackInit_IntegerType<UnsignedChar, 1, 0>(link);
728: #if defined(PETSC_HAVE_COMPLEX)
729:   } else if (nPetscComplex) {
730:   #if !defined(PETSC_HAVE_DEVICE)
731:     if (nPetscComplex == 8) PackInit_ComplexType<PetscComplex, 8, 1>(link);
732:     else if (nPetscComplex % 8 == 0) PackInit_ComplexType<PetscComplex, 8, 0>(link);
733:     else if (nPetscComplex == 4) PackInit_ComplexType<PetscComplex, 4, 1>(link);
734:     else if (nPetscComplex % 4 == 0) PackInit_ComplexType<PetscComplex, 4, 0>(link);
735:     else if (nPetscComplex == 2) PackInit_ComplexType<PetscComplex, 2, 1>(link);
736:     else if (nPetscComplex % 2 == 0) PackInit_ComplexType<PetscComplex, 2, 0>(link);
737:     else if (nPetscComplex == 1) PackInit_ComplexType<PetscComplex, 1, 1>(link);
738:     else if (nPetscComplex % 1 == 0)
739:   #endif
740:       PackInit_ComplexType<PetscComplex, 1, 0>(link);
741: #endif
742:   } else {
743:     MPI_Aint lb, nbyte;

745:     PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte));
746:     PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb);
747:     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
748: #if !defined(PETSC_HAVE_DEVICE)
749:       if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
750:       else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
751:       else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
752:       else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
753:       else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
754:       else if (nbyte % 1 == 0)
755: #endif
756:         PackInit_DumbType<char, 1, 0>(link);
757:     } else {
758:       PetscCall(PetscIntCast(nbyte / sizeof(int), &nInt));
759: #if !defined(PETSC_HAVE_DEVICE)
760:       if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
761:       else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
762:       else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
763:       else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
764:       else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
765:       else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
766:       else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
767:       else if (nInt % 1 == 0)
768: #endif
769:         PackInit_DumbType<int, 1, 0>(link);
770:     }
771:   }

773:   if (!sf->maxResidentThreadsPerGPU) { /* Not initialized */
774:     int              device;
775:     cupmDeviceProp_t props;

777:     PetscCallCUPM(cupmGetDevice(&device));
778:     PetscCallCUPM(cupmGetDeviceProperties(&props, device));
779:     sf->maxResidentThreadsPerGPU = props.maxThreadsPerMultiProcessor * props.multiProcessorCount;
780:   }
781:   link->maxResidentThreadsPerGPU = sf->maxResidentThreadsPerGPU;

783:   {
784:     cupmStream_t      *stream;
785:     PetscDeviceContext dctx;

787:     PetscCall(PetscDeviceContextGetCurrentContextAssertType_Internal(&dctx, PETSC_DEVICE_CUPM()));
788:     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
789:     link->stream = *stream;
790:   }
791:   link->Destroy      = LinkDestroy_MPI;
792:   link->SyncDevice   = LinkSyncDevice;
793:   link->SyncStream   = LinkSyncStream;
794:   link->Memcpy       = LinkMemcpy;
795:   link->deviceinited = PETSC_TRUE;
796:   PetscFunctionReturn(PETSC_SUCCESS);
797: }

799: } // namespace impl

801: } // namespace cupm

803: } // namespace sf

805: } // namespace Petsc