import copy as _copy
import os as _os
import logging

from graphlab.connect.aws._ec2 import DEFAULT_INSTANCE_TYPE, get_credentials
import graphlab as _gl

class Enum(set):
    def __getattr__(self, name):
        if name in self:
            return name
        raise AttributeError

ENVIRONMENT = Enum(['Local', 'LocalAsynchronous', 'EC2', 'Hadoop'])
__LOGGER__ = logging.getLogger(__name__)

class Environment(object):
    """Base class for handling execution environments"""

    def __init__(self, name, env):
        if name is None:
            raise TypeError("Name is required when creating an Environment.")

        if not isinstance(name, str):
            raise TypeError("Name is required to be a string.")

        if env not in ENVIRONMENT:
            raise Exception("Invalid type of environment. Environments must be one of %s" % ENVIRONMENT)

        self._session = _gl.deploy._default_session
        self.params = dict()
        self.name = name
        self.env = env
        self._modified_since_last_saved = None
        self._typename = Environment.__name__
        
        self._session.register(self)

    def get_max_degree_of_parallelism(self):
        raise NotImplementedError

    def set_params(self, params):
        self.params.update(params)
        self._set_dirty_bit()
        return self

    def _set_dirty_bit(self):
        if self._modified_since_last_saved is not None and not self._modified_since_last_saved:
            self._modified_since_last_saved = True

    def clone(self):
        return _copy.deepcopy(self)

    def save(self):
        self._session.save(self, typename=self._typename)
        __LOGGER__.info("Environment saved successfully.")

    def __getstate__(self):
        odict = dict.copy(self.__dict__)
        if hasattr(odict, '_session'):
            del odict._session
        return odict

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        return "Environment: [name: %s, type: %s]" % (self.name, self.env)


class LocalAsynchronous(Environment):
    """Environment for executing jobs in the background (asynchronously) on this machine."""
    def __init__(self, name):
        super(LocalAsynchronous, self).__init__(name, ENVIRONMENT.LocalAsynchronous)

    def __str__(self):
        return "Local (asynchronous): [name: %s, params: %s]" % (self.name, self.params)

    def get_max_degree_of_parallelism(self):
        return 1

class LocalEnvironment(Environment):
    """Environment for executing jobs on this machine."""
    def __init__(self, name):
        super(LocalEnvironment, self).__init__(name, ENVIRONMENT.Local)

    def __str__(self):
        return "Local: [name: %s, params: %s]" % (self.name, self.params)

    def get_max_degree_of_parallelism(self):
        return 1

class EC2Environment(Environment):
    """Environment for executing jobs on AWS EC2."""
    def __init__(self, name, s3_log_folder_path, aws_access_key = None, aws_secret_key = None,
                 num_hosts = 1, region = 'us-west-2', instance_type = DEFAULT_INSTANCE_TYPE,
                 security_group = None, tags = None, CIDR_rule = None):
        """
        Parameters
        ----------
        name : string
            The name for the environment.

        s3_log_folder_path : string
            The S3 path to folder where the log file will be saved. Format should be: 's3://bucket_name/path/'
            
        aws_access_key : string, optional
            The AWS Access Key to use to launch the host(s). This parameter must be set or the
            AWS_ACCESS_KEY_ID environment variable must be set. 

        aws_secret_key : string, optional
            The AWS Secret Key to use to launch the host(s).  This parameter must be set or the
            AWS_SECRET_ACCESS_KEY environment variable must be set. 

        num_hosts : int, optional
            The number of EC2 host(s) to use for this environment.

            NOTE: only graphlab.toolkits.model_parameter_search currently supports using more than
            one host; using more than one host for anything else will yield no additional benefit from
            using one host.

        region : string, optional
            The AWS region in which to launch your instance. Default is 'us-west-2'.

        instance_type : string, optional
            The EC2 instance type to launch, default is m3.xlarge. We support all instance types
            except: 't2.micro', 't2.small', 't2.medium' and 'm3.medium'. For a list of instance_types,
            please refer to `here <http://aws.amazon.com/ec2/instance-types/#Instance_Types>`_.

        security_group : string, optional
            The name of the security group for the EC2 instance to use.

        tags : dict, optional
            A dictionary containing the name/value tag pairs to be assigned to the instance. If you want
            to create only a tag name, the value for that tag should be the empty string (i.e. ''). 
            In addition to these specified tags, a 'GraphLab' tag will also be assigned. 

        CIDR_rule : string or list of strings, optional
            The Classless Inter-Domain Routing rule(s) to use for the instance. Useful for restricting the IP
            Address Range for a client. Default is no restriction. If you specify CIDR_rule(s), you must also
            specify a security group to use.
        """
        super(EC2Environment, self).__init__(name, ENVIRONMENT.EC2)

        if(aws_access_key is not None and aws_secret_key is not None):
            self.aws_access_key, self.aws_secret_key = aws_access_key, aws_secret_key
        else:
            assert(aws_access_key is None and aws_secret_key is None)
            assert(get_credentials())

        try:
            self.num_hosts = int(num_hosts)
            if num_hosts <= 0:
                raise TypeError("num_hosts must be a positive integer")
        except ValueError:
            raise TypeError("num_hosts is not an integer")

        if not isinstance(region, str):
            raise TypeError('region must be a string')
        
        if not isinstance(instance_type, str):
            raise TypeError('instance type must be a string')
        
        if security_group and not isinstance(security_group, str):
            raise TypeError('instance type must be a string')

        if tags and not isinstance(security_group, dict):
            raise TypeError('tags must be a dict')

        if CIDR_rule and not (isinstance(CIDR_rule, str) or isinstance(CIDR_rule, list)):
            raise TypeError('CIDR_rule must be a string or a list of strings')

        (self.s3_bucket, self.s3_log_folder_path) = EC2Environment._parse_s3_path(s3_log_folder_path)

        self.region = region
        self.instance_type = instance_type

        self.security_group = security_group
        self.tags = tags
        self.CIDR_rule = CIDR_rule


    @staticmethod
    def _parse_s3_path(s3_log_path):
        if not s3_log_path.startswith('s3://'):
            raise TypeError('s3_log_path must be an S3 path: s3://bucket_name/path')
        
        tokens = s3_log_path.split('/')
        bucket_name = tokens[2]

        s3_folder = '/'.join(tokens[3:])
        if s3_folder and s3_folder[-1] != '/':
            s3_folder = s3_folder + '/'

        return (bucket_name, s3_folder)


    def __str__(self):
        return "EC2: [name: %s, access_key: %s, instance_type: %s, region: %s, params: %s, log folder: s3://%s/%s]" % \
               (self.name, self.aws_access_key if hasattr(self, 'aws_access_key') else 'None', 
                self.instance_type, self.region, self.params, self.s3_bucket, self.s3_log_folder_path)

    def get_max_degree_of_parallelism(self):
        return self.num_hosts


class HadoopEnvironment(Environment):
    """
    Environment for running from a Hadoop 2 cluster
    Requires 'hadoop' and 'yarn' commands to be available on the local path.
    """

    def __init__(self, name, config_dir=None, memory_mb=4096, virtual_cores=2, gl_source=None):
        """
        Parameters
        ----------
         name: string
            The name of the HadoopEnvironment
         config_dir: string, optional
            The location, on the local filesystem, of the hadoop config directory.
         memory_mb: int, optional
            The memory in MB required for job execution.  Default is 4096
         gl_source: string, optional
            Options:
                None || 'none': The GraphLab client will download and package GraphLab binaries
                    and all required dependencies and place them on hdfs.  virtualenv-2.7 is required on the
                    path.
                <hdfs url> : The GraphLab client will use a tar.gz file available on hdfs as the source
                for GraphLab binaries and all dependencies.
                'native' : The GraphLab client will use a native GraphLab installation on the hadoop cluster.
            The location of a tar.gz containing GraphLab and all its requirements.

        """
        super(HadoopEnvironment, self).__init__(name, ENVIRONMENT.Hadoop)
        if not (gl_source in ['none', None, 'native', 'Native'] or gl_source.startswith('hdfs://')):
            raise TypeError("gl_source must be 'Native', None, or HDFS URL.")

        if not gl_source or gl_source=='none':
            self.path = 'none'
        else:
            self.path = gl_source

        if not virtual_cores >=2:
            raise TypeError("Virtual cores must be greater than or equal to two.")

        self.config_dir = config_dir
        self.container_memory = memory_mb
        self.vcores = virtual_cores

    def set_graphlab_source(self, path):
        # TODO valid values are "none", an hdfs path, a local os path, or native
        # if not _os.path.exists(path) or not _os.path.isfile(path):
        #     raise Exception("Invalid Path, must be an existing file. "
        #                     "Should point to GraphLab Create egg (.tar.gz file)")
        if not (path in ['none', None, 'native', 'Native'] or path.startswith('hdfs://')):
            raise TypeError("path must be 'Native', None, or HDFS URL.")
        self.path = path
        self._set_dirty_bit()

    def get_container_memory(self):
        return self.container_memory

    def set_container_memory(self, memory_mb):
        self.container_memory = memory_mb
        self._set_dirty_bit()

    def get_config_dir(self):
        return self.config_dir

    def set_config_dir(self, config_dir):
        self.config_dir = config_dir
        self._set_dirty_bit()

    def get_graphlab_source(self):
        return str(self.path)

    def set_virtual_cores(self, num_cores):
        self.vcores = num_cores
        self._set_dirty_bit()

    def get_virtual_cores(self):
        return str(self.vcores)

    def __str__(self):
        return "Hadoop: [name: %s, config_dir: %s, container_memory: %s," \
               "virtual_cores: %s, gl_source: %s]" % (self.name,
                self.config_dir, self.container_memory, self.vcores, self.get_graphlab_source())

