# -*- coding: utf-8 -*-
# Copyright (C) 2011  Alibaba Cloud Computing
# 
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
# 
import random
import threading
import time
import yaml
from collections import OrderedDict
from multiprocessing.pool import ThreadPool

from exceptions import *
from utils import *


class Job(object):
    _MEGABYTE = 1024 * 1024

    NumberThread = 16
    NumRetry = 3

    ResponseDataParser = (('Action', 'action', None),
                          ('ArchiveContentEtag', 'etag', None),
                          ('ArchiveId', 'archive_id', None),
                          ('ArchiveSizeInBytes', 'archive_size', 0),
                          ('Completed', 'completed', False),
                          ('CompletionDate', 'completion_date', None),
                          ('CreationDate', 'creation_date', None),
                          ('InventorySizeInBytes', 'inventory_size', 0),
                          ('JobDescription', 'description', None),
                          ('JobId', 'id', None),
                          ('StatusCode', 'status_code', None),
                          ('StatusMessage', 'status_message', None))

    def __init__(self, vault, response):
        self.vault = vault
        self.parts = OrderedDict()
        self._update(response)

    def __repr__(self):
        return 'Job: %s' % self.id

    def _parse_etag(self):
        md5sum = self.etag.split('-')
        return md5sum[0], int(md5sum[1])

    def _is_multipart(self):
        return '-' in self.etag

    def _update(self, response):
        for response_name, attr_name, default in self.ResponseDataParser:
            value = response.get(response_name)
            setattr(self, attr_name, value or default)

        self.parts = OrderedDict()
        if self.archive_size > 0:
            if self._is_multipart():
                _, part_size = self._parse_etag()
                for byte_range in calc_ranges(part_size, self.archive_size):
                    self.parts[byte_range] = None
            else:
                self.parts[(0, self.archive_size - 1)] = None

    def update_status(self):
        response = self.vault.api.describe_job(self.vault.id, self.id)
        self._update(response)

    def _check_status(self, block=False):
        self.update_status()
        if not block and not self.completed:
            raise DownloadArchiveError('Job not ready')
        elif block:
            while not self.completed:
                print time.ctime() + ' Job status: ' + self.status_code
                time.sleep(random.randint(60, 300))
                self.update_status()

    def download_by_range(self, byte_range, file_path=None, file_obj=None, chunk_size=None, block=True):
        self._check_status(block)

        chunk_size = chunk_size or self._MEGABYTE
        f = open_file(file_path=file_path, file_obj=file_obj, mode='w+')
        offset = f.tell() if file_obj is not None else 0
        size = range_size(byte_range)

        try:
            for cnt in xrange(self.NumRetry):
                pos = 0
                try:
                    response = self.vault.api.get_job_output(
                        self.vault.id, self.id, byte_range=byte_range)
                    while True:
                        data = response.read(chunk_size)
                        if not data:
                            break
                        f.write(data)
                        pos += len(data)

                    if pos == size:
                        return
                except OASServerError as e:
                    if e.type != 'client':
                        f.seek(offset)
                        continue
                    else:
                        raise e
                except IOError:
                    f.seek(offset)
                    continue
        finally:
            f.flush()
            if f is not file_obj:
                f.close()

        raise DownloadArchiveError(
            'Incomplete download: %d / %d' % (pos, size))

    def download_to_file(self, file_path, chunk_size=None, block=True):
        self._check_status(block)

        chunk_size = chunk_size or self._MEGABYTE
        if self.inventory_size > 0:
            return self.download_by_range(file_path=file_path, byte_range=(0, self.inventory_size - 1))

        if self._is_multipart():
            md5sum_comp, part_size = self._parse_etag()
        else:
            md5sum_comp, part_size = self.etag, self.archive_size

        file_dir, file_name = os.path.split(file_path)
        log_file = os.path.join(file_dir, file_name + '.oas')
        try:
            self._load(log_file)
        except IOError:
            pass

        log_lock = threading.RLock()
        file_lock = threading.RLock()

        def download_part(byte_range):
            time.sleep(random.randint(256, 4096) / 1000.)
            for cnt in xrange(self.NumRetry):
                try:
                    response = self.vault.api.get_job_output(
                        self.vault.id, self.id, byte_range=byte_range)

                    md5 = hashlib.md5()
                    offset = byte_range[0]
                    while True:
                        data = response.read(chunk_size)
                        if not data:
                            break
                        md5.update(data)

                        with file_lock:
                            f.seek(offset)
                            f.write(data)
                            f.flush()
                            offset += len(data)

                    if offset == byte_range[1] + 1:
                        self.parts[byte_range] = md5.hexdigest().upper()
                        with log_lock:
                            self._save(log_file)
                        print time.ctime() + (' Range %d-%d download success.' % byte_range)
                        return
                except IOError:
                    continue
            print time.ctime() + (' Range %d-%d download failed.' % byte_range)

        f = open_file(file_path=file_path, mode='w+')
        with f:
            print time.ctime() + ' Start download.'
            pool = ThreadPool(
                processes=min(self.NumberThread, len(self.parts)))
            pool.map(
                download_part, [byte_range for byte_range, md5sum in self.parts.items() if md5sum is None])

            size = self.size_completed
            if size != self.archive_size:
                raise DownloadArchiveError(
                    'Incomplete download: %d / %d' % (size, self.archive_size))

            md5sum_list = [md5sum for _, md5sum in self.parts.items()]
            md5sum_actual = compute_combine_md5(md5sum_list) \
                if self._is_multipart() else md5sum_list[0]
            if md5sum_actual != md5sum_comp:
                raise HashDoesNotMatchError(
                    'MD5 not match: %s / %s (actual)' % (md5sum_comp, md5sum_actual))
            elif os.path.exists(log_file):
                os.remove(log_file)

        print time.ctime() + ' Download finish.'

    def _save(self, file_path):
        f = open_file(file_path=file_path, mode='w+')
        with f:
            f.write(yaml.dump(self.parts, default_flow_style=True))

    def _load(self, file_path):
        f = open_file(file_path=file_path)
        with f:
            self.parts = yaml.load(f)

    @property
    def size_completed(self):
        size_list = [range_size(byte_range)
                     for byte_range, md5sum in self.parts.items() if md5sum is not None]
        return sum(size_list)
