Actual source code: ex1.c
1: const char help[] = "Test TAOLMVM on a least-squares problem";
3: #include <petsctao.h>
4: #include <petscdevice.h>
6: typedef struct _n_AppCtx {
7: Mat A;
8: Vec b;
9: Vec r;
10: } AppCtx;
12: static PetscErrorCode LSObjAndGrad(Tao tao, Vec x, PetscReal *obj, Vec g, void *_ctx)
13: {
14: PetscFunctionBegin;
15: AppCtx *ctx = (AppCtx *)_ctx;
16: PetscCall(VecAXPBY(ctx->r, -1.0, 0.0, ctx->b));
17: PetscCall(MatMultAdd(ctx->A, x, ctx->r, ctx->r));
18: PetscCall(VecDotRealPart(ctx->r, ctx->r, obj));
19: *obj *= 0.5;
20: PetscCall(MatMultTranspose(ctx->A, ctx->r, g));
21: PetscFunctionReturn(PETSC_SUCCESS);
22: }
24: int main(int argc, char **argv)
25: {
26: PetscCall(PetscInitialize(&argc, &argv, NULL, help));
27: MPI_Comm comm = PETSC_COMM_WORLD;
28: AppCtx ctx;
29: Vec sol;
30: PetscBool flg, cuda = PETSC_FALSE;
32: PetscInt M = 10;
33: PetscInt N = 10;
34: PetscOptionsBegin(comm, "", help, "TAO");
35: PetscCall(PetscOptionsInt("-m", "data size", NULL, M, &M, NULL));
36: PetscCall(PetscOptionsInt("-n", "data size", NULL, N, &N, NULL));
37: PetscCall(PetscOptionsGetBool(NULL, NULL, "-cuda", &cuda, &flg));
38: PetscOptionsEnd();
40: if (cuda) {
41: VecType vec_type;
42: PetscCall(VecCreateSeqCUDA(comm, N, &ctx.b));
43: PetscCall(VecGetType(ctx.b, &vec_type));
44: PetscCall(MatCreateDenseFromVecType(comm, vec_type, M, N, PETSC_DECIDE, PETSC_DECIDE, -1, NULL, &ctx.A));
45: PetscCall(MatCreateVecs(ctx.A, &sol, NULL));
46: } else {
47: PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, M, N, NULL, &ctx.A));
48: PetscCall(MatCreateVecs(ctx.A, &sol, &ctx.b));
49: }
50: PetscCall(VecDuplicate(ctx.b, &ctx.r));
51: PetscCall(VecZeroEntries(sol));
53: PetscRandom rand;
54: PetscCall(PetscRandomCreate(comm, &rand));
55: PetscCall(PetscRandomSetFromOptions(rand));
56: PetscCall(MatSetRandom(ctx.A, rand));
57: PetscCall(VecSetRandom(ctx.b, rand));
58: PetscCall(PetscRandomDestroy(&rand));
60: Tao tao;
61: PetscCall(TaoCreate(comm, &tao));
62: PetscCall(TaoSetSolution(tao, sol));
63: PetscCall(TaoSetObjectiveAndGradient(tao, NULL, LSObjAndGrad, &ctx));
64: PetscCall(TaoSetType(tao, TAOLMVM));
65: PetscCall(TaoSetFromOptions(tao));
66: PetscCall(TaoSolve(tao));
67: PetscCall(TaoDestroy(&tao));
69: PetscCall(VecDestroy(&ctx.r));
70: PetscCall(VecDestroy(&sol));
71: PetscCall(VecDestroy(&ctx.b));
72: PetscCall(MatDestroy(&ctx.A));
74: PetscCall(PetscFinalize());
75: return 0;
76: }
78: /*TEST
80: build:
81: requires: !complex !__float128 !single !defined(PETSC_USE_64BIT_INDICES)
83: test:
84: suffix: 0
85: args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_type lmvmbfgs
87: test:
88: suffix: 1
89: args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_type lmvmdbfgs
91: test:
92: suffix: 2
93: args: -tao_monitor -tao_ls_gtol 1.e-6 -tao_view -tao_lmvm_mat_lmvm_hist_size 20 -tao_ls_type more-thuente -tao_lmvm_mat_type lmvmdbfgs -tao_lmvm_mat_lmvm_scale_type none -tao_lmvm_mat_lbfgs_type {{inplace reorder}}
95: TEST*/