import sys, petsc4py import numpy as np petsc4py.init(sys.argv) from petsc4py import PETSc N=10 expected_soln_vec_npy = 100*np.random.randn(N) def residue(expected_soln_vec_npy, x, f): ''' Residue calculator ''' print("Call") f = (x - expected_soln_vec_npy)**2 print("Residue:", f) return f # this user class is an application # context for the nonlinear problem # at hand; it contains some parameters # and knows how to compute residuals class SimpleEqn: def __init__(self, m, expected_soln_vec_petsc, impl='python'): self.m = m self.expected_soln_vec_petsc = expected_soln_vec_petsc if impl == 'python': order = 'c' elif impl == 'fortran': order = 'f' else: raise ValueError('invalid implementation') self.compute = residue self.order = order def evalFunction(self, snes, X, F): m = self.m expected_soln_vec_petsc = self.expected_soln_vec_petsc order = self.order x = X.getArray(readonly=1).reshape(m, order=order) f = F.getArray(readonly=0).reshape(m, order=order) expected_soln_vec_npy = expected_soln_vec_petsc.getArray(readonly=1).reshape(m, order=order) self.compute(expected_soln_vec_npy, x, f) # convenience access to # PETSc options database OptDB = PETSc.Options() m = OptDB.getInt('m', N) impl = OptDB.getString('impl', 'python') expected_soln_vec_petsc = PETSc.Vec().createSeq(m) expected_soln_vec_petsc.setArray(expected_soln_vec_npy) # create application context # and PETSc nonlinear solver appc = SimpleEqn(m, expected_soln_vec_petsc, impl) snes = PETSc.SNES().create() # register the function in charge of # computing the nonlinear residual f = PETSc.Vec().createSeq(m) snes.setFunction(appc.evalFunction, f) # configure the nonlinear solver # to use a matrix-free Jacobian snes.setUseMF(False) snes.getKSP().setType('cg') snes.setFromOptions() snes.setTolerances(max_it=100) # solve the nonlinear problem b, x = None, f.duplicate() x.set(0) snes.solve(b, x) cr = snes.getConvergedReason() print('Converged Reason:', cr) ncr = snes.getErrorIfNotConverged() print("Not converged?", ncr) tols = snes.getTolerances() print("Tols:", tols) snes.getIterationNumber()