Actual source code: ngmresfunc.c
1: #include <../src/snes/impls/ngmres/snesngmres.h>
2: #include <petscblaslapack.h>
4: PetscErrorCode SNESNGMRESGetAdditiveLineSearch_Private(SNES snes, SNESLineSearch *linesearch)
5: {
6: SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
8: PetscFunctionBegin;
9: if (!ngmres->additive_linesearch) {
10: const char *optionsprefix;
11: PetscCall(SNESGetOptionsPrefix(snes, &optionsprefix));
12: PetscCall(SNESLineSearchCreate(PetscObjectComm((PetscObject)snes), &ngmres->additive_linesearch));
13: PetscCall(SNESLineSearchSetSNES(ngmres->additive_linesearch, snes));
14: PetscCall(SNESLineSearchSetType(ngmres->additive_linesearch, SNESLINESEARCHL2));
15: PetscCall(SNESLineSearchAppendOptionsPrefix(ngmres->additive_linesearch, "snes_ngmres_additive_"));
16: PetscCall(SNESLineSearchAppendOptionsPrefix(ngmres->additive_linesearch, optionsprefix));
17: PetscCall(PetscObjectIncrementTabLevel((PetscObject)ngmres->additive_linesearch, (PetscObject)snes, 1));
18: }
19: *linesearch = ngmres->additive_linesearch;
20: PetscFunctionReturn(PETSC_SUCCESS);
21: }
23: PetscErrorCode SNESNGMRESUpdateSubspace_Private(SNES snes, PetscInt ivec, PetscInt l, Vec F, PetscReal fnorm, Vec X)
24: {
25: SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
26: Vec *Fdot = ngmres->Fdot;
27: Vec *Xdot = ngmres->Xdot;
29: PetscFunctionBegin;
30: PetscCheck(ivec <= l, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "Cannot update vector %" PetscInt_FMT " with space size %" PetscInt_FMT "!", ivec, l);
31: PetscCall(VecCopy(F, Fdot[ivec]));
32: PetscCall(VecCopy(X, Xdot[ivec]));
34: ngmres->fnorms[ivec] = fnorm;
35: PetscFunctionReturn(PETSC_SUCCESS);
36: }
38: PetscErrorCode SNESNGMRESFormCombinedSolution_Private(SNES snes, PetscInt ivec, PetscInt l, Vec XM, Vec FM, PetscReal fMnorm, Vec X, Vec XA, Vec FA)
39: {
40: SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
41: PetscInt i, j;
42: Vec *Fdot = ngmres->Fdot;
43: Vec *Xdot = ngmres->Xdot;
44: PetscScalar *beta = ngmres->beta;
45: PetscScalar *xi = ngmres->xi;
46: PetscScalar alph_total = 0.;
47: PetscReal nu;
48: Vec Y = snes->vec_sol_update;
49: PetscBool changed_y, changed_w;
51: PetscFunctionBegin;
52: nu = fMnorm * fMnorm;
54: /* construct the right-hand side and xi factors */
55: if (l > 0) {
56: PetscCall(VecMDotBegin(FM, l, Fdot, xi));
57: PetscCall(VecMDotBegin(Fdot[ivec], l, Fdot, beta));
58: PetscCall(VecMDotEnd(FM, l, Fdot, xi));
59: PetscCall(VecMDotEnd(Fdot[ivec], l, Fdot, beta));
60: for (i = 0; i < l; i++) {
61: Q(i, ivec) = beta[i];
62: Q(ivec, i) = beta[i];
63: }
64: } else {
65: Q(0, 0) = ngmres->fnorms[ivec] * ngmres->fnorms[ivec];
66: }
68: for (i = 0; i < l; i++) beta[i] = nu - xi[i];
70: /* construct h */
71: for (j = 0; j < l; j++) {
72: for (i = 0; i < l; i++) H(i, j) = Q(i, j) - xi[i] - xi[j] + nu;
73: }
74: if (l == 1) {
75: /* simply set alpha[0] = beta[0] / H[0, 0] */
76: if (H(0, 0) != 0.) beta[0] = beta[0] / H(0, 0);
77: else beta[0] = 0.;
78: } else {
79: PetscCall(PetscBLASIntCast(l, &ngmres->m));
80: PetscCall(PetscBLASIntCast(l, &ngmres->n));
81: ngmres->info = 0;
82: ngmres->rcond = -1.;
83: PetscCall(PetscFPTrapPush(PETSC_FP_TRAP_OFF));
84: #if defined(PETSC_USE_COMPLEX)
85: PetscCallBLAS("LAPACKgelss", LAPACKgelss_(&ngmres->m, &ngmres->n, &ngmres->nrhs, ngmres->h, &ngmres->lda, ngmres->beta, &ngmres->ldb, ngmres->s, &ngmres->rcond, &ngmres->rank, ngmres->work, &ngmres->lwork, ngmres->rwork, &ngmres->info));
86: #else
87: PetscCallBLAS("LAPACKgelss", LAPACKgelss_(&ngmres->m, &ngmres->n, &ngmres->nrhs, ngmres->h, &ngmres->lda, ngmres->beta, &ngmres->ldb, ngmres->s, &ngmres->rcond, &ngmres->rank, ngmres->work, &ngmres->lwork, &ngmres->info));
88: #endif
89: PetscCall(PetscFPTrapPop());
90: PetscCheck(ngmres->info >= 0, PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "Bad argument to GELSS");
91: PetscCheck(ngmres->info <= 0, PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "SVD failed to converge");
92: }
93: for (i = 0; i < l; i++) PetscCheck(!PetscIsInfOrNanScalar(beta[i]), PetscObjectComm((PetscObject)snes), PETSC_ERR_LIB, "SVD generated inconsistent output");
94: alph_total = 0.;
95: for (i = 0; i < l; i++) alph_total += beta[i];
97: PetscCall(VecAXPBY(XA, 1.0 - alph_total, 0.0, XM));
98: PetscCall(VecMAXPY(XA, l, beta, Xdot));
99: /* check the validity of the step */
100: PetscCall(VecWAXPY(Y, -1.0, X, XA));
101: PetscCall(SNESLineSearchPostCheck(snes->linesearch, X, Y, XA, &changed_y, &changed_w));
102: if (!ngmres->approxfunc) {
103: if (snes->npc && snes->npcside == PC_LEFT) {
104: PetscCall(SNESApplyNPC(snes, XA, NULL, FA));
105: } else {
106: PetscCall(SNESComputeFunction(snes, XA, FA));
107: }
108: } else {
109: PetscCall(VecAXPBY(FA, 1.0 - alph_total, 0.0, FM));
110: PetscCall(VecMAXPY(FA, l, beta, Fdot));
111: }
112: PetscFunctionReturn(PETSC_SUCCESS);
113: }
115: PetscErrorCode SNESNGMRESNorms_Private(SNES snes, PetscInt l, Vec X, Vec F, Vec XM, Vec FM, Vec XA, Vec FA, Vec D, PetscReal *dnorm, PetscReal *dminnorm, PetscReal *xMnorm, PetscReal *fMnorm, PetscReal *yMnorm, PetscReal *xAnorm, PetscReal *fAnorm, PetscReal *yAnorm)
116: {
117: SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
118: PetscReal dcurnorm, dmin = -1.0;
119: Vec *Xdot = ngmres->Xdot;
120: PetscInt i;
122: PetscFunctionBegin;
123: if (xMnorm) PetscCall(VecNormBegin(XM, NORM_2, xMnorm));
124: if (fMnorm) PetscCall(VecNormBegin(FM, NORM_2, fMnorm));
125: if (yMnorm) {
126: PetscCall(VecWAXPY(D, -1.0, XM, X));
127: PetscCall(VecNormBegin(D, NORM_2, yMnorm));
128: }
129: if (xAnorm) PetscCall(VecNormBegin(XA, NORM_2, xAnorm));
130: if (fAnorm) PetscCall(VecNormBegin(FA, NORM_2, fAnorm));
131: if (yAnorm) {
132: PetscCall(VecWAXPY(D, -1.0, XA, X));
133: PetscCall(VecNormBegin(D, NORM_2, yAnorm));
134: }
135: if (dnorm) {
136: PetscCall(VecWAXPY(D, -1.0, XM, XA));
137: PetscCall(VecNormBegin(D, NORM_2, dnorm));
138: }
139: if (dminnorm) {
140: for (i = 0; i < l; i++) {
141: PetscCall(VecWAXPY(D, -1.0, XA, Xdot[i]));
142: PetscCall(VecNormBegin(D, NORM_2, &ngmres->xnorms[i]));
143: }
144: }
145: if (xMnorm) PetscCall(VecNormEnd(XM, NORM_2, xMnorm));
146: if (fMnorm) PetscCall(VecNormEnd(FM, NORM_2, fMnorm));
147: if (yMnorm) PetscCall(VecNormEnd(D, NORM_2, yMnorm));
148: if (xAnorm) PetscCall(VecNormEnd(XA, NORM_2, xAnorm));
149: if (fAnorm) PetscCall(VecNormEnd(FA, NORM_2, fAnorm));
150: if (yAnorm) PetscCall(VecNormEnd(D, NORM_2, yAnorm));
151: if (dnorm) PetscCall(VecNormEnd(D, NORM_2, dnorm));
152: if (dminnorm) {
153: for (i = 0; i < l; i++) {
154: PetscCall(VecNormEnd(D, NORM_2, &ngmres->xnorms[i]));
155: dcurnorm = ngmres->xnorms[i];
156: if ((dcurnorm < dmin) || (dmin < 0.0)) dmin = dcurnorm;
157: }
158: *dminnorm = dmin;
159: }
160: PetscFunctionReturn(PETSC_SUCCESS);
161: }
163: PetscErrorCode SNESNGMRESSelect_Private(SNES snes, PetscInt k_restart, Vec XM, Vec FM, PetscReal xMnorm, PetscReal fMnorm, PetscReal yMnorm, PetscReal objM, Vec XA, Vec FA, PetscReal xAnorm, PetscReal fAnorm, PetscReal yAnorm, PetscReal objA, PetscReal dnorm, PetscReal objmin, PetscReal dminnorm, Vec X, Vec F, Vec Y, PetscReal *xnorm, PetscReal *fnorm, PetscReal *ynorm)
164: {
165: SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
166: SNESLineSearchReason lssucceed;
167: PetscBool selectA;
169: PetscFunctionBegin;
170: if (ngmres->select_type == SNES_NGMRES_SELECT_LINESEARCH) {
171: /* X = X + \lambda(XA - X) */
172: if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
173: /* Test if is XA - XM is a descent direction: we want < F(XM), XA - XM > not positive
174: If positive, GMRES will be restarted see https://epubs.siam.org/doi/pdf/10.1137/110835530 */
175: PetscCall(VecCopy(FM, F));
176: PetscCall(VecCopy(XM, X));
177: PetscCall(VecWAXPY(Y, -1.0, XA, X)); /* minus sign since linesearch expects to find Xnew = X - lambda * Y */
178: PetscCall(VecDotRealPart(FM, Y, &ngmres->descent_ls_test)); /* this is actually < F(XM), XM - XA > */
179: *fnorm = fMnorm;
180: if (ngmres->descent_ls_test < 0) { /* XA - XM is not a descent direction, select XM */
181: *xnorm = xMnorm;
182: *fnorm = fMnorm;
183: *ynorm = yMnorm;
184: PetscCall(VecWAXPY(Y, -1.0, X, XM));
185: PetscCall(VecCopy(FM, F));
186: PetscCall(VecCopy(XM, X));
187: } else {
188: PetscCall(SNESNGMRESGetAdditiveLineSearch_Private(snes, &ngmres->additive_linesearch));
189: PetscCall(SNESLineSearchApply(ngmres->additive_linesearch, X, F, fnorm, Y));
190: PetscCall(SNESLineSearchGetReason(ngmres->additive_linesearch, &lssucceed));
191: PetscCall(SNESLineSearchGetNorms(ngmres->additive_linesearch, xnorm, fnorm, ynorm));
192: if (lssucceed) {
193: if (++snes->numFailures >= snes->maxFailures) {
194: snes->reason = SNES_DIVERGED_LINE_SEARCH;
195: PetscFunctionReturn(PETSC_SUCCESS);
196: }
197: }
198: }
199: if (ngmres->monitor) {
200: PetscReal objT = *fnorm;
201: SNESObjectiveFn *objective;
203: PetscCall(SNESGetObjective(snes, &objective, NULL));
204: if (objective) PetscCall(SNESComputeObjective(snes, X, &objT));
205: PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "Additive solution: objective = %e\n", (double)objT));
206: }
207: } else if (ngmres->select_type == SNES_NGMRES_SELECT_DIFFERENCE) {
208: /* Conditions for choosing the accelerated answer:
209: Criterion A -- the objective function isn't increased above the minimum by too much
210: Criterion B -- the choice of x^A isn't too close to some other choice
211: */
212: selectA = (PetscBool)(/* A */ (objA < ngmres->gammaA * objmin) && /* B */ (ngmres->epsilonB * dnorm < dminnorm || objA < ngmres->deltaB * objmin));
214: if (selectA) {
215: if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "picked X_A, obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
216: /* copy it over */
217: *xnorm = xAnorm;
218: *fnorm = fAnorm;
219: *ynorm = yAnorm;
220: PetscCall(VecCopy(FA, F));
221: PetscCall(VecCopy(XA, X));
222: } else {
223: if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "picked X_M, obj(X_A) = %e, ||F_A||_2 = %e, obj(X_M) = %e, ||F_M||_2 = %e\n", (double)objA, (double)fAnorm, (double)objM, (double)fMnorm));
224: *xnorm = xMnorm;
225: *fnorm = fMnorm;
226: *ynorm = yMnorm;
227: PetscCall(VecWAXPY(Y, -1.0, X, XM));
228: PetscCall(VecCopy(FM, F));
229: PetscCall(VecCopy(XM, X));
230: }
231: } else { /* none */
232: *xnorm = xAnorm;
233: *fnorm = fAnorm;
234: *ynorm = yAnorm;
235: PetscCall(VecCopy(FA, F));
236: PetscCall(VecCopy(XA, X));
237: }
238: PetscFunctionReturn(PETSC_SUCCESS);
239: }
241: PetscErrorCode SNESNGMRESSelectRestart_Private(SNES snes, PetscInt l, PetscReal obj, PetscReal objM, PetscReal objA, PetscReal dnorm, PetscReal objmin, PetscReal dminnorm, PetscBool *selectRestart)
242: {
243: SNES_NGMRES *ngmres = (SNES_NGMRES *)snes->data;
245: PetscFunctionBegin;
246: *selectRestart = PETSC_FALSE;
247: if (ngmres->select_type == SNES_NGMRES_SELECT_LINESEARCH) {
248: if (ngmres->descent_ls_test < 0) { /* XA - XM is not a descent direction */
249: if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "ascent restart: %e > 0\n", (double)-ngmres->descent_ls_test));
250: *selectRestart = PETSC_TRUE;
251: }
252: } else if (ngmres->select_type == SNES_NGMRES_SELECT_DIFFERENCE) {
253: /* difference stagnation restart */
254: if (ngmres->epsilonB * dnorm > dminnorm && objA > ngmres->deltaB * objmin && l > 0) {
255: if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "difference restart: %e > %e\n", (double)(ngmres->epsilonB * dnorm), (double)dminnorm));
256: *selectRestart = PETSC_TRUE;
257: }
258: /* residual stagnation restart */
259: if (objA > ngmres->gammaC * objmin) {
260: if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "residual restart: %e > %e\n", (double)objA, (double)(ngmres->gammaC * objmin)));
261: *selectRestart = PETSC_TRUE;
262: }
264: /* F_M stagnation restart */
265: if (ngmres->restart_fm_rise && objM > obj) {
266: if (ngmres->monitor) PetscCall(PetscViewerASCIIPrintf(ngmres->monitor, "F_M rise restart: %e > %e\n", (double)objM, (double)obj));
267: *selectRestart = PETSC_TRUE;
268: }
269: }
270: PetscFunctionReturn(PETSC_SUCCESS);
271: }