#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
    Copyright 2010 cloudControl UG (haftungsbeschraenkt)

    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.
"""
import socket
import ssl
import sys
import time
import thread

import argparse

from pycclib.version import __version__ as cclibversion

import cctrl.settings as settings
from cctrl.output import get_version


def main():
    parse_cmdline()


def parse_cmdline():
    """
      We parse the commandline using argparse from
      http://code.google.com/p/argparse/.
    """
    version = get_version(settings.VERSION, cclibversion)
    usage = '%(prog)s [-h, --help] [db-type]'
    description = '%(prog)s ssl secured db tunnel'
    epilog = "And now you're in control"

    parser = argparse.ArgumentParser(
        description=description,
        epilog=epilog,
        usage=usage)

    parser.add_argument(
        '-v',
        '--version',
        action='version',
        version=version)

    parser.add_argument('type', choices=['mysql', 'mongodb'])
    parser.add_argument(
        '--bind',
        default='127.0.0.1:4711',
        help="the ip:port to bind to")
    parser.add_argument(
        '--endpoint',
        default=None,
        help="set tunnel endpoint to host:port")
    parser.set_defaults(func=listen)

    args = parser.parse_args()
    args.func(args)


def listen(args):
    try:
        thread.daemon = True
        thread.start_new_thread(main_thread, (args, ))
        print 'listening on %s...' % args.bind
        while True:
            time.sleep(100)
    except KeyboardInterrupt:
        sys.exit('bye')
    else:
        lock = thread.allocate_lock()
        lock.acquire()
        lock.acquire()


def main_thread(args):
    loc_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    # args.bind should be in ip:port format
    bind = args.bind.partition(':')
    loc_sock.bind((bind[0], int(bind[2])))
    loc_sock.listen(1)
    while True:
        int_sock = loc_sock.accept()[0]
        ext_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        ssl_sock = ssl.wrap_socket(ext_sock, ssl_version=ssl.PROTOCOL_SSLv3)
        dest_host = '%s.cloudcontrolled.com' % args.type
        if args.type == 'mysql':
            dest_port = 3307
        if args.type == 'mongodb':
            dest_port = 27018
        if args.endpoint:
            endpoint = args.endpoint.partition(':')
            dest_host = endpoint[0]
            dest_port = endpoint[2]
        ssl_sock.connect((dest_host, int(dest_port)))
        thread.start_new_thread(forward, (int_sock, ssl_sock))
        thread.start_new_thread(forward, (ssl_sock, int_sock))


def forward(source, destination):
    data = ' '
    while data:
        data = source.recv(1024)
        if data:
            destination.sendall(data)
        else:
            try:
                source.shutdown(socket.SHUT_RD)
                destination.shutdown(socket.SHUT_WR)
            except socket.error:
                pass


if __name__ == "__main__":
    main()
