import sys
import json
import time
import subprocess
import logging
from grandfatherson import to_delete, MONDAY
from dateutil.parser import parse
from datetime import date, datetime, timedelta
from opsaws.awsargparse import AwsArgParser
from opsaws.awscli import AwsCli

logger = logging.getLogger(__name__)


def tagdict2list(tag_dict):
    tags = list()
    for tag in tag_dict:
        tags.append({'Key': tag, 'Value': tag_dict[tag]})
    return tags


def taglist2dict(tags):
    tag_dict = dict()
    for tag in tags:
        tag_dict[tag['Key']] = tag['Value']
    return tag_dict


class EBSSnapshot(AwsCli):
    def describe_volumes(self, filters=None):
        command = 'aws ec2 describe-volumes'
        if filters:
            command += ' --filter ' + json.dumps(filters)
        return self.run(command)['Volumes']

    def describe_snapshots(self, snapshot_ids=None, owner_id='self', filters=None):
        command = 'aws ec2 describe-snapshots'

        if owner_id:
            command += ' --owner-id ' + owner_id
        if filters:
            command += ' --filter ' + json.dumps(filters)
        if snapshot_ids:
            command += ' --snapshot-ids'
            for snapshot_id in snapshot_ids:
                command += ' ' + snapshot_id

        return self.run(command)['Snapshots']

    def add_tag_to_snapshot(self, snapshot_id, tags, dry_run=False):
        logger.info('Add tag to snapshot: {0} tag: {1}'.format(snapshot_id, tags))
        command = 'aws ec2 create-tags --resources ' + snapshot_id
        if dry_run:
            command += ' --dry-run'

        command += ' --tags ' + "'" + json.dumps(tags) + "'"
        return self.run(command)

    def create_snapshot(self, volume_id, instance_id=None, description=None, dry_run=False):
        logger.info(
            'Create snapshot for volume: {0} instance_id: {1}'.format(volume_id, instance_id))
        command = 'aws ec2 create-snapshot --volume-id ' + volume_id
        if dry_run:
            command += ' --dry-run'
        if description:
            command += ' --description "' + description + '"'
        snapshot = self.run(command)

        # fake data for dry_run
        if dry_run:
            snapshot = {'SnapshotId': 'snap-12345678'}

        # snapshot tag dictionary
        snapshot_tag_dict = dict()
        snapshot_tag_dict['TakenFromVolumeID'] = volume_id
        snapshot_tag_dict['TakenAtUnixTime'] = str(time.time())
        snapshot_tag_dict['TakenDate'] = date.today().isoformat()

        # add instance_id information if available
        if instance_id:
            command = 'aws ec2 describe-instances --instance-ids ' + instance_id
            instance = self.run(command)['Reservations'][0]['Instances'][0]
            instance_tag_dict = taglist2dict(instance['Tags'])

            if 'Tenant' in instance_tag_dict:
                snapshot_tag_dict['Tenant'] = instance_tag_dict['Tenant']
            snapshot_tag_dict['TakenWhileAttachedToID'] = instance_id
            snapshot_tag_dict['TakenWhileAttachedToName'] = instance_tag_dict['Name']

        # add tag
        self.add_tag_to_snapshot(snapshot['SnapshotId'], tagdict2list(snapshot_tag_dict),
                                 dry_run=dry_run)

        return snapshot

    def delete_snapshot(self, snapshot_id, dry_run=False):
        logger.info('Delete snapshot: {0}'.format(snapshot_id))
        command = 'aws ec2 delete-snapshot --snapshot_id ' + snapshot_id
        if dry_run:
            command += ' --dry-run'
        result = self.run(command)

        if result['return'] != 'true':
            raise subprocess.CalledProcessError(-1, command)

    def create_snapshot_all_volumes(self, filters=None, dry_run=False):
        volumes = self.describe_volumes(filters)

        snapshots = list()
        for volume in volumes:
            description = volume['VolumeId'] + '-' + date.today().isoformat()
            if 'Attachments' not in volume:
                snapshot = self.create_snapshot(volume['VolumeId'], description=description,
                                                dry_run=dry_run)
            else:
                snapshot = self.create_snapshot(volume['VolumeId'],
                                                volume['Attachments'][0]['InstanceId'],
                                                description=description,
                                                dry_run=dry_run)
            logger.info('Snapshot_id = ' + snapshot['SnapshotId'])
            snapshots.append(snapshot)

        return snapshots

    def rotate_snapshot(self, retention, dry_run=False):

        # parse the retention rule
        days, weeks, months = map(int, retention.split(':'))

        # get all snapshot and start to delete
        snapshots = self.describe_snapshots()
        for snapshot in snapshots:
            t1 = parse(snapshot['StartTime'])
            backups = [datetime(t1.year, t1.month, t1.day)]
            if to_delete(backups, days=days, weeks=weeks, months=months, firstweekday=MONDAY):
                self.delete_snapshot(snapshot['SnapshotId'], dry_run=dry_run)


def add_backup_argument_group(parser):
    group = parser.add_argument_group('backup options')
    group.add_argument('--retention', default='7:4:1',
                       help="days:weeks:months worth of backups to keep")
    group.add_argument('--tenants', default=None, action="append")
    group.add_argument('--dry-run', action="store_true",
                       help="Don't execute snapshot run or delete run")
    group.add_argument('--time-limit', type=int, default=3600)


def parse_options():
    parser = AwsArgParser()
    add_backup_argument_group(parser)
    parser.add_argument('-d', '--debug', action='store_true', help='Turn on debug info',
                        default=False)

    options = parser.parse_args()

    if options.debug:
        logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
    else:
        logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')

    logging.debug(options)

    return options


def validate_snapshots(snapshots, state='completed'):
    for snapshot in snapshots:
        if snapshot['State'] != state:
            return False
    return True


def main():
    options = parse_options()

    # setup ebs_backup AWS connection information
    ebs_backup = EBSSnapshot(options.conf_file, options.profile, options.region)

    # create all snapshot
    snapshots = ebs_backup.create_snapshot_all_volumes(dry_run=options.dry_run)

    # delete old retention
    if options.retention:
        ebs_backup.rotate_snapshot(options.retention)

    # check if all snapshot ready
    start_time = time.time()
    snapshot_ids = list()

    if not options.dry_run:
        while True:
            time.sleep(10)
            snapshot_ids = [s['SnapshotId'] for s in snapshots if s['State'] != 'completed']
            logger.info('Checking snapshot: {0}'.format(snapshot_ids))
            snapshots = ebs_backup.describe_snapshots(snapshot_ids)
            if validate_snapshots(snapshots):
                snapshots = None
                break
            elif time.time() - start_time > options.time_limit:
                snapshot_ids = [s['SnapshotId'] for s in snapshots if s['State'] != 'completed']
                break
    else:
        snapshots = None

    if not snapshots:
        msg = 'OK: All snapshot created'
        logger.info(msg)
        print msg
        return 0
    else:
        msg = 'WARNING: Some snapshot is not ready. {0}'.format(snapshot_ids)
        logger.info(msg)
        print >> sys.stderr, msg
        return 1

if __name__ == "__main__":
    sys.exit(main())