Actual source code: fsolvebaij.F90

  1: !
  2: !
  3: !    Fortran kernel for sparse triangular solve in the BAIJ matrix format
  4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
  5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
  6: !
  7: #include <petsc/finclude/petscsys.h>
  8: !

 10:       subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
 11:       implicit none
 12:       MatScalar   a(0:*)
 13:       PetscScalar x(0:*)
 14:       PetscScalar b(0:*)
 15:       PetscInt    n
 16:       PetscInt    ai(0:*)
 17:       PetscInt    aj(0:*)
 18:       PetscInt    adiag(0:*)

 20:       PetscInt    i,j,jstart,jend
 21:       PetscInt    idx,ax,jdx
 22:       PetscScalar s1,s2,s3,s4
 23:       PetscScalar x1,x2,x3,x4
 24: !
 25: !     Forward Solve
 26: !
 27:       PETSC_AssertAlignx(16,a(1))
 28:       PETSC_AssertAlignx(16,x(1))
 29:       PETSC_AssertAlignx(16,b(1))
 30:       PETSC_AssertAlignx(16,ai(1))
 31:       PETSC_AssertAlignx(16,aj(1))
 32:       PETSC_AssertAlignx(16,adiag(1))

 34:          x(0) = b(0)
 35:          x(1) = b(1)
 36:          x(2) = b(2)
 37:          x(3) = b(3)
 38:          idx  = 0
 39:          do 20 i=1,n-1
 40:             jstart = ai(i)
 41:             jend   = adiag(i) - 1
 42:             ax    = 16*jstart
 43:             idx    = idx + 4
 44:             s1     = b(idx)
 45:             s2     = b(idx+1)
 46:             s3     = b(idx+2)
 47:             s4     = b(idx+3)
 48:             do 30 j=jstart,jend
 49:               jdx   = 4*aj(j)

 51:               x1    = x(jdx)
 52:               x2    = x(jdx+1)
 53:               x3    = x(jdx+2)
 54:               x4    = x(jdx+3)
 55:               s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 56:               s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 57:               s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 58:               s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 59:               ax = ax + 16
 60:  30         continue
 61:             x(idx)   = s1
 62:             x(idx+1) = s2
 63:             x(idx+2) = s3
 64:             x(idx+3) = s4
 65:  20      continue

 67: !
 68: !     Backward solve the upper triangular
 69: !
 70:          do 40 i=n-1,0,-1
 71:             jstart  = adiag(i) + 1
 72:             jend    = ai(i+1) - 1
 73:             ax     = 16*jstart
 74:             s1      = x(idx)
 75:             s2      = x(idx+1)
 76:             s3      = x(idx+2)
 77:             s4      = x(idx+3)
 78:             do 50 j=jstart,jend
 79:               jdx   = 4*aj(j)
 80:               x1    = x(jdx)
 81:               x2    = x(jdx+1)
 82:               x3    = x(jdx+2)
 83:               x4    = x(jdx+3)
 84:               s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 85:               s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 86:               s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 87:               s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 88:               ax = ax + 16
 89:  50         continue
 90:             ax      = 16*adiag(i)
 91:             x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
 92:             x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
 93:             x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
 94:             x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
 95:             idx      = idx - 4
 96:  40      continue
 97:       end

 99: !   version that does not call BLAS 2 operation for each row block
100: !
101:       subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
102:       implicit none
103:       MatScalar   a(0:*)
104:       PetscScalar x(0:*),b(0:*),w(0:*)
105:       PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)
106:       PetscInt  ii,jj,i,j

108:       PetscInt  jstart,jend,idx,ax,jdx,kdx,nn
109:       PetscScalar s(0:3)

111: !
112: !     Forward Solve
113: !

115:       PETSC_AssertAlignx(16,a(1))
116:       PETSC_AssertAlignx(16,w(1))
117:       PETSC_AssertAlignx(16,x(1))
118:       PETSC_AssertAlignx(16,b(1))
119:       PETSC_AssertAlignx(16,ai(1))
120:       PETSC_AssertAlignx(16,aj(1))
121:       PETSC_AssertAlignx(16,adiag(1))

123:       x(0) = b(0)
124:       x(1) = b(1)
125:       x(2) = b(2)
126:       x(3) = b(3)
127:       idx  = 0
128:       do 20 i=1,n-1
129: !
130: !        Pack required part of vector into work array
131: !
132:          kdx    = 0
133:          jstart = ai(i)
134:          jend   = adiag(i) - 1
135:          if (jend - jstart .ge. 500) then
136:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
137:          endif
138:          do 30 j=jstart,jend

140:            jdx       = 4*aj(j)

142:            w(kdx)    = x(jdx)
143:            w(kdx+1)  = x(jdx+1)
144:            w(kdx+2)  = x(jdx+2)
145:            w(kdx+3)  = x(jdx+3)
146:            kdx       = kdx + 4
147:  30      continue

149:          ax       = 16*jstart
150:          idx      = idx + 4
151:          s(0)     = b(idx)
152:          s(1)     = b(idx+1)
153:          s(2)     = b(idx+2)
154:          s(3)     = b(idx+3)
155: !
156: !    s = s - a(ax:)*w
157: !
158:          nn = 4*(jend - jstart + 1) - 1
159:          do 100, ii=0,3
160:            do 110, jj=0,nn
161:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
162:  110       continue
163:  100     continue

165:          x(idx)   = s(0)
166:          x(idx+1) = s(1)
167:          x(idx+2) = s(2)
168:          x(idx+3) = s(3)
169:  20   continue

171: !
172: !     Backward solve the upper triangular
173: !
174:       do 40 i=n-1,0,-1
175:          jstart    = adiag(i) + 1
176:          jend      = ai(i+1) - 1
177:          ax        = 16*jstart
178:          s(0)      = x(idx)
179:          s(1)      = x(idx+1)
180:          s(2)      = x(idx+2)
181:          s(3)      = x(idx+3)
182: !
183: !   Pack each chunk of vector needed
184: !
185:          kdx = 0
186:          if (jend - jstart .ge. 500) then
187:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
188:          endif
189:          do 50 j=jstart,jend
190:            jdx      = 4*aj(j)
191:            w(kdx)   = x(jdx)
192:            w(kdx+1) = x(jdx+1)
193:            w(kdx+2) = x(jdx+2)
194:            w(kdx+3) = x(jdx+3)
195:            kdx      = kdx + 4
196:  50      continue
197:          nn = 4*(jend - jstart + 1) - 1
198:          do 200, ii=0,3
199:            do 210, jj=0,nn
200:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
201:  210       continue
202:  200     continue

204:          ax      = 16*adiag(i)
205:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
206:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
207:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
208:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
209:          idx     = idx - 4
210:  40   continue

212:       end