Actual source code: fnsqrt.c

slepc-3.13.4 2020-09-02
Report Typos and Errors
  1: /*
  2:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  3:    SLEPc - Scalable Library for Eigenvalue Problem Computations
  4:    Copyright (c) 2002-2020, Universitat Politecnica de Valencia, Spain

  6:    This file is part of SLEPc.
  7:    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
  8:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  9: */
 10: /*
 11:    Square root function  sqrt(x)
 12: */

 14:  #include <slepc/private/fnimpl.h>
 15:  #include <slepcblaslapack.h>

 17: PetscErrorCode FNEvaluateFunction_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
 18: {
 20: #if !defined(PETSC_USE_COMPLEX)
 21:   if (x<0.0) SETERRQ(PETSC_COMM_SELF,1,"Function not defined in the requested value");
 22: #endif
 23:   *y = PetscSqrtScalar(x);
 24:   return(0);
 25: }

 27: PetscErrorCode FNEvaluateDerivative_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
 28: {
 30:   if (x==0.0) SETERRQ(PETSC_COMM_SELF,1,"Derivative not defined in the requested value");
 31: #if !defined(PETSC_USE_COMPLEX)
 32:   if (x<0.0) SETERRQ(PETSC_COMM_SELF,1,"Derivative not defined in the requested value");
 33: #endif
 34:   *y = 1.0/(2.0*PetscSqrtScalar(x));
 35:   return(0);
 36: }

 38: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Schur(FN fn,Mat A,Mat B)
 39: {
 41:   PetscBLASInt   n;
 42:   PetscScalar    *T;
 43:   PetscInt       m;

 46:   if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
 47:   MatDenseGetArray(B,&T);
 48:   MatGetSize(A,&m,NULL);
 49:   PetscBLASIntCast(m,&n);
 50:   SlepcSqrtmSchur(n,T,n,PETSC_FALSE);
 51:   MatDenseRestoreArray(B,&T);
 52:   return(0);
 53: }

 55: PetscErrorCode FNEvaluateFunctionMatVec_Sqrt_Schur(FN fn,Mat A,Vec v)
 56: {
 58:   PetscBLASInt   n;
 59:   PetscScalar    *T;
 60:   PetscInt       m;
 61:   Mat            B;

 64:   FN_AllocateWorkMat(fn,A,&B);
 65:   MatDenseGetArray(B,&T);
 66:   MatGetSize(A,&m,NULL);
 67:   PetscBLASIntCast(m,&n);
 68:   SlepcSqrtmSchur(n,T,n,PETSC_TRUE);
 69:   MatDenseRestoreArray(B,&T);
 70:   MatGetColumnVector(B,v,0);
 71:   FN_FreeWorkMat(fn,&B);
 72:   return(0);
 73: }

 75: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP(FN fn,Mat A,Mat B)
 76: {
 78:   PetscBLASInt   n;
 79:   PetscScalar    *T;
 80:   PetscInt       m;

 83:   if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
 84:   MatDenseGetArray(B,&T);
 85:   MatGetSize(A,&m,NULL);
 86:   PetscBLASIntCast(m,&n);
 87:   SlepcSqrtmDenmanBeavers(n,T,n,PETSC_FALSE);
 88:   MatDenseRestoreArray(B,&T);
 89:   return(0);
 90: }

 92: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS(FN fn,Mat A,Mat B)
 93: {
 95:   PetscBLASInt   n;
 96:   PetscScalar    *Ba;
 97:   PetscInt       m;

100:   if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
101:   MatDenseGetArray(B,&Ba);
102:   MatGetSize(A,&m,NULL);
103:   PetscBLASIntCast(m,&n);
104:   SlepcSqrtmNewtonSchulz(n,Ba,n,PETSC_FALSE);
105:   MatDenseRestoreArray(B,&Ba);
106:   return(0);
107: }

109: #define MAXIT 50

111: /*
112:    Computes the principal square root of the matrix A using the
113:    Sadeghi iteration. A is overwritten with sqrtm(A).
114:  */
115: static PetscErrorCode SlepcSqrtmSadeghi(PetscBLASInt n,PetscScalar *A,PetscBLASInt ld)
116: {
117:   PetscScalar        *M,*M2,*G,*X=A,*work,work1,alpha,sqrtnrm;
118:   PetscScalar        szero=0.0,sone=1.0,smfive=-5.0,s1d16=1.0/16.0;
119:   PetscReal          tol,Mres=0.0,nrm,rwork[1];
120:   PetscBLASInt       N,i,it,*piv=NULL,info,lwork,query=-1;
121:   const PetscBLASInt one=1;
122:   PetscBool          converged=PETSC_FALSE;
123:   PetscErrorCode     ierr;
124:   unsigned int       ftz;

127:   N = n*n;
128:   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
129:   SlepcSetFlushToZero(&ftz);

131:   /* query work size */
132:   PetscStackCallBLAS("LAPACKgetri",LAPACKgetri_(&n,A,&ld,piv,&work1,&query,&info));
133:   PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork);

135:   PetscMalloc5(N,&M,N,&M2,N,&G,lwork,&work,n,&piv);
136:   PetscArraycpy(M,A,N);

138:   /* scale M */
139:   nrm = LAPACKlange_("fro",&n,&n,M,&n,rwork);
140:   if (nrm>1.0) {
141:     sqrtnrm = PetscSqrtReal(nrm);
142:     alpha = 1.0/nrm;
143:     PetscStackCallBLAS("BLASscal",BLASscal_(&N,&alpha,M,&one));
144:     tol *= nrm;
145:   }
146:   PetscInfo2(NULL,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol);

148:   /* X = I */
149:   PetscArrayzero(X,N);
150:   for (i=0;i<n;i++) X[i+i*ld] = 1.0;

152:   for (it=0;it<MAXIT && !converged;it++) {

154:     /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
155:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M,&ld,M,&ld,&szero,M2,&ld));
156:     PetscStackCallBLAS("BLASaxpy",BLASaxpy_(&N,&smfive,M,&one,M2,&one));
157:     for (i=0;i<n;i++) M2[i+i*ld] += 15.0;
158:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&s1d16,M,&ld,M2,&ld,&szero,G,&ld));
159:     for (i=0;i<n;i++) G[i+i*ld] += 5.0/16.0;

161:     /* X = X*G */
162:     PetscArraycpy(M2,X,N);
163:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M2,&ld,G,&ld,&szero,X,&ld));

165:     /* M = M*inv(G*G) */
166:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,G,&ld,&szero,M2,&ld));
167:     PetscStackCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,M2,&ld,piv,&info));
168:     SlepcCheckLapackInfo("getrf",info);
169:     PetscStackCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M2,&ld,piv,work,&lwork,&info));
170:     SlepcCheckLapackInfo("getri",info);

172:     PetscArraycpy(G,M,N);
173:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,M2,&ld,&szero,M,&ld));

175:     /* check ||I-M|| */
176:     PetscArraycpy(M2,M,N);
177:     for (i=0;i<n;i++) M2[i+i*ld] -= 1.0;
178:     Mres = LAPACKlange_("fro",&n,&n,M2,&n,rwork);
179:     PetscIsNanReal(Mres);
180:     if (Mres<=tol) converged = PETSC_TRUE;
181:     PetscInfo2(NULL,"it: %D res: %g\n",it,(double)Mres);
182:     PetscLogFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n);
183:   }

185:   if (Mres>tol) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",MAXIT);

187:   /* undo scaling */
188:   if (nrm>1.0) PetscStackCallBLAS("BLASscal",BLASscal_(&N,&sqrtnrm,A,&one));

190:   PetscFree5(M,M2,G,work,piv);
191:   SlepcResetFlushToZero(&ftz);
192:   return(0);
193: }

195: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi(FN fn,Mat A,Mat B)
196: {
198:   PetscBLASInt   n;
199:   PetscScalar    *Ba;
200:   PetscInt       m;

203:   if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
204:   MatDenseGetArray(B,&Ba);
205:   MatGetSize(A,&m,NULL);
206:   PetscBLASIntCast(m,&n);
207:   SlepcSqrtmSadeghi(n,Ba,n);
208:   MatDenseRestoreArray(B,&Ba);
209:   return(0);
210: }

212: PetscErrorCode FNView_Sqrt(FN fn,PetscViewer viewer)
213: {
215:   PetscBool      isascii;
216:   char           str[50];
217:   const char     *methodname[] = {
218:                   "Schur method for the square root",
219:                   "Denman-Beavers (product form)",
220:                   "Newton-Schulz iteration",
221:                   "Sadeghi iteration"
222:   };
223:   const int      nmeth=sizeof(methodname)/sizeof(methodname[0]);

226:   PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii);
227:   if (isascii) {
228:     if (fn->beta==(PetscScalar)1.0) {
229:       if (fn->alpha==(PetscScalar)1.0) {
230:         PetscViewerASCIIPrintf(viewer,"  Square root: sqrt(x)\n");
231:       } else {
232:         SlepcSNPrintfScalar(str,50,fn->alpha,PETSC_TRUE);
233:         PetscViewerASCIIPrintf(viewer,"  Square root: sqrt(%s*x)\n",str);
234:       }
235:     } else {
236:       SlepcSNPrintfScalar(str,50,fn->beta,PETSC_TRUE);
237:       if (fn->alpha==(PetscScalar)1.0) {
238:         PetscViewerASCIIPrintf(viewer,"  Square root: %s*sqrt(x)\n",str);
239:       } else {
240:         PetscViewerASCIIPrintf(viewer,"  Square root: %s",str);
241:         PetscViewerASCIIUseTabs(viewer,PETSC_FALSE);
242:         SlepcSNPrintfScalar(str,50,fn->alpha,PETSC_TRUE);
243:         PetscViewerASCIIPrintf(viewer,"*sqrt(%s*x)\n",str);
244:         PetscViewerASCIIUseTabs(viewer,PETSC_TRUE);
245:       }
246:     }
247:     if (fn->method<nmeth) {
248:       PetscViewerASCIIPrintf(viewer,"  computing matrix functions with: %s\n",methodname[fn->method]);
249:     }
250:   }
251:   return(0);
252: }

254: SLEPC_EXTERN PetscErrorCode FNCreate_Sqrt(FN fn)
255: {
257:   fn->ops->evaluatefunction          = FNEvaluateFunction_Sqrt;
258:   fn->ops->evaluatederivative        = FNEvaluateDerivative_Sqrt;
259:   fn->ops->evaluatefunctionmat[0]    = FNEvaluateFunctionMat_Sqrt_Schur;
260:   fn->ops->evaluatefunctionmat[1]    = FNEvaluateFunctionMat_Sqrt_DBP;
261:   fn->ops->evaluatefunctionmat[2]    = FNEvaluateFunctionMat_Sqrt_NS;
262:   fn->ops->evaluatefunctionmat[3]    = FNEvaluateFunctionMat_Sqrt_Sadeghi;
263:   fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Sqrt_Schur;
264:   fn->ops->view                      = FNView_Sqrt;
265:   return(0);
266: }