#!/usr/bin/env python2.7
# -*- coding: UTF-8 -*-
#
# Copyright (C) 2010-2011 Yung-Yu Chen <yyc@solvcon.net>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

"""
Convergence test for two- and three-dimensional, anisotropic, linear elastic
solver.  2D tests are run against a two by two (meter) square with 0.2, 0.15,
0.1, 0.05 cell edge sizes.  3D tests are run against a two by two by two
(meter) cube with 0.2, 0.15, 0.1 cell edge sizes.  The meshes are generated by
calling Cubit on the fly.  The wave vector is set to (pi, pi) or (pi, pi, pi).
A command line tool "./go converge" is provided to analyze the order of
convergence of error after the arrangements are run.

Run the arrangements by ./go run cvg2d_200 cvg2d_150 cvg2d_100 cvg2d_50
cvg3d_200 cvg3d_150 cvg3d_100
"""

from solvcon.anchor import Anchor
from solvcon.hook import BlockHook
from solvcon.kerpak.cuse import CusePeriodic
from solvcon.kerpak.lincuse import PlaneWaveSolution
from solvcon.kerpak import vslin
from solvcon.cmdutil import Command

###############################################################################
# Command line.
###############################################################################
class converge(Command):
    """
    Calculate and verify convergence.

    Must supply <delta> <M1>.
    """
    min_args = 0

    def __init__(self, env):
        from optparse import OptionGroup
        super(converge, self).__init__(env)
        op = self.op

        opg = OptionGroup(op, 'Convergence')
        opg.add_option("--wdir", action="store",
            dest="wdir", default="result", help="Working directory.")
        opg.add_option("--key", action="store",
            dest="key", default="L2", help="Linf or L2 norm.")
        opg.add_option("--idx", action="store", type=int,
            dest="idx", default=0, help="Index of variable: 0--8.")
        opg.add_option("--order", action="store", type=float,
            dest="order", default=None,
            help="Error-norm should converge at the rate, if given.")
        opg.add_option("--order-tolerance", action="store", type=float,
            dest="order_tolerance", default=0.4,
            help="The variation of converge order which can be tolerated.")
        opg.add_option("--stop-on-over", action="store_true",
            dest="stop_on_over", default=False,
            help="Raise ValueError if tolerance not met.")
        op.add_option_group(opg)
        self.opg_obshock = opg

    def _convergence_test(self, mainfn):
        import os, sys
        from math import log
        from pickle import load
        from glob import glob
        ops, args = self.opargs
        # collect data.
        mms = reversed(sorted([int(txt.split('_')[1]) for txt in 
            glob(os.path.join(ops.wdir, '%s_*_norm.pickle'%mainfn))]))
        dat = [(mm, load(open(os.path.join(ops.wdir,
                '%s_%d_norm.pickle'%(mainfn, mm))))) for mm in mms]
        # show convergence.
        sys.stdout.write(
            '%s convergence of %s error-norm at the %dth (0--8) variable:\n' % (
            mainfn, ops.key, ops.idx))
        for ih in range(1, len(dat)):
            er = [dat[it][1][ops.key][ops.idx] for it in range(ih-1, ih+1)]
            hr = [float(dat[it][0])/1000 for it in range(ih-1, ih+1)]
            odr = log(er[1]/er[0])/log(hr[1]/hr[0])
            sys.stdout.write('  %6.4f -> %6.4f (m): %g' % (hr[0], hr[1], odr))
            if ops.order is not None:
                if abs(odr - ops.order) < ops.order_tolerance:
                    sys.stdout.write(' GOOD. Within')
                else:
                    if ops.stop_on_over:
                        raise ValueError('out of tolerance')
                    else:
                        sys.stdout.write(' BAD. Out of')
                sys.stdout.write(' %g +/- %g'%(ops.order, ops.order_tolerance))
            sys.stdout.write('\n')

    def __call__(self):
        self._convergence_test('cvg2d')
        self._convergence_test('cvg3d')

################################################################################
# Mesh generation and boundary condition processor.
################################################################################
def mesher(cse):
    """
    Generate a cube according to journaling file cube.tmpl.
    """
    from solvcon.helper import Cubit
    try:
        itv = float(cse.io.basefn.split('_')[-1])/1000
    except ValueError:
        itv = 0.2
    ndim = int(cse.io.basefn[3])
    if ndim == 3:
        cmds = open('cube.tmpl').read() % itv
    else:
        cmds = open('square.tmpl').read() % itv
    cmds = [cmd.strip() for cmd in cmds.strip().split('\n')]
    gn = Cubit(cmds, ndim)()
    return gn.toblock(bcname_mapper=cse.condition.bcmap)

def match_periodic(blk):
    """
    Match periodic boundary condition.
    """
    from numpy import array
    from solvcon.boundcond import bctregy
    bct = bctregy.CusePeriodic
    bcmap = dict()
    bcmap.update({
        'left': (
            bct, {
                'link': 'right',
                'ref': array([0,-2,-2] if blk.ndim == 3 else [0,-2],
                    dtype='float64')
            }
        ),
    })
    bcmap.update({
        'lower': (
            bct, {
                'link': 'upper',
                'ref': array([-2,0,-2] if blk.ndim == 3 else [-2,0],
                    dtype='float64'),
            }
        ),
    })
    if blk.ndim == 3:
        bcmap.update({
            'rear': (
                bct, {
                    'link': 'front',
                    'ref': array([-2,-2,0], dtype='float64'),
                }
            ),
        })
    bct.couple_all(blk, bcmap)

################################################################################
# Basic configuration.
################################################################################
def cvg_base(casename=None, mtrlname='GaAs',
    al=20.0, be=40.0, ga=50.0, wtests=None, psteps=None, ssteps=None, **kw):
    """
    Fundamental configuration of the simulation and return the case object.

    @return: the created Case object.
    @rtype: solvcon.case.BlockCase
    """
    import os
    from numpy import pi, zeros, sin, cos, sqrt
    from solvcon.conf import env
    from solvcon.boundcond import bctregy
    from solvcon import hook, anchor
    from solvcon.solver import ALMOST_ZERO
    from solvcon.kerpak import cuse
    from solvcon.kerpak.lincuse import PlaneWaveHook
    from solvcon.kerpak.vslin import mltregy, VslinPWSolution
    ndim = int(casename[3])
    # set up BCs.
    bct = bctregy.BC
    bcmap = dict()
    bcmap.update({
        'left': (bct, {}),
        'right': (bct, {}),
        'upper': (bct, {}),
        'lower': (bct, {}),
    })
    if ndim == 3:
        bcmap.update({
            'front': (bct, {}),
            'rear': (bct, {}),
        })
    # set up case.
    mtrl = mltregy[mtrlname](al=al*pi/180.0, be=be*pi/180.0, ga=ga*pi/180.0)
    basedir = os.path.join(os.path.abspath(os.getcwd()), 'result')
    cse = vslin.VslinCase(basedir=basedir, rootdir=env.projdir,
        basefn=casename, mesher=mesher, bcmap=bcmap, bcmod=match_periodic,
        mtrldict={None: mtrl}, taylor=0.0, **kw)
    # statistical anchors for solvers.
    for name in 'Runtime', 'March', 'Tpool':
        cse.runhooks.append(getattr(anchor, name+'StatAnchor'))
    # informative hooks.
    cse.runhooks.append(hook.BlockInfoHook)
    cse.runhooks.append(hook.ProgressHook, psteps=psteps,
        linewidth=ssteps/psteps)
    cse.runhooks.append(cuse.CflHook, fullstop=False, psteps=ssteps,
        cflmax=10.0, linewidth=ssteps/psteps)
    # initializer anchors..
    cse.runhooks.append(anchor.FillAnchor, keys=('soln',), value=ALMOST_ZERO)
    cse.runhooks.append(anchor.FillAnchor, keys=('dsoln',), value=0)
    ## plane wave solution.
    pws = list()
    for wvec, idx in wtests:
        pws.append(VslinPWSolution(amp=1.0,
            ctr=zeros(ndim, dtype='float64'), wvec=wvec, mtrl=mtrl, idx=idx
            ))
    cse.runhooks.append(PlaneWaveHook, psteps=ssteps, planewaves=pws)
    # analyzing/output anchors and hooks.
    cse.runhooks.append(hook.PMarchSave, anames=[
            ('soln', False, -9),
            ('analytical', True, -9),
            ('difference', True, -9),
        ], fpdtype='float64', psteps=ssteps, compressor='gz')
    return cse

def cvg2d_skel(casename, div, std, **kw):
    from numpy import array, pi
    period = 2.649983322636356e-04
    return cvg_base(casename=casename,
        time_increment=period/div, steps_run=2*div, ssteps=div, psteps=std,
        wtests=(
            (array([1,1], dtype='float64')*pi, 8),
        ), **kw)

def cvg3d_skel(casename, div, std, **kw):
    from numpy import array, pi
    period = 2.353142528777195e-04
    return cvg_base(casename=casename,
        time_increment=period/div, steps_run=2*div, ssteps=div, psteps=std,
        wtests=(
            (array([1,1,1], dtype='float64')*pi, 8),
        ), **kw)

################################################################################
# The arrangement for 2D convergence test.
################################################################################
@vslin.VslinCase.register_arrangement
def cvg2d_200(casename, **kw):
    return cvg2d_skel(casename=casename, div=18, std=1)
@vslin.VslinCase.register_arrangement
def cvg2d_150(casename, **kw):
    return cvg2d_skel(casename=casename, div=22, std=1)
@vslin.VslinCase.register_arrangement
def cvg2d_100(casename, **kw):
    return cvg2d_skel(casename=casename, div=32, std=1)
@vslin.VslinCase.register_arrangement
def cvg2d_50(casename, **kw):
    return cvg2d_skel(casename=casename, div=64, std=2)

################################################################################
# The arrangement for 3D convergence test.
################################################################################
@vslin.VslinCase.register_arrangement
def cvg3d_200(casename, **kw):
    return cvg3d_skel(casename=casename, div=32, std=1)
@vslin.VslinCase.register_arrangement
def cvg3d_150(casename, **kw):
    return cvg3d_skel(casename=casename, div=40, std=1)
@vslin.VslinCase.register_arrangement
def cvg3d_100(casename, **kw):
    return cvg3d_skel(casename=casename, div=64, std=2)

################################################################################
# Invoke SOLVCON workflow.
################################################################################
if __name__ == '__main__':
    import solvcon
    solvcon.go()
