"""
1D Korteweg-de Vries / Burgers equation

This script should be ran serially (because it is 1D), and creates a space-time
plot of the computed solution.

"""

import numpy as np
import matplotlib.pyplot as plt
from dedalus2.public import *
from dedalus2.extras.plot_tools import quad_mesh, pad_limits


# Setup equation
problem = ParsedProblem(axis_names=['x'],
                        field_names=['u', 'ux', 'ut'],
                        param_names=['c'])
problem.add_equation("dt(ut) - c**2*dx(ux) = 0")
problem.add_equation("ux - dx(u) = 0")
problem.add_equation("ut - dt(u) = 0")
problem.add_left_bc("u = 0")
problem.add_right_bc("u = 0")

# Build domain
xb1 = Chebyshev(257, interval=(-1, 0))
xb2 = Chebyshev(257, interval=(0, 1))
x_basis = Compound((xb1, xb2))
domain = Domain([x_basis], np.float64)
x = domain.grid(0)
problem.parameters['c'] = c = 2.5 + 1.5*np.tanh(x/0.02)
order = 1
cut = 257
problem.expand(domain, order=order)
cf = domain.new_field()
cf['g'] = c
cf['c'][order:cut] = 0
cf['c'][cut+order:] = 0


# Build solver
solver = solvers.IVP(problem, domain, timesteppers.MCNAB2)
solver.stop_sim_time = 10
solver.stop_wall_time = 30
solver.stop_iteration = 5000

# Initial conditions
u = solver.state['u']
ux = solver.state['ux']
ut = solver.state['ut']
u['g'] = 2*np.exp(-((x+0.5)/0.02)**2)
u.differentiate(0, out=ux)
ut['g'] = - c * ux['g']

# Plot
liveplot = True
if liveplot:
    fig = plt.figure(1, figsize=(12,10))
    ax1 = fig.add_subplot(211)
    ax2 = fig.add_subplot(212)
    ax1.plot(x, cf['g'], '.', c='0.5')
    scat = ax1.scatter(x, np.copy(u['g']), c=cf['g'])
    ax1.set_xlim(-1, 1)
    ax1.set_ylim(-5, 5)
    ax1.set_xlabel('x')
    ax1.set_ylabel('u')
    line2,= ax2.semilogy(x_basis.elements, np.abs(u['c'])**2, 'r.-')
    ax2.set_xlabel('l')
    ax2.set_ylabel('|u_l|^2')
    title = fig.suptitle('it: %i, t: %f' %(solver.iteration, solver.sim_time))
    ax2.set_ylim(1e-14, 1)
    plt.draw()

spacetime = True
if spacetime:
    u_list = [np.copy(u['g'])]
    t_list = [solver.sim_time]

# Main loop
dt = 5e-4
while solver.ok:
    solver.step(dt)
    if solver.iteration % 30 == 0:
        if spacetime:
            u_list.append(np.copy(u['g']))
            t_list.append(solver.sim_time)
        if liveplot:
            scat.set_offsets(np.array([x, np.copy(u['g'])]).T)
            line2.set_ydata(np.abs(u['c'])**2)
            title.set_text('it: %i, t: %f' %(solver.iteration, solver.sim_time))
            plt.draw()

if spacetime:
    u_array = np.array(u_list)
    t_array = np.array(t_list)
    xmesh, ymesh = quad_mesh(x=x, y=t_array)
    plt.figure(2)
    plt.pcolormesh(xmesh, ymesh, u_array, cmap='RdBu_r')
    plt.axis(pad_limits(xmesh, ymesh))
    plt.colorbar()
    plt.xlabel('x')
    plt.ylabel('t')
    plt.title('Wave Equation')
    plt.savefig('wave.png')

