Actual source code: pbjacobi_kok.kokkos.cxx
1: #include <petscvec_kokkos.hpp>
2: #include <petsc_kokkos.hpp>
3: #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
4: #include <petscdevice.h>
5: #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h>
7: struct PC_PBJacobi_Kokkos {
8: PetscScalarKokkosDualView diag_dual;
10: PC_PBJacobi_Kokkos(PetscInt len, PetscScalar *diag_ptr_h)
11: {
12: PetscScalarKokkosViewHost diag_h(diag_ptr_h, len);
13: auto diag_d = Kokkos::create_mirror_view_and_copy(PetscGetKokkosExecutionSpace(), diag_h);
14: diag_dual = PetscScalarKokkosDualView(diag_d, diag_h);
15: }
17: PetscErrorCode Update(const PetscScalar *diag_ptr_h)
18: {
19: auto &exec = PetscGetKokkosExecutionSpace();
21: PetscFunctionBegin;
22: PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Host pointer has changed since last call");
23: PetscCallCXX(diag_dual.modify_host()); /* mark the host has newer data */
24: PetscCallCXX(diag_dual.sync_device(exec));
25: PetscFunctionReturn(PETSC_SUCCESS);
26: }
27: };
29: /* Make 'transpose' a template parameter instead of a function input parameter, so that
30: it will be a const in template instantiation and gets optimized out.
31: */
32: template <PetscBool transpose>
33: static PetscErrorCode PCApplyOrTranspose_PBJacobi_Kokkos(PC pc, Vec x, Vec y)
34: {
35: PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
36: PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr);
37: ConstPetscScalarKokkosView xv;
38: PetscScalarKokkosView yv;
39: PetscScalarKokkosView Av = pckok->diag_dual.view_device();
40: const PetscInt bs = jac->bs, mbs = jac->mbs, bs2 = bs * bs;
41: const char *label = transpose ? "PCApplyTranspose_PBJacobi_Kokkos" : "PCApply_PBJacobi_Kokkos";
43: PetscFunctionBegin;
44: PetscCall(PetscLogGpuTimeBegin());
45: VecErrorIfNotKokkos(x);
46: VecErrorIfNotKokkos(y);
47: PetscCall(VecGetKokkosView(x, &xv));
48: PetscCall(VecGetKokkosViewWrite(y, &yv));
49: PetscCallCXX(Kokkos::parallel_for(
50: label, Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, bs * mbs), KOKKOS_LAMBDA(PetscInt row) {
51: const PetscScalar *Ap, *xp;
52: PetscScalar *yp;
53: PetscInt i, j, k;
55: k = row / bs; /* k-th block */
56: i = row % bs; /* this thread deals with i-th row of the block */
57: Ap = &Av(bs2 * k + i * (transpose ? bs : 1)); /* Ap points to the first entry of i-th row */
58: xp = &xv(bs * k);
59: yp = &yv(bs * k);
60: /* multiply i-th row (column) with x */
61: yp[i] = 0.0;
62: for (j = 0; j < bs; j++) {
63: yp[i] += Ap[0] * xp[j];
64: Ap += (transpose ? 1 : bs); /* block is in column major order */
65: }
66: }));
67: PetscCall(VecRestoreKokkosView(x, &xv));
68: PetscCall(VecRestoreKokkosViewWrite(y, &yv));
69: PetscCall(PetscLogGpuFlops(bs * bs * mbs * 2)); /* FMA on entries in all blocks */
70: PetscCall(PetscLogGpuTimeEnd());
71: PetscFunctionReturn(PETSC_SUCCESS);
72: }
74: static PetscErrorCode PCDestroy_PBJacobi_Kokkos(PC pc)
75: {
76: PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
78: PetscFunctionBegin;
79: PetscCallCXX(delete static_cast<PC_PBJacobi_Kokkos *>(jac->spptr));
80: PetscCall(PCDestroy_PBJacobi(pc));
81: PetscFunctionReturn(PETSC_SUCCESS);
82: }
84: PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_Kokkos(PC pc, Mat diagPB)
85: {
86: PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
87: PetscInt len;
89: PetscFunctionBegin;
90: PetscCall(PCSetUp_PBJacobi_Host(pc, diagPB)); /* Compute the inverse on host now. Might worth doing it on device directly */
91: len = jac->bs * jac->bs * jac->mbs;
92: if (!jac->spptr) {
93: PetscCallCXX(jac->spptr = new PC_PBJacobi_Kokkos(len, const_cast<PetscScalar *>(jac->diag)));
94: } else {
95: PC_PBJacobi_Kokkos *pckok = static_cast<PC_PBJacobi_Kokkos *>(jac->spptr);
96: PetscCall(pckok->Update(jac->diag));
97: }
98: PetscCall(PetscLogCpuToGpu(sizeof(PetscScalar) * len));
100: pc->ops->apply = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_FALSE>;
101: pc->ops->applytranspose = PCApplyOrTranspose_PBJacobi_Kokkos<PETSC_TRUE>;
102: pc->ops->destroy = PCDestroy_PBJacobi_Kokkos;
103: PetscFunctionReturn(PETSC_SUCCESS);
104: }