#!/usr/bin/env python2.7
# -*- coding: UTF-8 -*-
#
# Copyright (C) 2010-2011 Yung-Yu Chen <yyc@solvcon.net>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

"""
Demonstrate customizing a new solver from solvcon.kerpak.cese.  The custom
solver is for linear elasticity.  Convergence test with plane wave for the
solver is used as the model problem.  The wave vector is set to (pi, pi, 0).  A
command line tool "./go converge" is provided to analyze the order of
convergence of error after the arrangements are run.

Run the arrangements by ./go run elas3d_200 elas3d_150 elas3d_100
"""

from solvcon.anchor import Anchor
from solvcon.hook import BlockHook
from solvcon.kerpak.cese import CesePeriodic
from solvcon.cmdutil import Command
import elastic

###############################################################################
# Command line.
###############################################################################
class converge(Command):
    """
    Calculate and verify convergence.

    Must supply <delta> <M1>.
    """
    min_args = 0

    def __init__(self, env):
        from optparse import OptionGroup
        super(converge, self).__init__(env)
        op = self.op

        opg = OptionGroup(op, 'Convergence')
        opg.add_option("--wdir", action="store",
            dest="wdir", default="result", help="Working directory.")
        opg.add_option("--key", action="store",
            dest="key", default="L2", help="Linf or L2 norm.")
        opg.add_option("--idx", action="store", type=int,
            dest="idx", default=0, help="Index of variable: 0--8.")
        opg.add_option("--order", action="store", type=float,
            dest="order", default=None,
            help="Error-norm should converge at the rate, if given.")
        opg.add_option("--order-tolerance", action="store", type=float,
            dest="order_tolerance", default=0.4,
            help="The variation of converge order which can be tolerated.")
        opg.add_option("--stop-on-over", action="store_true",
            dest="stop_on_over", default=False,
            help="Raise ValueError if tolerance not met.")
        op.add_option_group(opg)
        self.opg_obshock = opg

    def __call__(self):
        import os, sys
        from pickle import load
        from math import log
        ops, args = self.opargs
        dat = [(mm,
            load(open(os.path.join(ops.wdir, 'elas3d_%d_norm.pickle'%mm))))
            for mm in (200, 150, 100)]
        sys.stdout.write(
            'Convergence of %s error-norm at the %dth (0--8) variable:\n' % (
            ops.key, ops.idx))
        for ih in range(1, len(dat)):
            er = [dat[it][1][ops.key][ops.idx] for it in range(ih-1, ih+1)]
            hr = [float(dat[it][0])/1000 for it in range(ih-1, ih+1)]
            odr = log(er[1]/er[0])/log(hr[1]/hr[0])
            sys.stdout.write('  %6.4f -> %6.4f (m): %g' % (hr[0], hr[1], odr))
            if ops.order is not None:
                if abs(odr - ops.order) < ops.order_tolerance:
                    sys.stdout.write(' GOOD. Within')
                else:
                    if ops.stop_on_over:
                        raise ValueError('out of tolerance')
                    else:
                        sys.stdout.write(' BAD. Out of')
                sys.stdout.write(' %g +/- %g'%(ops.order, ops.order_tolerance))
            sys.stdout.write('\n')

###############################################################################
# Plane wave solution and initializer.
###############################################################################
class PlaneWaveSolution(object):
    def __init__(self, mtrl, idx, amp, ctr, wvec):
        from numpy import sqrt
        from numpy.linalg import eig
        assert len(ctr) == len(wvec)
        # calculate eigenvectors and eigenvalues.
        nml = wvec/sqrt((wvec**2).sum())
        jaco = mtrl.jacox*nml[0] + mtrl.jacoy*nml[1]
        if len(nml) == 3:
            jaco += mtrl.jacoz*nml[2]
        evl, evc = eig(jaco)
        srt = evl.argsort()
        evl = evl[srt[idx]]
        evc = evc[:,srt[idx]]
        # store data to self.
        self.amp = evc * (amp / sqrt((evc**2).sum()))
        self.ctr = ctr
        self.wvec = wvec
        self.afreq = evl * sqrt((wvec**2).sum())
        self.wsp = evl
    def __call__(self, svr, asol, adsol):
        from ctypes import byref, c_double
        svr._clib_elastic.calc_planewave(
            byref(svr.exd),
            asol.ctypes._as_parameter_,
            adsol.ctypes._as_parameter_,
            self.amp.ctypes._as_parameter_,
            self.ctr.ctypes._as_parameter_,
            self.wvec.ctypes._as_parameter_,
            c_double(self.afreq),
        )
class PlaneWaveAnchor(Anchor):
    def __init__(self, svr, **kw):
        self.planewaves = kw.pop('planewaves')
        super(PlaneWaveAnchor, self).__init__(svr, **kw)
    def _calculate(self, asol):
        for pw in self.planewaves:
            pw(self.svr, asol, self.adsol)
    def provide(self):
        from numpy import empty
        # plane wave solution.
        asol = self.svr.der['analytical'] = empty(
            (self.svr.ncell, self.svr.neq), dtype=self.svr.fpdtype)
        adsol = self.adsol = empty(
            (self.svr.ncell, self.svr.neq, self.svr.ndim),
            dtype=self.svr.fpdtype)
        asol.fill(0.0)
        self._calculate(asol)
        self.svr.soln[self.svr.ngstcell:,:] = asol
        self.svr.dsoln[self.svr.ngstcell:,:,:] = adsol
        # difference.
        diff = self.svr.der['difference'] = empty(
            (self.svr.ncell, self.svr.neq), dtype=self.svr.fpdtype)
        diff[:,:] = self.svr.soln[self.svr.ngstcell:,:] - asol
    def postfull(self):
        # plane wave solution.
        asol = self.svr.der['analytical']
        asol.fill(0.0)
        self._calculate(asol)
        # difference.
        diff = self.svr.der['difference']
        diff[:,:] = self.svr.soln[self.svr.ngstcell:,:] - asol
class PlaneWaveHook(BlockHook):
    def __init__(self, svr, **kw):
        self.planewaves = kw.pop('planewaves')
        self.norm = dict()
        super(PlaneWaveHook, self).__init__(svr, **kw)
    def drop_anchor(self, svr):
        svr.runanchors.append(
            PlaneWaveAnchor(svr, planewaves=self.planewaves)
        )
    def _calculate(self):
        from numpy import empty, sqrt, abs
        var = self.cse.execution.var
        asol = self._collect_interior(
            'analytical', inder=True, consider_ghost=False)
        diff = self._collect_interior(
            'difference', inder=True, consider_ghost=False)
        norm_Linf = empty(9, dtype='float64')
        norm_L2 = empty(9, dtype='float64')
        clvol = self.blk.clvol
        for it in range(9):
            norm_Linf[it] = abs(diff[:,it]).max()
            norm_L2[it] = sqrt((diff[:,it]**2*clvol).sum())
        self.norm['Linf'] = norm_Linf
        self.norm['L2'] = norm_L2
    def preloop(self):
        from numpy import pi
        self.postmarch()
        for ipw in range(len(self.planewaves)):
            pw = self.planewaves[ipw]
            self.info("planewave[%d]:\n" % ipw)
            self.info("  c = %g, omega = %g, T = %.15e\n" % (
                pw.wsp, pw.afreq, 2*pi/pw.afreq))
    def postmarch(self):
        psteps = self.psteps
        istep = self.cse.execution.step_current
        if istep%psteps == 0:
            self._calculate()
    def postloop(self):
        import os
        from cPickle import dump
        fname = '%s_norm.pickle' % self.cse.io.basefn
        fname = os.path.join(self.cse.io.basedir, fname)
        dump(self.norm, open(fname, 'wb'), -1)
        self.info('Linf norm in velocity:\n')
        self.info('  %e, %e, %e\n' % tuple(self.norm['Linf'][:3]))
        self.info('L2 norm in velocity:\n')
        self.info('  %e, %e, %e\n' % tuple(self.norm['L2'][:3]))

################################################################################
# Mesh generation and boundary condition processor.
################################################################################
def mesher(cse):
    """
    Generate a cube according to journaling file cube.tmpl.
    """
    from solvcon.helper import Cubit
    try:
        itv = float(cse.io.basefn.split('_')[-1])/1000
    except ValueError:
        itv = 0.2
    cmds = open('cube.tmpl').read() % itv
    cmds = [cmd.strip() for cmd in cmds.strip().split('\n')]
    gn = Cubit(cmds, 3)()
    return gn.toblock(bcname_mapper=cse.condition.bcmap)

def match_periodic(blk):
    """
    Match periodic boundary condition.
    """
    from numpy import array
    from solvcon.boundcond import bctregy
    bct = bctregy.CesePeriodic
    bcmap = {
        'left': (
            bct, {
                'link': 'right',
                'ref': array([0,-2,-2], dtype='float64'),
            }
        ),
        'lower': (
            bct, {
                'link': 'upper',
                'ref': array([-2,0,-2], dtype='float64'),
            }
        ),
        'rear': (
            bct, {
                'link': 'front',
                'ref': array([-2,-2,0], dtype='float64'),
            }
        ),
    }
    bct.couple_all(blk, bcmap)

################################################################################
# Basic configuration.
################################################################################
def elas3d_base(casename=None, mtrlname='GaAs',
    al=20.0, be=40.0, ga=50.0, wtests=None, psteps=None, ssteps=None, **kw):
    """
    Fundamental configuration of the simulation and return the case object.

    @return: the created Case object.
    @rtype: solvcon.case.BlockCase
    """
    import os
    from numpy import pi, array, sin, cos, sqrt
    from solvcon.conf import env
    from solvcon.boundcond import bctregy
    from solvcon import hook, anchor
    from solvcon.solver import ALMOST_ZERO
    from solvcon.kerpak import cese
    #from solvcon.kerpak.elastic import mltregy
    from elastic import mltregy
    # set up BCs.
    bct = bctregy.BC
    bcmap = {
        'front': (bct, {}),
        'rear': (bct, {}),
        'left': (bct, {}),
        'right': (bct, {}),
        'upper': (bct, {}),
        'lower': (bct, {}),
    }
    # set up case.
    mtrl = mltregy[mtrlname](al=al*pi/180.0, be=be*pi/180.0, ga=ga*pi/180.0)
    basedir = os.path.join(os.path.abspath(os.getcwd()), 'result')
    cse = elastic.ElasticCase(basedir=basedir, rootdir=env.projdir,
        basefn=casename, mesher=mesher, bcmap=bcmap, bcmod=match_periodic,
        mtrldict={None: mtrl}, taylor=0.0, **kw)
    # statistical anchors for solvers.
    for name in 'Runtime', 'March', 'Tpool':
        cse.runhooks.append(getattr(anchor, name+'StatAnchor'))
    # informative hooks.
    cse.runhooks.append(hook.BlockInfoHook)
    cse.runhooks.append(hook.ProgressHook, psteps=psteps,
        linewidth=ssteps/psteps)
    cse.runhooks.append(cese.CflHook, fullstop=False, psteps=ssteps,
        cflmax=10.0, linewidth=ssteps/psteps)
    # initializer anchors..
    cse.runhooks.append(anchor.FillAnchor, keys=('soln',), value=ALMOST_ZERO)
    cse.runhooks.append(anchor.FillAnchor, keys=('dsoln',), value=0)
    ## plane wave solution.
    pws = list()
    for wvec, idx in wtests:
        pws.append(PlaneWaveSolution(mtrl=mtrl, amp=1.0, idx=idx,
            ctr=array([0,0,0], dtype='float64'), wvec=wvec))
    cse.runhooks.append(PlaneWaveHook, psteps=ssteps, planewaves=pws)
    # analyzing/output anchors and hooks.
    cse.runhooks.append(hook.PMarchSave, anames=[
            ('soln', False, -9),
        ], fpdtype='float64', psteps=ssteps, compressor='gz')
    return cse

def elas3d_skel(casename, div, std, **kw):
    from numpy import array, pi
    period = 2.649983322636356e-04
    return elas3d_base(casename=casename,
        time_increment=period/div, steps_run=2*div, ssteps=div, psteps=std,
        wtests=(
            (array([1,1,0], dtype='float64')*pi, 8),
        ), **kw)

################################################################################
# The arrangement for demonstration.
################################################################################
@elastic.ElasticCase.register_arrangement
def elas3d(casename, **kw):
    from numpy import array, pi
    period = 2.649983322636356e-04
    div = 36
    std = 1
    return elas3d_base(casename=casename,
        time_increment=period/div, steps_run=2*div, ssteps=std, psteps=std,
        wtests=(
            (array([1,1,0], dtype='float64')*pi, 8),
        ), **kw)

################################################################################
# The arrangement for convergence test.
################################################################################
@elastic.ElasticCase.register_arrangement
def elas3d_200(casename, **kw):
    return elas3d_skel(casename=casename, div=34, std=1)
@elastic.ElasticCase.register_arrangement
def elas3d_150(casename, **kw):
    return elas3d_skel(casename=casename, div=44, std=1)
@elastic.ElasticCase.register_arrangement
def elas3d_100(casename, **kw):
    return elas3d_skel(casename=casename, div=72, std=2)

################################################################################
# Invoke SOLVCON workflow.
################################################################################
if __name__ == '__main__':
    import solvcon
    solvcon.go()
