Actual source code: sfnvshmem.cu

  1: #include <petsc/private/cudavecimpl.h>
  2: #include <../src/vec/is/sf/impls/basic/sfpack.h>
  3: #include <mpi.h>
  4: #include <nvshmem.h>
  5: #include <nvshmemx.h>

  7: PetscErrorCode PetscNvshmemInitializeCheck(void)
  8: {
  9:   PetscFunctionBegin;
 10:   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
 11:     nvshmemx_init_attr_t attr;
 12:     attr.mpi_comm = &PETSC_COMM_WORLD;
 13:     PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
 14:     PetscCall(nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr));
 15:     PetscNvshmemInitialized = PETSC_TRUE;
 16:     PetscBeganNvshmem       = PETSC_TRUE;
 17:   }
 18:   PetscFunctionReturn(PETSC_SUCCESS);
 19: }

 21: PetscErrorCode PetscNvshmemMalloc(size_t size, void **ptr)
 22: {
 23:   PetscFunctionBegin;
 24:   PetscCall(PetscNvshmemInitializeCheck());
 25:   *ptr = nvshmem_malloc(size);
 26:   PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_malloc() failed to allocate %zu bytes", size);
 27:   PetscFunctionReturn(PETSC_SUCCESS);
 28: }

 30: PetscErrorCode PetscNvshmemCalloc(size_t size, void **ptr)
 31: {
 32:   PetscFunctionBegin;
 33:   PetscCall(PetscNvshmemInitializeCheck());
 34:   *ptr = nvshmem_calloc(size, 1);
 35:   PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_calloc() failed to allocate %zu bytes", size);
 36:   PetscFunctionReturn(PETSC_SUCCESS);
 37: }

 39: PetscErrorCode PetscNvshmemFree_Private(void *ptr)
 40: {
 41:   PetscFunctionBegin;
 42:   nvshmem_free(ptr);
 43:   PetscFunctionReturn(PETSC_SUCCESS);
 44: }

 46: PetscErrorCode PetscNvshmemFinalize(void)
 47: {
 48:   PetscFunctionBegin;
 49:   nvshmem_finalize();
 50:   PetscFunctionReturn(PETSC_SUCCESS);
 51: }

 53: /* Free nvshmem related fields in the SF */
 54: PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
 55: {
 56:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;

 58:   PetscFunctionBegin;
 59:   PetscCall(PetscFree2(bas->leafsigdisp, bas->leafbufdisp));
 60:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafbufdisp_d));
 61:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafsigdisp_d));
 62:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->iranks_d));
 63:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->ioffset_d));

 65:   PetscCall(PetscFree2(sf->rootsigdisp, sf->rootbufdisp));
 66:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootbufdisp_d));
 67:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootsigdisp_d));
 68:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->ranks_d));
 69:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->roffset_d));
 70:   PetscFunctionReturn(PETSC_SUCCESS);
 71: }

 73: /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependent fields) */
 74: static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
 75: {
 76:   cudaError_t    cerr;
 77:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
 78:   PetscInt       i, nRemoteRootRanks, nRemoteLeafRanks;
 79:   PetscMPIInt    tag;
 80:   MPI_Comm       comm;
 81:   MPI_Request   *rootreqs, *leafreqs;
 82:   PetscInt       tmp, stmp[4], rtmp[4]; /* tmps for send/recv buffers */

 84:   PetscFunctionBegin;
 85:   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
 86:   PetscCall(PetscObjectGetNewTag((PetscObject)sf, &tag));

 88:   nRemoteRootRanks      = sf->nranks - sf->ndranks;
 89:   nRemoteLeafRanks      = bas->niranks - bas->ndiranks;
 90:   sf->nRemoteRootRanks  = nRemoteRootRanks;
 91:   bas->nRemoteLeafRanks = nRemoteLeafRanks;

 93:   PetscCall(PetscMalloc2(nRemoteLeafRanks, &rootreqs, nRemoteRootRanks, &leafreqs));

 95:   stmp[0] = nRemoteRootRanks;
 96:   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
 97:   stmp[2] = nRemoteLeafRanks;
 98:   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];

100:   PetscCallMPI(MPIU_Allreduce(stmp, rtmp, 4, MPIU_INT, MPI_MAX, comm));

102:   sf->nRemoteRootRanksMax  = rtmp[0];
103:   sf->leafbuflen_rmax      = rtmp[1];
104:   bas->nRemoteLeafRanksMax = rtmp[2];
105:   bas->rootbuflen_rmax     = rtmp[3];

107:   /* Total four rounds of MPI communications to set up the nvshmem fields */

109:   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
110:   PetscCall(PetscMalloc2(nRemoteRootRanks, &sf->rootsigdisp, nRemoteRootRanks, &sf->rootbufdisp));
111:   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPIU_Irecv(&sf->rootsigdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
112:   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm));                              /* Roots send. Note i changes, so we use MPI_Send. */
113:   PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));

115:   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPIU_Irecv(&sf->rootbufdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
116:   for (i = 0; i < nRemoteLeafRanks; i++) {
117:     tmp = bas->ioffset[i + bas->ndiranks] - bas->ioffset[bas->ndiranks];
118:     PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm)); /* Roots send. Note tmp changes, so we use MPI_Send. */
119:   }
120:   PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));

122:   PetscCallCUDA(cudaMalloc((void **)&sf->rootbufdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
123:   PetscCallCUDA(cudaMalloc((void **)&sf->rootsigdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
124:   PetscCallCUDA(cudaMalloc((void **)&sf->ranks_d, nRemoteRootRanks * sizeof(PetscMPIInt)));
125:   PetscCallCUDA(cudaMalloc((void **)&sf->roffset_d, (nRemoteRootRanks + 1) * sizeof(PetscInt)));

127:   PetscCallCUDA(cudaMemcpyAsync(sf->rootbufdisp_d, sf->rootbufdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
128:   PetscCallCUDA(cudaMemcpyAsync(sf->rootsigdisp_d, sf->rootsigdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
129:   PetscCallCUDA(cudaMemcpyAsync(sf->ranks_d, sf->ranks + sf->ndranks, nRemoteRootRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
130:   PetscCallCUDA(cudaMemcpyAsync(sf->roffset_d, sf->roffset + sf->ndranks, (nRemoteRootRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));

132:   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
133:   PetscCall(PetscMalloc2(nRemoteLeafRanks, &bas->leafsigdisp, nRemoteLeafRanks, &bas->leafbufdisp));
134:   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPIU_Irecv(&bas->leafsigdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
135:   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
136:   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));

138:   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPIU_Irecv(&bas->leafbufdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
139:   for (i = 0; i < nRemoteRootRanks; i++) {
140:     tmp = sf->roffset[i + sf->ndranks] - sf->roffset[sf->ndranks];
141:     PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
142:   }
143:   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));

145:   PetscCallCUDA(cudaMalloc((void **)&bas->leafbufdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
146:   PetscCallCUDA(cudaMalloc((void **)&bas->leafsigdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
147:   PetscCallCUDA(cudaMalloc((void **)&bas->iranks_d, nRemoteLeafRanks * sizeof(PetscMPIInt)));
148:   PetscCallCUDA(cudaMalloc((void **)&bas->ioffset_d, (nRemoteLeafRanks + 1) * sizeof(PetscInt)));

150:   PetscCallCUDA(cudaMemcpyAsync(bas->leafbufdisp_d, bas->leafbufdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
151:   PetscCallCUDA(cudaMemcpyAsync(bas->leafsigdisp_d, bas->leafsigdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
152:   PetscCallCUDA(cudaMemcpyAsync(bas->iranks_d, bas->iranks + bas->ndiranks, nRemoteLeafRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
153:   PetscCallCUDA(cudaMemcpyAsync(bas->ioffset_d, bas->ioffset + bas->ndiranks, (nRemoteLeafRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));

155:   PetscCall(PetscFree2(rootreqs, leafreqs));
156:   PetscFunctionReturn(PETSC_SUCCESS);
157: }

159: PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, PetscBool *use_nvshmem)
160: {
161:   MPI_Comm    comm;
162:   PetscBool   isBasic;
163:   PetscMPIInt result = MPI_UNEQUAL;

165:   PetscFunctionBegin;
166:   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
167:   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
168:      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
169:   */
170:   sf->checked_nvshmem_eligibility = PETSC_TRUE;
171:   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
172:     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
173:     PetscCall(PetscObjectTypeCompare((PetscObject)sf, PETSCSFBASIC, &isBasic));
174:     if (isBasic) PetscCallMPI(MPI_Comm_compare(PETSC_COMM_WORLD, comm, &result));
175:     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */

177:     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
178:        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
179:        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
180:     */
181:     if (sf->use_nvshmem) {
182:       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
183:       PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, &hasNullRank, 1, MPIU_INT, MPI_LOR, comm));
184:       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
185:     }
186:     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
187:   }

189:   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
190:   if (sf->use_nvshmem) {
191:     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
192:     PetscInt allCuda = oneCuda;                                                                                          /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
193: #if defined(PETSC_USE_DEBUG)                                                                                             /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
194:     PetscCallMPI(MPIU_Allreduce(&oneCuda, &allCuda, 1, MPIU_INT, MPI_LAND, comm));
195:     PetscCheck(allCuda == oneCuda, comm, PETSC_ERR_SUP, "root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
196: #endif
197:     if (allCuda) {
198:       PetscCall(PetscNvshmemInitializeCheck());
199:       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
200:         PetscCall(PetscSFSetUp_Basic_NVSHMEM(sf));
201:         sf->setup_nvshmem = PETSC_TRUE;
202:       }
203:       *use_nvshmem = PETSC_TRUE;
204:     } else {
205:       *use_nvshmem = PETSC_FALSE;
206:     }
207:   } else {
208:     *use_nvshmem = PETSC_FALSE;
209:   }
210:   PetscFunctionReturn(PETSC_SUCCESS);
211: }

213: /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
214: static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
215: {
216:   cudaError_t    cerr;
217:   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
218:   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF) ? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];

220:   PetscFunctionBegin;
221:   if (buflen) {
222:     PetscCallCUDA(cudaEventRecord(link->dataReady, link->stream));
223:     PetscCallCUDA(cudaStreamWaitEvent(link->remoteCommStream, link->dataReady, 0));
224:   }
225:   PetscFunctionReturn(PETSC_SUCCESS);
226: }

228: /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
229: static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
230: {
231:   cudaError_t    cerr;
232:   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
233:   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF) ? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];

235:   PetscFunctionBegin;
236:   /* If unpack to non-null device buffer, build the endRemoteComm dependence */
237:   if (buflen) {
238:     PetscCallCUDA(cudaEventRecord(link->endRemoteComm, link->remoteCommStream));
239:     PetscCallCUDA(cudaStreamWaitEvent(link->stream, link->endRemoteComm, 0));
240:   }
241:   PetscFunctionReturn(PETSC_SUCCESS);
242: }

244: /* Send/Put signals to remote ranks

246:  Input parameters:
247:   + n        - Number of remote ranks
248:   . sig      - Signal address in symmetric heap
249:   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
250:   . ranks    - remote ranks
251:   - newval   - Set signals to this value
252: */
253: __global__ static void NvshmemSendSignals(PetscInt n, uint64_t *sig, PetscInt *sigdisp, PetscMPIInt *ranks, uint64_t newval)
254: {
255:   int i = blockIdx.x * blockDim.x + threadIdx.x;

257:   /* Each thread puts one remote signal */
258:   if (i < n) nvshmemx_uint64_signal(sig + sigdisp[i], newval, ranks[i]);
259: }

261: /* Wait until local signals equal to the expected value and then set them to a new value

263:  Input parameters:
264:   + n        - Number of signals
265:   . sig      - Local signal address
266:   . expval   - expected value
267:   - newval   - Set signals to this new value
268: */
269: __global__ static void NvshmemWaitSignals(PetscInt n, uint64_t *sig, uint64_t expval, uint64_t newval)
270: {
271: #if 0
272:   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
273:   int i = blockIdx.x*blockDim.x + threadIdx.x;
274:   if (i < n) {
275:     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
276:     sig[i] = newval;
277:   }
278: #else
279:   nvshmem_uint64_wait_until_all(sig, n, NULL /*no mask*/, NVSHMEM_CMP_EQ, expval);
280:   for (int i = 0; i < n; i++) sig[i] = newval;
281: #endif
282: }

284: /* ===========================================================================================================

286:    A set of routines to support receiver initiated communication using the get method

288:     The getting protocol is:

290:     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
291:     All signal variables have an initial value 0.

293:     Sender:                                 |  Receiver:
294:   1.  Wait ssig be 0, then set it to 1
295:   2.  Pack data into stand alone sbuf       |
296:   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
297:                                             |   2. Get data from remote sbuf to local rbuf
298:                                             |   3. Put 1 to sender's ssig
299:                                             |   4. Unpack data from local rbuf
300:    ===========================================================================================================*/
301: /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
302:    Sender waits for signals (from receivers) indicating receivers have finished getting data
303: */
304: static PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
305: {
306:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
307:   uint64_t      *sig;
308:   PetscInt       n;

310:   PetscFunctionBegin;
311:   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
312:     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
313:     n   = bas->nRemoteLeafRanks;
314:   } else { /* LEAF2ROOT */
315:     sig = link->leafSendSig;
316:     n   = sf->nRemoteRootRanks;
317:   }

319:   if (n) {
320:     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(n, sig, 0, 1); /* wait the signals to be 0, then set them to 1 */
321:     PetscCallCUDA(cudaGetLastError());
322:   }
323:   PetscFunctionReturn(PETSC_SUCCESS);
324: }

326: /* n thread blocks. Each takes in charge one remote rank */
327: __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks, PetscMPIInt *srcranks, const char *src, PetscInt *srcdisp, char *dst, PetscInt *dstdisp, PetscInt unitbytes)
328: {
329:   int         bid = blockIdx.x;
330:   PetscMPIInt pe  = srcranks[bid];

332:   if (!nvshmem_ptr(src, pe)) {
333:     PetscInt nelems = (dstdisp[bid + 1] - dstdisp[bid]) * unitbytes;
334:     nvshmem_getmem_nbi(dst + (dstdisp[bid] - dstdisp[0]) * unitbytes, src + srcdisp[bid] * unitbytes, nelems, pe);
335:   }
336: }

338: /* Start communication -- Get data in the given direction */
339: static PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
340: {
341:   cudaError_t    cerr;
342:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;

344:   PetscInt nsrcranks, ndstranks, nLocallyAccessible = 0;

346:   char        *src, *dst;
347:   PetscInt    *srcdisp_h, *dstdisp_h;
348:   PetscInt    *srcdisp_d, *dstdisp_d;
349:   PetscMPIInt *srcranks_h;
350:   PetscMPIInt *srcranks_d, *dstranks_d;
351:   uint64_t    *dstsig;
352:   PetscInt    *dstsigdisp_d;

354:   PetscFunctionBegin;
355:   PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
356:   if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
357:     nsrcranks = sf->nRemoteRootRanks;
358:     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */

360:     srcdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
361:     srcdisp_d  = sf->rootbufdisp_d;
362:     srcranks_h = sf->ranks + sf->ndranks; /* my (remote) root ranks */
363:     srcranks_d = sf->ranks_d;

365:     ndstranks = bas->nRemoteLeafRanks;
366:     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */

368:     dstdisp_h  = sf->roffset + sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
369:     dstdisp_d  = sf->roffset_d;
370:     dstranks_d = bas->iranks_d; /* my (remote) leaf ranks */

372:     dstsig       = link->leafRecvSig;
373:     dstsigdisp_d = bas->leafsigdisp_d;
374:   } else { /* src is leaf, dst is root; we will move data from src to dst */
375:     nsrcranks = bas->nRemoteLeafRanks;
376:     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */

378:     srcdisp_h  = bas->leafbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
379:     srcdisp_d  = bas->leafbufdisp_d;
380:     srcranks_h = bas->iranks + bas->ndiranks; /* my (remote) root ranks */
381:     srcranks_d = bas->iranks_d;

383:     ndstranks = sf->nRemoteRootRanks;
384:     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */

386:     dstdisp_h  = bas->ioffset + bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
387:     dstdisp_d  = bas->ioffset_d;
388:     dstranks_d = sf->ranks_d; /* my (remote) root ranks */

390:     dstsig       = link->rootRecvSig;
391:     dstsigdisp_d = sf->rootsigdisp_d;
392:   }

394:   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
395:   if (ndstranks) {
396:     NvshmemSendSignals<<<(ndstranks + 255) / 256, 256, 0, link->remoteCommStream>>>(ndstranks, dstsig, dstsigdisp_d, dstranks_d, 1); /* set signals to 1 */
397:     PetscCallCUDA(cudaGetLastError());
398:   }

400:   /* dst waits for signals (permissions) from src ranks to start getting data */
401:   if (nsrcranks) {
402:     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, dstsig, 1, 0); /* wait the signals to be 1, then set them to 0 */
403:     PetscCallCUDA(cudaGetLastError());
404:   }

406:   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */

408:   /* Count number of locally accessible src ranks, which should be a small number */
409:   for (int i = 0; i < nsrcranks; i++) {
410:     if (nvshmem_ptr(src, srcranks_h[i])) nLocallyAccessible++;
411:   }

413:   /* Get data from remotely accessible PEs */
414:   if (nLocallyAccessible < nsrcranks) {
415:     GetDataFromRemotelyAccessible<<<nsrcranks, 1, 0, link->remoteCommStream>>>(nsrcranks, srcranks_d, src, srcdisp_d, dst, dstdisp_d, link->unitbytes);
416:     PetscCallCUDA(cudaGetLastError());
417:   }

419:   /* Get data from locally accessible PEs */
420:   if (nLocallyAccessible) {
421:     for (int i = 0; i < nsrcranks; i++) {
422:       int pe = srcranks_h[i];
423:       if (nvshmem_ptr(src, pe)) {
424:         size_t nelems = (dstdisp_h[i + 1] - dstdisp_h[i]) * link->unitbytes;
425:         nvshmemx_getmem_nbi_on_stream(dst + (dstdisp_h[i] - dstdisp_h[0]) * link->unitbytes, src + srcdisp_h[i] * link->unitbytes, nelems, pe, link->remoteCommStream);
426:       }
427:     }
428:   }
429:   PetscFunctionReturn(PETSC_SUCCESS);
430: }

432: /* Finish the communication (can be done before Unpack)
433:    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
434: */
435: static PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
436: {
437:   cudaError_t    cerr;
438:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
439:   uint64_t      *srcsig;
440:   PetscInt       nsrcranks, *srcsigdisp;
441:   PetscMPIInt   *srcranks;

443:   PetscFunctionBegin;
444:   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
445:     nsrcranks  = sf->nRemoteRootRanks;
446:     srcsig     = link->rootSendSig; /* I want to set their root signal */
447:     srcsigdisp = sf->rootsigdisp_d; /* offset of each root signal */
448:     srcranks   = sf->ranks_d;       /* ranks of the n root ranks */
449:   } else {                          /* LEAF2ROOT, root ranks are getting data */
450:     nsrcranks  = bas->nRemoteLeafRanks;
451:     srcsig     = link->leafSendSig;
452:     srcsigdisp = bas->leafsigdisp_d;
453:     srcranks   = bas->iranks_d;
454:   }

456:   if (nsrcranks) {
457:     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
458:     PetscCallCUDA(cudaGetLastError());
459:     NvshmemSendSignals<<<(nsrcranks + 511) / 512, 512, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp, srcranks, 0); /* set signals to 0 */
460:     PetscCallCUDA(cudaGetLastError());
461:   }
462:   PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
463:   PetscFunctionReturn(PETSC_SUCCESS);
464: }

466: /* ===========================================================================================================

468:    A set of routines to support sender initiated communication using the put-based method (the default)

470:     The putting protocol is:

472:     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
473:     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
474:     is in nvshmem space.

476:     Sender:                                 |  Receiver:
477:                                             |
478:   1.  Pack data into sbuf                   |
479:   2.  Wait ssig be 0, then set it to 1      |
480:   3.  Put data to remote stand-alone rbuf   |
481:   4.  Fence // make sure 5 happens after 3  |
482:   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
483:                                             |   2. Unpack data from local rbuf
484:                                             |   3. Put 0 to sender's ssig
485:    ===========================================================================================================*/

487: /* n thread blocks. Each takes in charge one remote rank */
488: __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, char *dst, PetscInt *dstdisp, const char *src, PetscInt *srcdisp, uint64_t *srcsig, PetscInt unitbytes)
489: {
490:   int         bid = blockIdx.x;
491:   PetscMPIInt pe  = dstranks[bid];

493:   if (!nvshmem_ptr(dst, pe)) {
494:     PetscInt nelems = (srcdisp[bid + 1] - srcdisp[bid]) * unitbytes;
495:     nvshmem_uint64_wait_until(srcsig + bid, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
496:     srcsig[bid] = 1;
497:     nvshmem_putmem_nbi(dst + dstdisp[bid] * unitbytes, src + (srcdisp[bid] - srcdisp[0]) * unitbytes, nelems, pe);
498:   }
499: }

501: /* one-thread kernel, which takes in charge all locally accessible */
502: __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *srcsig, const char *dst)
503: {
504:   for (int i = 0; i < ndstranks; i++) {
505:     int pe = dstranks[i];
506:     if (nvshmem_ptr(dst, pe)) {
507:       nvshmem_uint64_wait_until(srcsig + i, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
508:       srcsig[i] = 1;
509:     }
510:   }
511: }

513: /* Put data in the given direction  */
514: static PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
515: {
516:   cudaError_t    cerr;
517:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
518:   PetscInt       ndstranks, nLocallyAccessible = 0;
519:   char          *src, *dst;
520:   PetscInt      *srcdisp_h, *dstdisp_h;
521:   PetscInt      *srcdisp_d, *dstdisp_d;
522:   PetscMPIInt   *dstranks_h;
523:   PetscMPIInt   *dstranks_d;
524:   uint64_t      *srcsig;

526:   PetscFunctionBegin;
527:   PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
528:   if (direction == PETSCSF_ROOT2LEAF) {                              /* put data in rootbuf to leafbuf  */
529:     ndstranks = bas->nRemoteLeafRanks;                               /* number of (remote) leaf ranks */
530:     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
531:     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];

533:     srcdisp_h = bas->ioffset + bas->ndiranks; /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
534:     srcdisp_d = bas->ioffset_d;
535:     srcsig    = link->rootSendSig;

537:     dstdisp_h  = bas->leafbufdisp; /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
538:     dstdisp_d  = bas->leafbufdisp_d;
539:     dstranks_h = bas->iranks + bas->ndiranks; /* remote leaf ranks */
540:     dstranks_d = bas->iranks_d;
541:   } else { /* put data in leafbuf to rootbuf */
542:     ndstranks = sf->nRemoteRootRanks;
543:     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
544:     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];

546:     srcdisp_h = sf->roffset + sf->ndranks; /* offsets of leafbuf */
547:     srcdisp_d = sf->roffset_d;
548:     srcsig    = link->leafSendSig;

550:     dstdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
551:     dstdisp_d  = sf->rootbufdisp_d;
552:     dstranks_h = sf->ranks + sf->ndranks; /* remote root ranks */
553:     dstranks_d = sf->ranks_d;
554:   }

556:   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */

558:   /* Count number of locally accessible neighbors, which should be a small number */
559:   for (int i = 0; i < ndstranks; i++) {
560:     if (nvshmem_ptr(dst, dstranks_h[i])) nLocallyAccessible++;
561:   }

563:   /* For remotely accessible PEs, send data to them in one kernel call */
564:   if (nLocallyAccessible < ndstranks) {
565:     WaitAndPutDataToRemotelyAccessible<<<ndstranks, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, dst, dstdisp_d, src, srcdisp_d, srcsig, link->unitbytes);
566:     PetscCallCUDA(cudaGetLastError());
567:   }

569:   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
570:   if (nLocallyAccessible) {
571:     WaitSignalsFromLocallyAccessible<<<1, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, srcsig, dst);
572:     for (int i = 0; i < ndstranks; i++) {
573:       int pe = dstranks_h[i];
574:       if (nvshmem_ptr(dst, pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
575:         size_t nelems = (srcdisp_h[i + 1] - srcdisp_h[i]) * link->unitbytes;
576:         /* Initiate the nonblocking communication */
577:         nvshmemx_putmem_nbi_on_stream(dst + dstdisp_h[i] * link->unitbytes, src + (srcdisp_h[i] - srcdisp_h[0]) * link->unitbytes, nelems, pe, link->remoteCommStream);
578:       }
579:     }
580:   }

582:   if (nLocallyAccessible) { nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */ }
583:   PetscFunctionReturn(PETSC_SUCCESS);
584: }

586: /* A one-thread kernel. The thread takes in charge all remote PEs */
587: __global__ static void PutDataEnd(PetscInt nsrcranks, PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *dstsig, PetscInt *dstsigdisp)
588: {
589:   /* TODO: Shall we finished the non-blocking remote puts? */

591:   /* 1. Send a signal to each dst rank */

593:   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
594:      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
595:   */
596:   for (int i = 0; i < ndstranks; i++) nvshmemx_uint64_signal(dstsig + dstsigdisp[i], 1, dstranks[i]); /* set sig to 1 */

598:   /* 2. Wait for signals from src ranks (if any) */
599:   if (nsrcranks) {
600:     nvshmem_uint64_wait_until_all(dstsig, nsrcranks, NULL /*no mask*/, NVSHMEM_CMP_EQ, 1); /* wait sigs to be 1, then set them to 0 */
601:     for (int i = 0; i < nsrcranks; i++) dstsig[i] = 0;
602:   }
603: }

605: /* Finish the communication -- A receiver waits until it can access its receive buffer */
606: static PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
607: {
608:   cudaError_t    cerr;
609:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
610:   PetscMPIInt   *dstranks;
611:   uint64_t      *dstsig;
612:   PetscInt       nsrcranks, ndstranks, *dstsigdisp;

614:   PetscFunctionBegin;
615:   if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
616:     nsrcranks = sf->nRemoteRootRanks;

618:     ndstranks  = bas->nRemoteLeafRanks;
619:     dstranks   = bas->iranks_d;      /* leaf ranks */
620:     dstsig     = link->leafRecvSig;  /* I will set my leaf ranks's RecvSig */
621:     dstsigdisp = bas->leafsigdisp_d; /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
622:   } else {                           /* LEAF2ROOT */
623:     nsrcranks = bas->nRemoteLeafRanks;

625:     ndstranks  = sf->nRemoteRootRanks;
626:     dstranks   = sf->ranks_d;
627:     dstsig     = link->rootRecvSig;
628:     dstsigdisp = sf->rootsigdisp_d;
629:   }

631:   if (nsrcranks || ndstranks) {
632:     PutDataEnd<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, ndstranks, dstranks, dstsig, dstsigdisp);
633:     PetscCallCUDA(cudaGetLastError());
634:   }
635:   PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
636:   PetscFunctionReturn(PETSC_SUCCESS);
637: }

639: /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
640: static PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
641: {
642:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
643:   uint64_t      *srcsig;
644:   PetscInt       nsrcranks, *srcsigdisp_d;
645:   PetscMPIInt   *srcranks_d;

647:   PetscFunctionBegin;
648:   if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
649:     nsrcranks    = sf->nRemoteRootRanks;
650:     srcsig       = link->rootSendSig; /* I want to set their send signals */
651:     srcsigdisp_d = sf->rootsigdisp_d; /* offset of each root signal */
652:     srcranks_d   = sf->ranks_d;       /* ranks of the n root ranks */
653:   } else {                            /* LEAF2ROOT */
654:     nsrcranks    = bas->nRemoteLeafRanks;
655:     srcsig       = link->leafSendSig;
656:     srcsigdisp_d = bas->leafsigdisp_d;
657:     srcranks_d   = bas->iranks_d;
658:   }

660:   if (nsrcranks) {
661:     NvshmemSendSignals<<<(nsrcranks + 255) / 256, 256, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp_d, srcranks_d, 0); /* Set remote signals to 0 */
662:     PetscCallCUDA(cudaGetLastError());
663:   }
664:   PetscFunctionReturn(PETSC_SUCCESS);
665: }

667: /* Destructor when the link uses nvshmem for communication */
668: static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf, PetscSFLink link)
669: {
670:   cudaError_t cerr;

672:   PetscFunctionBegin;
673:   PetscCallCUDA(cudaEventDestroy(link->dataReady));
674:   PetscCallCUDA(cudaEventDestroy(link->endRemoteComm));
675:   PetscCallCUDA(cudaStreamDestroy(link->remoteCommStream));

677:   /* nvshmem does not need buffers on host, which should be NULL */
678:   PetscCall(PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
679:   PetscCall(PetscNvshmemFree(link->leafSendSig));
680:   PetscCall(PetscNvshmemFree(link->leafRecvSig));
681:   PetscCall(PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
682:   PetscCall(PetscNvshmemFree(link->rootSendSig));
683:   PetscCall(PetscNvshmemFree(link->rootRecvSig));
684:   PetscFunctionReturn(PETSC_SUCCESS);
685: }

687: PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf, MPI_Datatype unit, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, MPI_Op op, PetscSFOperation sfop, PetscSFLink *mylink)
688: {
689:   cudaError_t    cerr;
690:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
691:   PetscSFLink   *p, link;
692:   PetscBool      match, rootdirect[2], leafdirect[2];
693:   int            greatestPriority;

695:   PetscFunctionBegin;
696:   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
697:      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermediate buffers in local communication with NVSHMEM.
698:   */
699:   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
700:     if (sf->use_nvshmem_get) {
701:       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
702:       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
703:     } else {
704:       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
705:       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
706:     }
707:   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
708:     if (sf->use_nvshmem_get) {
709:       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
710:       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
711:     } else {
712:       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
713:       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
714:     }
715:   } else {                                    /* PETSCSF_FETCH */
716:     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
717:     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
718:   }

720:   /* Look for free nvshmem links in cache */
721:   for (p = &bas->avail; (link = *p); p = &link->next) {
722:     if (link->use_nvshmem) {
723:       PetscCall(MPIPetsc_Type_compare(unit, link->unit, &match));
724:       if (match) {
725:         *p = link->next; /* Remove from available list */
726:         goto found;
727:       }
728:     }
729:   }
730:   PetscCall(PetscNew(&link));
731:   PetscCall(PetscSFLinkSetUp_Host(sf, link, unit));                                          /* Compute link->unitbytes, dup link->unit etc. */
732:   if (sf->backend == PETSCSF_BACKEND_CUDA) PetscCall(PetscSFLinkSetUp_CUDA(sf, link, unit)); /* Setup pack routines, streams etc */
733: #if defined(PETSC_HAVE_KOKKOS)
734:   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) PetscCall(PetscSFLinkSetUp_Kokkos(sf, link, unit));
735: #endif

737:   link->rootdirect[PETSCSF_LOCAL] = PETSC_TRUE; /* For the local part we directly use root/leafdata */
738:   link->leafdirect[PETSCSF_LOCAL] = PETSC_TRUE;

740:   /* Init signals to zero */
741:   if (!link->rootSendSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootSendSig));
742:   if (!link->rootRecvSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootRecvSig));
743:   if (!link->leafSendSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafSendSig));
744:   if (!link->leafRecvSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafRecvSig));

746:   link->use_nvshmem = PETSC_TRUE;
747:   link->rootmtype   = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
748:   link->leafmtype   = PETSC_MEMTYPE_DEVICE;
749:   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
750:   link->Destroy = PetscSFLinkDestroy_NVSHMEM;
751:   if (sf->use_nvshmem_get) { /* get-based protocol */
752:     link->PrePack             = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
753:     link->StartCommunication  = PetscSFLinkGetDataBegin_NVSHMEM;
754:     link->FinishCommunication = PetscSFLinkGetDataEnd_NVSHMEM;
755:   } else { /* put-based protocol */
756:     link->StartCommunication  = PetscSFLinkPutDataBegin_NVSHMEM;
757:     link->FinishCommunication = PetscSFLinkPutDataEnd_NVSHMEM;
758:     link->PostUnpack          = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
759:   }

761:   PetscCallCUDA(cudaDeviceGetStreamPriorityRange(NULL, &greatestPriority));
762:   PetscCallCUDA(cudaStreamCreateWithPriority(&link->remoteCommStream, cudaStreamNonBlocking, greatestPriority));

764:   PetscCallCUDA(cudaEventCreateWithFlags(&link->dataReady, cudaEventDisableTiming));
765:   PetscCallCUDA(cudaEventCreateWithFlags(&link->endRemoteComm, cudaEventDisableTiming));

767: found:
768:   if (rootdirect[PETSCSF_REMOTE]) {
769:     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)rootdata + bas->rootstart[PETSCSF_REMOTE] * link->unitbytes;
770:   } else {
771:     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscCall(PetscNvshmemMalloc(bas->rootbuflen_rmax * link->unitbytes, (void **)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
772:     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
773:   }

775:   if (leafdirect[PETSCSF_REMOTE]) {
776:     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)leafdata + sf->leafstart[PETSCSF_REMOTE] * link->unitbytes;
777:   } else {
778:     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscCall(PetscNvshmemMalloc(sf->leafbuflen_rmax * link->unitbytes, (void **)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
779:     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
780:   }

782:   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
783:   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
784:   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
785:   link->leafdata                   = leafdata;
786:   link->next                       = bas->inuse;
787:   bas->inuse                       = link;
788:   *mylink                          = link;
789:   PetscFunctionReturn(PETSC_SUCCESS);
790: }

792: #if defined(PETSC_USE_REAL_SINGLE)
793: PetscErrorCode PetscNvshmemSum(PetscInt count, float *dst, const float *src)
794: {
795:   PetscMPIInt num; /* Assume nvshmem's int is MPI's int */

797:   PetscFunctionBegin;
798:   PetscCall(PetscMPIIntCast(count, &num));
799:   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
800:   PetscFunctionReturn(PETSC_SUCCESS);
801: }

803: PetscErrorCode PetscNvshmemMax(PetscInt count, float *dst, const float *src)
804: {
805:   PetscMPIInt num;

807:   PetscFunctionBegin;
808:   PetscCall(PetscMPIIntCast(count, &num));
809:   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
810:   PetscFunctionReturn(PETSC_SUCCESS);
811: }
812: #elif defined(PETSC_USE_REAL_DOUBLE)
813: PetscErrorCode PetscNvshmemSum(PetscInt count, double *dst, const double *src)
814: {
815:   PetscMPIInt num;

817:   PetscFunctionBegin;
818:   PetscCall(PetscMPIIntCast(count, &num));
819:   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
820:   PetscFunctionReturn(PETSC_SUCCESS);
821: }

823: PetscErrorCode PetscNvshmemMax(PetscInt count, double *dst, const double *src)
824: {
825:   PetscMPIInt num;

827:   PetscFunctionBegin;
828:   PetscCall(PetscMPIIntCast(count, &num));
829:   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
830:   PetscFunctionReturn(PETSC_SUCCESS);
831: }
832: #endif