#!/usr/bin/python

"""
Copyright (c) 2011-2013 Yubico AB
All rights reserved.

Redistribution and use in source and binary forms, with or
without modification, are permitted provided that the following
conditions are met:

    1. Redistributions of source code must retain the above copyright
    notice, this list of conditions and the following disclaimer.
    2. Redistributions in binary form must reproduce the above
    copyright notice, this list of conditions and the following
    disclaimer in the documentation and/or other materials provided
    with the distribution.
    
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""



#import lib
import os
import sys
import io
import hashlib
import re
import time
import argparse

try:
    from sqlalchemy import create_engine, Table, MetaData 
except ImportError: 
    print "Please, install SQLAlchemy. Follow instruction in README.db.import-export file."
    sys.exit(1)


from os.path import abspath
sys.path.append('Lib')
from pyhsm.util import key_handle_to_int
from pyhsm.yubikey import modhex_decode
import pyhsm.aead_cmd 


##########################
# Functions Declarations #
##########################


#
# reconstruct a database url with the valid username and password
#
def build_database_url(databaseUrl, dbusername, dbpassword):

    #urltype mysql://user:pass@localhost/dbname
    result = re.search(r'@', databaseUrl)
    if result:
        #build a valid url
        result = re.sub(r'://.*?@', '://'+dbusername+":"+dbpassword+"@", databaseUrl)
        if result:
            #return a valid url
            return result
        else:
            #unrecognized url format
            print "FATAL: Malformed database url. A valid url looks like this: mysql://username:password@hostname/dbname"
            sys.exit(1)
            
    else:
        #fix the url with username and password
        result = re.search(r'(.*?://)(.*?/)(.+)', databaseUrl)
        if result:
            databaseUrl = result.group(1)+dbusername+":"+dbpassword+"@"+result.group(2)+result.group(3)
            return databaseUrl

        else:
            #unrecognized url format
            print "FATAL: Malformed database url. A valid url looks like this: mysql://username:password@hostname/dbname"
            sys.exit(1)



#
# extract keyhandle value from the path
#
def extract_keyhandle(path, filepath):

    keyhandle = filepath.lstrip(path)
    keyhandle = keyhandle.split("/")
    return keyhandle[0]



#
#insert_query: this functions read the response fields and creates sql query. then inserts everything inside the database
#
def insert_query(publicId, aead, keyhandle, aeadobj):

    #turn the keyhandle into an integer
    keyhandle = key_handle_to_int(keyhandle)
    if not keyhandle == aead.key_handle:
        print "WARNING: keyhandle does not match aead.key_handle"
    
    #creates the query object
    sql = aeadobj.insert().values(public_id=publicId, keyhandle=aead.key_handle, nonce=aead.nonce, aead=aead.data)
    #insert the query
    result = connection.execute(sql)

    return result

#################################
# END of functions declariation #
#################################




#######################
#                     #
# Initialization Area #
#                     #
#######################

parser = argparse.ArgumentParser(description='Import AEADs into the database')

parser.add_argument('path', action="store", type=str)
parser.add_argument('dburl', action="store")


args = vars(parser.parse_args())




if len(sys.argv) != 3:
    print("\nUsage: python import_aeads.py /path/to/keyhandle/ database_url\ni.e. python import_aeads.py /root/aeads/ mysql://root:password@localhost:3306/database_name")
    sys.exit(2)
    
if not os.path.isdir(sys.argv[1]):
    print("\nInvalid path, check your spelling.\n")
    sys.exit(2)
                        
                        
#set the path
path = args['path']
#mysql url
databaseUrl = args['dburl']

try:
    #check database connection
    engine = create_engine(databaseUrl)

    #SQLAlchemy voodoo
    metadata = MetaData()
    aeadobj = Table('aead_table', metadata, autoload=True, autoload_with=engine)
    connection = engine.connect()
except:
    print "\nPlease provide a valid database username"
    dbusername = raw_input(">>")  
    print "\nPlease provide a valid database password:"
    dbpassword = raw_input(">>")

    try:
        #rebuild the url with given username and password
        databaseUrl = build_database_url(databaseUrl, dbusername, dbpassword)
        
        #check database connection
        engine = create_engine(databaseUrl)
   
        #SQLAlchemy voodoo
        metadata = MetaData()
        aeadobj = Table('aead_table', metadata, autoload=True, autoload_with=engine)
        connection = engine.connect()
        
    
    except: 
         print "FATAL: Please provide valid username and password and a valid address i.e. mysql://localhost/dbname"
         sys.exit(1)


 



####################
# Computation area #
####################

#walk across the file system and build insert query for: public_id, keyhandle, aead
aead = None
nonce = None
key_handle = None

#build the aead object from file
pyhsm.aead_cmd.YHSM_GeneratedAEAD.load(aead, filepath)

for root, subFolders, files in os.walk(path):
    if files:
        #build file path
        filepath = os.path.join(root,files[0])

        #instantiate a new aead object
        aead = pyhsm.aead_cmd.YHSM_GeneratedAEAD(nonce, key_handle, aead)
    
        #build the aead object from file
        #pyhsm.aead_cmd.YHSM_GeneratedAEAD.load(aead, filepath)
    
        #set the public_id
        public_id = str(files[0])
        

        #extract the key handle from the path
        keyhandle = extract_keyhandle(path, filepath)
                
        
        #check it is old format aead
        if not aead.nonce:
            #configure values for oldformat
             aead.nonce = pyhsm.yubikey.modhex_decode(public_id).decode('hex')
             aead.key_handle = key_handle_to_int(keyhandle)
        
        if not insert_query(public_id, aead, keyhandle, aeadobj):
            print "WARNING: could not insert %s" % publicId
            



#close sqlalchemy
connection.close()

#exit without error
sys.exit(0)
