#!/usr/bin/env python2.7
# -*- coding: UTF-8 -*-
#
# Copyright (C) 2010-2011 Yung-Yu Chen <yyc@solvcon.net>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

"""
Test for group (energy) velocity of beryl.  The stopping time is at 91
microsecond.

This example implements a point source for the velocity-stress equation solver,
enabled by the corresponding Anchor class.
"""

from solvcon.anchor import Anchor
from solvcon.hook import BlockHook
from solvcon.kerpak import vslin

################################################################################
# Source term treatment.
################################################################################
class PointSource(object):
    def __init__(self, name, crd, amp, delay, freq):
        self.name = name
        self.crd = crd
        self.pcl = -1
        self.amp = amp
        self.delay = delay
        self.freq = freq
    def locate_cell(self, svr):
        idx = svr.locate_point(self.crd[0], self.crd[1])
        self.pcl = idx[0]
    def calc_timedep(self, svr):
        raise NotImplementedError
    def __call__(self, svr):
        from math import exp, pi
        tdep = self.calc_timedep(svr)
        return self.amp * tdep / svr.cevol[svr.ngstcell+self.pcl,0]
class PointSourceRicker(PointSource):
    def calc_timedep(self, svr):
        from math import exp, pi
        time = svr.exd.time
        time_increment = svr.exd.time_increment
        delay = self.delay
        freq = self.freq
        t1 = time - delay
        t2 = t1 + time_increment/2
        tdep = t2*exp(-(pi*freq*t2)**2) - t1*exp(-(pi*freq*t1)**2)
        return tdep / 2
class PointSourceCos(PointSource):
    def calc_timedep(self, svr):
        from math import sin, pi
        time = svr.exd.time
        time_increment = svr.exd.time_increment
        delay = self.delay
        freq = self.freq
        t1 = time - delay
        t2 = t1 + time_increment/2
        pref = 2*pi*freq
        tdep = sin(pref*t2) - sin(pref*t1)
        return tdep / pref

class PSAnchor(Anchor):
    def __init__(self, svr, **kw):
        self.sources = kw.pop('sources')
        super(PSAnchor, self).__init__(svr, **kw)
    def preloop(self):
        svr = self.svr
        for src in self.sources:
            src.locate_cell(svr)
    def postcalcsoln(self):
        svr = self.svr
        ngstcell = svr.ngstcell
        for src in self.sources:
            if src.pcl >= 0:
                svr.soln[ngstcell+src.pcl,:] += src(svr)

################################################################################
# Basic configuration.
################################################################################
def grpv_base(casename=None, meshname=None, mtrlname='Beryl',
    al=0.0, be=90.0, ga=0.0, psteps=None, ssteps=None, **kw):
    """
    Fundamental configuration of the simulation and return the case object.

    @return: the created Case object.
    @rtype: solvcon.case.BlockCase
    """
    import os
    from numpy import pi, array, sin, cos, sqrt
    from solvcon.conf import env
    from solvcon.boundcond import bctregy
    from solvcon import hook, anchor
    from solvcon.solver import ALMOST_ZERO
    from solvcon.kerpak import cuse
    from solvcon.kerpak.vslin import mltregy
    # set up BCs.
    bcmap = {
        'left': (bctregy.CuseNonrefl, {}),
        'right': (bctregy.CuseNonrefl, {}),
        'upper': (bctregy.CuseNonrefl, {}),
        'lower': (bctregy.CuseNonrefl, {}),
    }
    # set up case.
    mtrl = mltregy[mtrlname](al=al*pi/180.0, be=be*pi/180.0, ga=ga*pi/180.0)
    basedir = os.path.join(os.path.abspath(os.getcwd()), 'result')
    cse = vslin.VslinCase(basedir=basedir, rootdir=env.projdir,
        basefn=casename, meshfn=os.path.join(env.find_scdata_mesh(), meshname),
        bcmap=bcmap, mtrldict={None: mtrl}, **kw)
    # statistical anchors for solvers.
    for name in 'Runtime', 'March', 'Tpool':
        cse.runhooks.append(getattr(anchor, name+'StatAnchor'))
    # informative hooks.
    cse.runhooks.append(hook.BlockInfoHook)
    cse.runhooks.append(hook.ProgressHook, psteps=psteps,
        linewidth=ssteps/psteps)
    cse.runhooks.append(cuse.CflHook, fullstop=False, psteps=ssteps,
        cflmax=10.0, linewidth=ssteps/psteps)
    # initializer anchors..
    cse.runhooks.append(anchor.FillAnchor, keys=('soln',), value=ALMOST_ZERO)
    cse.runhooks.append(anchor.FillAnchor, keys=('dsoln',), value=0)
    ## point source.
    src = PointSourceRicker(name='source', crd=(0.0, 0.0),
        delay=0.0, freq=1.e5,
        amp=array([
            1, 1, 1, 0,0,0,0,0,0,
        ], dtype='float64'),
    )
    cse.runhooks.append(PSAnchor, sources=[src])
    # analyzing/output anchors and hooks.
    cse.runhooks.append(vslin.VslinOAnchor)
    cse.runhooks.append(hook.PMarchSave, anames=[
            ('soln', False, -9),
            ('energy', True, 0),
        ], fpdtype='float64', psteps=ssteps, compressor='gz')
    return cse

################################################################################
# The arrangement.
################################################################################
@vslin.VslinCase.register_arrangement
def grpv(casename, **kw):
    return grpv_base(casename=casename, meshname='psquare_10mm.neu.gz',
        time_increment=3.64e-7,
        steps_run=250, steps_stride=1, ssteps=50, psteps=1, **kw)

################################################################################
# Invoke SOLVCON workflow.
################################################################################
if __name__ == '__main__':
    import solvcon
    solvcon.go()
