Actual source code: solve_performance.c
1: const char help[] = "Profile the performance of MATLMVM MatSolve() in a loop";
3: #include <petscksp.h>
4: #include <petscmath.h>
6: int main(int argc, char **argv)
7: {
8: PetscInt n = 1000;
9: PetscInt n_epochs = 10;
10: PetscInt n_iters = 10;
11: Vec x, g, dx, df, p;
12: PetscRandom rand;
13: PetscLogStage matsolve_loop, main_stage;
14: Mat B;
16: PetscCall(PetscInitialize(&argc, &argv, NULL, help));
17: PetscOptionsBegin(PETSC_COMM_WORLD, NULL, help, "KSP");
18: PetscCall(PetscOptionsInt("-n", "Vector size", __FILE__, n, &n, NULL));
19: PetscCall(PetscOptionsInt("-epochs", "Number of epochs", __FILE__, n_epochs, &n_epochs, NULL));
20: PetscCall(PetscOptionsInt("-iters", "Number of iterations per epoch", __FILE__, n_iters, &n_iters, NULL));
21: PetscOptionsEnd();
22: PetscCall(VecCreateMPI(PETSC_COMM_WORLD, PETSC_DETERMINE, n, &x));
23: PetscCall(VecSetFromOptions(x));
24: PetscCall(VecDuplicate(x, &g));
25: PetscCall(VecDuplicate(x, &dx));
26: PetscCall(VecDuplicate(x, &df));
27: PetscCall(VecDuplicate(x, &p));
28: PetscCall(MatCreateLMVMBFGS(PETSC_COMM_WORLD, PETSC_DETERMINE, n, &B));
29: PetscCall(MatSetFromOptions(B));
30: PetscCall(MatLMVMAllocate(B, x, g));
31: PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &rand));
32: PetscCall(PetscRandomSetInterval(rand, -1.0, 1.0));
33: PetscCall(PetscRandomSetFromOptions(rand));
34: PetscCall(PetscLogStageRegister("LMVM MatSolve Loop", &matsolve_loop));
35: PetscCall(PetscLogStageGetId("Main Stage", &main_stage));
36: PetscCall(PetscLogStageSetVisible(main_stage, PETSC_FALSE));
37: for (PetscInt epoch = 0; epoch < n_epochs + 1; epoch++) {
38: PetscScalar dot;
39: PetscReal xscale, fscale, absdot;
40: PetscInt history_size;
42: PetscCall(VecSetRandom(dx, rand));
43: PetscCall(VecSetRandom(df, rand));
44: PetscCall(VecDot(dx, df, &dot));
45: absdot = PetscAbsScalar(dot);
46: PetscCall(VecSetRandom(x, rand));
47: PetscCall(VecSetRandom(g, rand));
48: xscale = 1.0;
49: fscale = absdot / PetscRealPart(dot);
50: PetscCall(MatLMVMGetHistorySize(B, &history_size));
52: PetscCall(MatLMVMUpdate(B, x, g));
53: for (PetscInt iter = 0; iter < history_size; iter++, xscale *= -1.0, fscale *= -1.0) {
54: PetscCall(VecAXPY(x, xscale, dx));
55: PetscCall(VecAXPY(g, fscale, df));
56: PetscCall(MatLMVMUpdate(B, x, g));
57: PetscCall(MatSolve(B, g, p));
58: }
59: if (epoch > 0) PetscCall(PetscLogStagePush(matsolve_loop));
60: for (PetscInt iter = 0; iter < n_iters; iter++, xscale *= -1.0, fscale *= -1.0) {
61: PetscCall(VecAXPY(x, xscale, dx));
62: PetscCall(VecAXPY(g, fscale, df));
63: PetscCall(MatLMVMUpdate(B, x, g));
64: PetscCall(MatSolve(B, g, p));
65: }
66: PetscCall(MatLMVMReset(B, PETSC_FALSE));
67: if (epoch > 0) PetscCall(PetscLogStagePop());
68: }
69: PetscCall(MatView(B, PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD)));
70: PetscCall(PetscRandomDestroy(&rand));
71: PetscCall(MatDestroy(&B));
72: PetscCall(VecDestroy(&p));
73: PetscCall(VecDestroy(&df));
74: PetscCall(VecDestroy(&dx));
75: PetscCall(VecDestroy(&g));
76: PetscCall(VecDestroy(&x));
77: PetscCall(PetscFinalize());
78: return 0;
79: }
81: /*TEST
83: test:
84: suffix: 0
85: args: -mat_lmvm_scale_type none
87: TEST*/