"""
Author:      www.tropofy.com

Copyright 2013 Tropofy Pty Ltd, all rights reserved.

This source file is part of Tropofy and govered by the Tropofy terms of service
available at: http://www.tropofy.com/terms_of_service.html

This source file is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE. See the license files for details.
"""


from sqlalchemy import create_engine
from sqlalchemy.sql import select


def create_in_memory_db(metadata, tables=None, **kw):
    """Creates db and empty tables.

    :param metadata: sqla metadata related to the tables
    :param tables: (Optional) list of sqla table objects. If None, all tables in metadata are created.
    :return: sqla engine for in memory db
    """
    in_memory_engine = create_engine('sqlite:///:memory:', echo=True)

    kwargs = {'bind': in_memory_engine}
    if tables is not None:
        kwargs.update({'tables': tables})
    metadata.create_all(**kwargs)
    return in_memory_engine


def copy_data_between_similar_dbs(source_db_con, destination_db_con, table_matching, data_set_id=None, **kw):
    """Copy's data between similar tables in two db's.

    Requires that rows of data data queried from the the source table, can be directly inserted into the destination table.
     Copies data between tables in the order provided in the table_matching list. Use this ordering to ensure DB constraints can be satisfied on the destination db during copy.

    :param source_db_con: sqla connection object to source db.
    :param destination_db_con: sqla connection object to in memory db.
    :param table_matching: List of sqla table pairs. Each element consists of (source_table, destination_table).
    :param data_set_id: (Optional) data_set id. If provided, will assume column 'data_set_id' exists in source data, and will filter on it.
    """
    for source_table, destination_table in table_matching:
        s = select([source_table])
        if data_set_id:
            s = s.where(source_table.c.data_set_id == data_set_id)

        all_rows = source_db_con.execute(s).fetchall()  # TODO: Filter by data_set_id

        destination_db_con.execute(
            destination_table.insert(),
            all_rows
        )
