#include "Sundance.hpp"
#include <cstdio>

/**
 * \example timeStepHeat1D.cpp 
 *
 * This example shows how to do timestepping in Sundance. We solve the
 * transient heat equation in one dimension using Crank-Nicolson time
 * discretization. The time discretization is done at the symbolic level.
 * Spatial discretization is done via StaticLinearProblem, yielding system
 * matrices and vectors that can be used to march the problem in time. In this
 * example we march with a simple constant timestep loop; for more difficult
 * problems one would use a high-quality DAE/ODE integration code such 
 * as DASSLQ or PVODE.
 *
 * We solve the heat equation u_xx = u_t with boundary conditions
 * u(0)=u(1)=0 and initial conditions u(x,t=0)=sin(pi x). The solution
 * is u(x,t)=exp(-pi^2 t) sin(pi x).
 */

CELL_PREDICATE(LeftPointTest, {return fabs(x[0]) < 1.0e-10;});
CELL_PREDICATE(BottomPointTest, {return fabs(x[1]) < 1.0e-10;});
CELL_PREDICATE(RightPointTest, {return fabs(x[0]-1.0) < 1.0e-10;});
CELL_PREDICATE(TopPointTest, {return fabs(x[1]-1.0) < 1.0e-10;});

int main(int argc, void** argv)
{
  try
    {
      MPISession::init(&argc, &argv);
      int np = MPIComm::world().getNProc();

      /* We will do our linear algebra using Epetra */
      VectorType<double> vecType = new EpetraVectorType();

      /* create a simple mesh on the unit line */
      int nx = 200;
      int ny = 200;
      MeshType meshType = new BasicSimplicialMeshType();
      MeshSource mesher = new PartitionedRectangleMesher(0.0, 1.0, nx, np,
                                                         0.0, 1.0, ny, 1,
                                                         meshType);
      Mesh mesh = mesher.getMesh();

      CellFilter interior = new MaximalCellFilter();
      CellFilter edges = new DimensionalCellFilter(1);

      CellFilter left = edges.subset(new LeftPointTest());
      CellFilter right = edges.subset(new RightPointTest());
      CellFilter top = edges.subset(new TopPointTest());
      CellFilter bottom = edges.subset(new BottomPointTest());

      /* create unknown and variational functions */
      Expr delU = new TestFunction(new Lagrange(2), "delU");
      Expr U = new UnknownFunction(new Lagrange(2), "U");

      /* create a differentiation operator */
      Expr dx = new Derivative(0);
      Expr dy = new Derivative(1);
      Expr gradient = List(dx, dy);

      /* the initial conditions will be u0(x,t=0) = sin(pi*x[0]).
       * create a coordinate expression to represent x, then
       * create sin(pi*x), and then project it onto a discrete function. */
      Expr x = new CoordExpr(0);
      Expr y = new CoordExpr(1);

      double pi = 4.0*atan(1.0);

      /* Create a discrete space, and discretize the function 1.0+x on it */
      DiscreteSpace discreteSpace(mesh, new Lagrange(2), vecType);

      L2Projector projector(discreteSpace, sin(pi*x) * sin(pi*y));
      Expr u0 = projector.project();

      /* We need a quadrature rule for doing the integrations */
      QuadratureFamily quad = new GaussianQuadrature(4);

      /* 
         set up crank-nicolson stepping with timestep = 0.02. The time
         discretization is done at the symbolic level, yielding
         an elliptic problem that we solve repeatedly for the updated
         solution at each time level. 
      */

      double deltaT = 0.02;
      //Expr cnStep = delU*(U - u0) + deltaT*((gradient*delU)*(gradient*U));
      Expr cnStep = delU*(U - u0) + deltaT*((gradient*delU)*(gradient*(U+u0)/2.0));
      Expr eqn = Integral(interior, cnStep, quad);

      /* Define BCs to be zero at all sides */
		Expr bc = EssentialBC(left, delU*U, quad) + 
			EssentialBC(right, delU*U, quad) + 
			EssentialBC(top, delU*U, quad) + 
			EssentialBC(bottom, delU*U, quad);

      /* No flux BCs everywhere */
		//Expr bc;
			

      /* OLD_CODE create a solver object */
      //TSFPreconditionerFactory prec = new ILUKPreconditionerFactory(1);
      //TSFLinearSolver solver = new BICGSTABSolver(prec, 1.0e-14, 300);
      //solver.setVerbosityLevel(0);     

      ParameterXMLFileReader reader("bicgstab.xml");
      ParameterList solverParams = reader.getParameters();
      cout << "params = " << solverParams << endl;


      LinearSolver<double> solver 
        = LinearSolverBuilder::createSolver(solverParams);


      /* 
         put the time-discretized eqn into a StaticLinearProblem object
         which will do the spatial discretization.
      */
		LinearProblem prob(mesh, eqn, bc, delU, U, vecType);

      // OLD_CODE LinearProblem prob(mesh, eqn, bc, List(vx, q), 
      //                   List(ux, p), vecType);

      /* 
         Now, loop over timesteps, solving the elliptic problem for u at each
         step. At the end of each step, assign the solution solnU into u0.
         Because Exprs are stored by reference, the updating of u0 propagates
         to the copies of u0 in the equation set and in the 
         StaticLinearProblem. The same StaticLinearProblem can be reused
         at all timesteps.
      */
      int nSteps = 10;
      for (int i=0; i<nSteps; i++)
        {
          /* solve the problem */
          Expr soln = prob.solve(solver);
          Vector<double> solnVec = DiscreteFunction::discFunc(soln)->getVector();
			 DiscreteFunction::discFunc(u0)->setVector(solnVec);
          cerr << "eqn = " << cnStep << endl;
          cerr << "u0 = " << u0 << endl;
          /* write the solution at step i to a file */
          char fName[20];
          sprintf(fName, "OCtimeStepHeat%d.dat", i);
          ofstream of(fName);
          FieldWriter writer = new MatlabWriter(fName);
			 writer.addMesh(mesh);
			 writer.addField("u0", new ExprFieldWrapper(soln[0]));
			 writer.write();
          cerr << "[" << i << "]";
          /* flush the matrix and RHS values */
          //prob.flushMatrixValues();
        }
      cerr << endl;

      /* compute the exact solution and the error */
      double tFinal = nSteps * deltaT;
			Expr exactSoln = exp(-2 * pi*pi*tFinal) * sin(pi*x) * sin(pi*y);
			
      /*
        compute the norm of the error
      */
      Expr uxErr = u0 - exactSoln;
      Expr errExpr = Integral(interior, 
                              uxErr*uxErr,
                              new GaussianQuadrature(4));

      FunctionalEvaluator errInt(mesh, errExpr);

      double errorSq = errInt.evaluate();
      double errorNorm = sqrt(errorSq);
      cerr << "error norm = " << errorNorm << endl << endl;


    }
	catch(exception& e)
		{
      cerr << e.what() << endl;
		}
  Sundance::finalize();
}





