from nox.elements import Edge
from nox.serialization import BaseSerializationRule, deserialize, nox_serializer


from .base_elements import BaseElement
from .vertex import Vertex


class ElementInstanceSerializer(BaseSerializationRule):
    FIELD_TYPE = '_type'
    FIELD_ID = '_id'

    FIELD_OUT_V = '_outV'
    FIELD_IN_V = '_inV'
    FIELD_LABEL = '_label'

    FIELD_PROPERTIES = '_properties'

    TYPE_EDGE = 'edge'
    TYPE_VERTEX = 'vertex'

    def is_applicable_for_serialization(self, obj):
        return isinstance(obj, BaseElement)

    def serialize(self, element):
        """
        :param BaseElement element: Nox graph element to serialize
        """
        return element.get_id()

    def is_applicable_for_deserialization(self, obj):
        is_dict = isinstance(obj, dict)
        return is_dict and (self._is_vertex(obj) or self._is_edge(obj))

    def _is_vertex(self, element_dict):
        """
        :param dict element_dict:
        """
        try:
            return element_dict['_type'] == 'vertex' and element_dict['_id'] is not None
        except KeyError:
            return False

    def _is_edge(self, element_dict):
        """
        :param dict element_dict:
        """
        try:
            edge_fields = {self.FIELD_ID, self.FIELD_OUT_V, self.FIELD_IN_V}
            return element_dict[self.FIELD_TYPE] == 'edge' and all(
                element_dict[field] is not None for field in edge_fields)
        except KeyError:
            return False

    def deserialize(self, element_dict):
        """
        :param dict element_dict:
        """
        if self._is_vertex(element_dict):
            return self._deserialize_vertex(element_dict)
        elif self._is_edge(element_dict):
            return self._deserialize_edge(element_dict)

        raise ValueError('Wrong element_dict for ElementInstanceSerializer!')

    def _deserialize_vertex(self, element_dict):
        """
        :rtype: Vertex
        """
        vertex = self._deserialize_as_vertex_trait(element_dict)
        if vertex is None:
            vertex = self._deserialize_as_general_vertex(element_dict)

        return vertex

    def _deserialize_as_vertex_trait(self, element_dict):
        """
        :rtype: VertexTrait|None
        """
        properties = self._get_deserialized_properties(element_dict)
        if not properties:
            return

        property_keys = set(properties)

        from nox.abstractions.vertex_trait import VertexTrait
        known_traits = self._get_known_traits(VertexTrait)
        best_trait = None
        for trait in known_traits:
            # TODO: Make more advanced - add distance checking
            if trait.get_allowed_keys().issuperset(property_keys):
                best_trait = trait
                break

        if best_trait is None:
            return

        vid = int(element_dict[self.FIELD_ID])
        return best_trait(vid, **properties)

    def _deserialize_as_general_vertex(self, element_dict):
        """
        :rtype: Vertex
        """
        vid = int(element_dict[self.FIELD_ID])
        return Vertex(vid, **self._get_deserialized_properties(element_dict))

    def _deserialize_edge(self, element_dict):
        """
        :rtype: Edge
        """
        edge = self._deserialize_as_edge_trait(element_dict)
        if edge is None:
            edge = self._deserialize_as_general_edge(element_dict)

        return edge

    def _deserialize_as_edge_trait(self, element_dict):
        """
        :rtype: EdgeTrait|None
        """
        eid = element_dict[self.FIELD_ID]
        out_v = element_dict[self.FIELD_OUT_V]
        in_v = element_dict[self.FIELD_IN_V]
        label = element_dict[self.FIELD_LABEL]

        from nox.abstractions.edge_trait import EdgeTrait
        known_traits = self._get_known_traits(EdgeTrait)
        known_traits_with_matching_labels = [t for t in known_traits if t.__label__ == label]

        properties = self._get_deserialized_properties(element_dict)
        if known_traits_with_matching_labels and not properties:
            return self._build_edge_trait(known_traits_with_matching_labels[0], eid, out_v, in_v, label)

        deserialized_trait = self._try_deserialize_edge_trait(eid, out_v, in_v, label, properties,
                                                              known_traits_with_matching_labels)
        if deserialized_trait is None:
            deserialized_trait = self._try_deserialize_edge_trait(eid, out_v, in_v, label, properties, known_traits)

        return deserialized_trait

    def _try_deserialize_edge_trait(self, eid, out_v, in_v, label, properties, trait_list):
        property_keys = set(properties)
        best_trait = None
        for trait in trait_list:
            # TODO: Make more advanced - add distance checking
            if trait.get_allowed_keys().issuperset(property_keys):
                best_trait = trait
                break

        if best_trait is None:
            return

        return self._build_edge_trait(best_trait, eid, out_v, in_v, label, properties)

    def _build_edge_trait(self, trait_class, eid, out_v, in_v, label, properties):
        return trait_class(eid, out_v, in_v, label, **properties)

    def _deserialize_as_general_edge(self, element_dict):
        """
        :rtype: Edge
        """
        eid = element_dict[self.FIELD_ID]
        out_v = element_dict[self.FIELD_OUT_V]
        in_v = element_dict[self.FIELD_IN_V]
        label = element_dict[self.FIELD_LABEL]
        return Edge(eid, out_v, in_v, label, **self._get_deserialized_properties(element_dict))

    def _get_known_traits(self, trait_class):
        """
        :rtype: set[BaseTrait]
        """
        return set(trait_class.__subclasses__())

    def _get_deserialized_properties(self, element_dict):
        properties = element_dict.get(self.FIELD_PROPERTIES, {})
        return {k: deserialize(v) for k, v in properties.iteritems()}

nox_serializer(ElementInstanceSerializer())
