Actual source code: pbjacobi_kok.kokkos.cxx

  1: #include <petscvec_kokkos.hpp>
  2: #include <petsc_kokkos.hpp>
  3: #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
  4: #include <petscdevice.h>
  5: #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h>

  7: struct PC_PBJacobi_Kokkos {
  8:   PetscScalarKokkosDualView diag_dual;

 10:   PC_PBJacobi_Kokkos(PetscInt len, PetscScalar *diag_ptr_h)
 11:   {
 12:     PetscScalarKokkosViewHost diag_h(diag_ptr_h, len);
 13:     auto                      diag_d = Kokkos::create_mirror_view_and_copy(PetscGetKokkosExecutionSpace(), diag_h);
 14:     diag_dual                        = PetscScalarKokkosDualView(diag_d, diag_h);
 15:   }

 17:   PetscErrorCode Update(const PetscScalar *diag_ptr_h)
 18:   {
 19:     auto &exec = PetscGetKokkosExecutionSpace();

 21:     PetscFunctionBegin;
 22:     PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call");
 23:     PetscCallCXX(diag_dual.modify_host()); /* mark the host has newer data */
 24:     PetscCallCXX(diag_dual.sync_device(exec));
 25:     PetscFunctionReturn(PETSC_SUCCESS);
 26:   }
 27: };

 29: /* Make 'transpose' a template parameter instead of a function input parameter, so that
 30:  it will be a const in template instantiation and gets optimized out.
 31: */
 32: template <PetscBool transpose>
 33: static PetscErrorCode PCApplyOrTranspose_PBJacobi_Kokkos(PC pc, Vec x, Vec y)
 34: {
 35:   PC_PBJacobi               *jac   = (PC_PBJacobi *)pc->data;
 36:   PC_PBJacobi_Kokkos        *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr);
 37:   ConstPetscScalarKokkosView xv;
 38:   PetscScalarKokkosView      yv;
 39:   PetscScalarKokkosView      Av = pckok->diag_dual.view_device();
 40:   const PetscInt             bs = jac->bs, mbs = jac->mbs, bs2 = bs * bs;
 41:   const char                *label = transpose ? "PCApplyTranspose_PBJacobi_Kokkos" : "PCApply_PBJacobi_Kokkos";

 43:   PetscFunctionBegin;
 44:   PetscCall(PetscLogGpuTimeBegin());
 45:   VecErrorIfNotKokkos(x);
 46:   VecErrorIfNotKokkos(y);
 47:   PetscCall(VecGetKokkosView(x, &xv));
 48:   PetscCall(VecGetKokkosViewWrite(y, &yv));
 49:   PetscCallCXX(Kokkos::parallel_for(
 50:     label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, bs * mbs), KOKKOS_LAMBDA(PetscInt row) {
 51:       const PetscScalar *Ap, *xp;
 52:       PetscScalar       *yp;
 53:       PetscInt           i, j, k;

 55:       k  = row / bs;                                /* k-th block */
 56:       i  = row % bs;                                /* this thread deals with i-th row of the block */
 57:       Ap = &Av(bs2 * k + i * (transpose ? bs : 1)); /* Ap points to the first entry of i-th row */
 58:       xp = &xv(bs * k);
 59:       yp = &yv(bs * k);
 60:       /* multiply i-th row (column) with x */
 61:       yp[i] = 0.0;
 62:       for (j = 0; j < bs; j++) {
 63:         yp[i] += Ap[0] * xp[j];
 64:         Ap += (transpose ? 1 : bs); /* block is in column major order */
 65:       }
 66:     }));
 67:   PetscCall(VecRestoreKokkosView(x, &xv));
 68:   PetscCall(VecRestoreKokkosViewWrite(y, &yv));
 69:   PetscCall(PetscLogGpuFlops(bs * bs * mbs * 2)); /* FMA on entries in all blocks */
 70:   PetscCall(PetscLogGpuTimeEnd());
 71:   PetscFunctionReturn(PETSC_SUCCESS);
 72: }

 74: static PetscErrorCode PCDestroy_PBJacobi_Kokkos(PC pc)
 75: {
 76:   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;

 78:   PetscFunctionBegin;
 79:   PetscCallCXX(delete static_cast<PC_PBJacobi_Kokkos *>(jac->spptr));
 80:   PetscCall(PCDestroy_PBJacobi(pc));
 81:   PetscFunctionReturn(PETSC_SUCCESS);
 82: }

 84: PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_Kokkos(PC pc, Mat diagPB)
 85: {
 86:   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
 87:   PetscInt     len;

 89:   PetscFunctionBegin;
 90:   PetscCall(PCSetUp_PBJacobi_Host(pc, diagPB)); /* Compute the inverse on host now. Might worth doing it on device directly */
 91:   len = jac->bs * jac->bs * jac->mbs;
 92:   if (!jac->spptr) {
 93:     PetscCallCXX(jac->spptr = new PC_PBJacobi_Kokkos(len, const_cast<PetscScalar *>(jac->diag)));
 94:   } else {
 95:     PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr);
 96:     PetscCall(pckok->Update(jac->diag));
 97:   }
 98:   PetscCall(PetscLogCpuToGpu(sizeof(PetscScalar) * len));

100:   pc->ops->apply          = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_FALSE>;
101:   pc->ops->applytranspose = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_TRUE>;
102:   pc->ops->destroy        = PCDestroy_PBJacobi_Kokkos;
103:   PetscFunctionReturn(PETSC_SUCCESS);
104: }