Source code for sisl._category

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

import logging
from abc import ABCMeta, abstractmethod
from collections import ChainMap, defaultdict
from collections.abc import Iterable
from functools import lru_cache, singledispatchmethod, wraps

from ._internal import set_module

__all__ = ["Category", "CompositeCategory", "NullCategory"]
__all__ += ["AndCategory", "OrCategory", "XOrCategory"]
__all__ += ["InstanceCache"]

_log = logging.getLogger(__name__)


class InstanceCache:
    """Wraps an instance to cache *all* results based on `functools.lru_cache`

    Parameters
    ----------
    obj : object
        the object to get cached results
    lru_size : int or dict, optional
        initial size of the lru cache. For integers this
        is the default size of the cache, for a dictionary
        it should return the ``maxsize`` argument for `functools.lru_cache`
    no_cache : searchable (list or dict)
        a list-like (or dictionary) for searching for
        methods that don't require caches (e.g. small methods)
    """

    def __init__(self, obj, lru_size=1, no_cache=None):
        self.__obj = obj

        # Handle user input for lru_size
        if isinstance(lru_size, defaultdict):
            # fine, user did everything good
            self.__lru_size = lru_size
        elif isinstance(lru_size, dict):
            default = lru_size.pop("default", 1)
            self.__lru_size = ChainMap(lru_size, defaultdict(lambda: default))
        else:
            self.__lru_size = defaultdict(lambda: lru_size)

        if no_cache is None:
            self.__no_cache = []
        else:
            self.__no_cache = no_cache

    def __getattr__(self, name):
        attr = getattr(self.__obj, name)
        # Check if the attribute has the cached functionality
        try:
            attr.cache_info()
        except AttributeError:
            # Fix it and set it to this one
            if name in self.__no_cache:
                # We have to make it cacheable
                maxsize = self.__lru_size[name]
                if maxsize != 0:
                    attr = wraps(attr)(lru_cache(maxsize)(attr))

        # offload the attribute to this class (to minimize overhead)
        object.__setattr__(self, name, attr)
        return attr


class CategoryMeta(ABCMeta):
    """
    Metaclass that defines how category classes should behave.
    """

    def __call__(cls, *args, **kwargs):
        """
        If a category class is called, we will attempt to instantiate it.

        However, it may be that this is a parent class (e.g. `AtomCategory`)
        that does not make sense to instantiate. Since these classes are abstract,
        they will raise an error that we will use to build the categories that the user
        requested.

        Examples
        -----------

        >>> AtomZ(6) # returns an AtomZ category with the Z parameter set to 6
        >>> AtomCategory(Z=6) # returns exactly the same since it uses `AtomCategory.kw`
        """
        try:
            return super().__call__(*args, **kwargs)
        except TypeError as e:
            if len(args) == 0:
                return cls.kw(**kwargs)
            # If args were provided, the user probably didn't want to use the category builder,
            # so we are going to let the exception be raised
            raise e


@set_module("sisl.category")
class Category(metaclass=CategoryMeta):
    r"""A category"""

    __slots__ = ("_name", "_wrapper")

    def __init__(self, name=None):
        if name is None:
            self._name = self.__class__.__name__
        else:
            self._name = name

    @property
    def name(self):
        r"""Name of category"""
        return self._name

    @name.setter
    def name(self, name):
        r"""Override the name of the categorization"""
        self._name = name

    @classmethod
    @abstractmethod
    def is_class(cls, name, case=True):
        r"""Query whether `name` matches the class name by removing a prefix `kw`

        This is important to ensure that users match the full class name
        by omitting the prefix returned from this method.

        This is an abstract method to ensure sub-classes of `Category`
        implements it.

        For instance:

        .. code::

            class MyCategory(Category):

                @classmethod
                def is_class(cls, name):
                     # strip "My" and do comparison
                     return cl.__name__.lower()[2:] == name.lower()

        would enable one to compare against the *base* category scheme.

        This has the option to search case-sensitivity or not.
        """
        pass

    @classmethod
    def kw(cls, **kwargs):
        """Create categories based on keywords

        This will search through the inherited classes and
        return and & category object for all keywords.

        Since this is a class method one should use this
        on the base category class in the given section
        of the code.
        """

        subcls = set()
        work = [cls]
        while work:
            parent = work.pop()
            for child in parent.__subclasses__():
                if child not in subcls:
                    subcls.add(child)
                    work.append(child)

        del work

        def get_cat(cl, args):
            if isinstance(args, dict):
                return cl(**args)
            return cl(args)

        # Now search keywords and create category
        cat = None
        for key, args in kwargs.items():
            found = None
            # First search case-sensitive
            for cl in subcls:
                if cl.is_class(key):
                    if found:
                        raise ValueError(
                            f"{cls.__name__}.kw got a non-unique argument for category name:\n"
                            f"    Searching for {key} and found matches {found.__name__} and {cl.__name__}."
                        )
                    found = cl

            if found is None:
                for cl in subcls:
                    if cl.is_class(key, case=False):
                        if found:
                            raise ValueError(
                                f"{cls.__name__}.kw got a non-unique argument for category name:\n"
                                f"    Searching for {key} and found matches {found.__name__.lower()} and {cl.__name__.lower()}."
                            )
                        found = cl

            if found is None:
                raise ValueError(
                    f"{cls.__name__}.kw got an argument for category name:\n"
                    f"    Searching for {key} but found no matches."
                )

            if cat is None:
                cat = get_cat(found, args)
            else:
                cat = cat & get_cat(found, args)

        return cat

    @abstractmethod
    def categorize(self, *args, **kwargs):
        r"""Do categorization"""
        pass

    def __str__(self):
        r"""String representation of the class (non-distinguishable between equivalent classifiers)"""
        return self.name

    def __repr__(self):
        r"""String representation of the class (non-distinguishable between equivalent classifiers)"""
        return self.name

    @singledispatchmethod
    def __eq__(self, other):
        """Comparison of two categories, they are compared by class-type"""
        # This is not totally safe since composites *could* be generated
        # in different sequences and result in the same boolean expression.
        # This we do not check and thus are not fool proof...
        # The exact action also depends on whether we are dealing with
        # an And/Or/XOr operation....
        # I.e. something like
        # (A & B & C) != (A & C & B)
        # (A ^ B ^ C) != (C ^ A ^ B)
        if isinstance(self, CompositeCategory):
            if isinstance(other, CompositeCategory):
                return self.__class__ is other.__class__ and (
                    self.A == other.A
                    and self.B == other.B
                    or self.A == other.B
                    and self.B == other.A
                )
            # if neither is a compositecategory, then they cannot
            # be the same category
            return False
        if self.__class__ != other.__class__:
            return False
        return self == other

    @__eq__.register(Iterable)
    def _(self, other):
        return [self.__eq__(o) for o in other]

    def __ne__(self, other):
        eq = self == other
        if isinstance(eq, Iterable):
            return [not e for e in eq]
        return not eq

    # Implement logical operators to enable composition of sets
    def __and__(self, other):
        return AndCategory(self, other)

    def __or__(self, other):
        return OrCategory(self, other)

    def __xor__(self, other):
        return XOrCategory(self, other)

    def __invert__(self):
        if isinstance(self, NotCategory):
            return self._cat
        return NotCategory(self)


@set_module("sisl.category")
class GenericCategory(Category):
    """Used to indicate that the category does not act on specific objects

    It serves to identify categories such as `NullCategory`, `NotCategory`
    and `CompositeCategory` and distinguish them from categories that have
    a specific object in which they act.
    """

    __slots__ = ()

    @classmethod
    def is_class(cls, name):
        # never allow one to match a generic class
        # I.e. you can't instantiate a Null/Not/And/Or/XOr category by name
        return False


@set_module("sisl.category")
class NullCategory(GenericCategory):
    r"""Special Null class which always represents a classification not being *anything*"""

    __slots__ = ()

[docs] def categorize(self, *args, **kwargs): return self
@singledispatchmethod def __eq__(self, other): if other is None: return True return self.__class__ == other.__class__ @__eq__.register(Iterable) def _(self, other): return super().__eq__(other) @property def name(self): return "∅" @name.setter def name(self, name): raise ValueError( f"One cannot overwrite the name of a {self.__class__.__name__}" ) @set_module("sisl.category") class NotCategory(GenericCategory): """A class returning the *opposite* of this class (NullCategory) if it is categorized as such""" __slots__ = ("_cat",) def __init__(self, cat): super().__init__() if isinstance(cat, CompositeCategory): self.name = f"~({cat})" else: self.name = f"~{cat}" self._cat = cat def categorize(self, *args, **kwargs): r"""Base method for queriyng whether an object is a certain category""" cat = self._cat.categorize(*args, **kwargs) _null = NullCategory() def check(cat): if isinstance(cat, NullCategory): return self return _null if isinstance(cat, Iterable): return [check(c) for c in cat] return check(cat) @singledispatchmethod def __eq__(self, other): if isinstance(other, NotCategory): return self._cat == other._cat return False @__eq__.register(Iterable) def _(self, other): # this will call the list approach return super().__eq__(other) def _composite_name(sep): def getter(self): if not self._name is None: return self._name # Name is unset, we simply return the other parts if isinstance(self.A, CompositeCategory): nameA = f"({self.A.name})" else: nameA = self.A.name if isinstance(self.B, CompositeCategory): nameB = f"({self.B.name})" else: nameB = self.B.name return f"{nameA} {sep} {nameB}" def setter(self, name): self._name = name return property(getter, setter) @set_module("sisl.category") class CompositeCategory(GenericCategory): """A composite class consisting of two categories, an abstract class to always be inherited This should take 2 categories as arguments Parameters ---------- A : Category the left hand side of the set operation B : Category the right hand side of the set operation """ __slots__ = ("A", "B") def __init__(self, A, B): # To ensure we always get composite name super().__init__() self._name = None self.A = A self.B = B def __init_subclass__(cls, /, composite_name: str, **kwargs): super().__init_subclass__(**kwargs) cls.name = _composite_name(composite_name) @set_module("sisl.category") class OrCategory(CompositeCategory, composite_name="|"): """A class consisting of two categories This should take 2 categories as arguments and a binary operator to define how the categories are related. Parameters ---------- A : Category the left hand side of the set operation B : Category the right hand side of the set operation """ __slots__ = () def categorize(self, *args, **kwargs): r"""Base method for queriyng whether an object is a certain category""" catA = self.A.categorize(*args, **kwargs) if isinstance(catA, Iterable): if all(map(lambda a: not isinstance(a, NullCategory), catA)): return catA elif not isinstance(catA, NullCategory): return catA catB = self.B.categorize(*args, **kwargs) def cmp(a, b): if isinstance(a, NullCategory): return b return a if isinstance(catA, Iterable): return [cmp(a, b) for a, b in zip(catA, catB)] return cmp(catA, catB) @set_module("sisl.category") class AndCategory(CompositeCategory, composite_name="&"): """A class consisting of two categories This should take 2 categories as arguments and a binary operator to define how the categories are related. Parameters ---------- A : Category the left hand side of the set operation B : Category the right hand side of the set operation """ __slots__ = () def categorize(self, *args, **kwargs): r"""Base method for queriyng whether an object is a certain category""" catA = self.A.categorize(*args, **kwargs) if isinstance(catA, Iterable): if all(map(lambda a: isinstance(a, NullCategory), catA)): return catA elif isinstance(catA, NullCategory): return catA # We can now get B and categorize catB = self.B.categorize(*args, **kwargs) def cmp(a, b): if isinstance(a, NullCategory): return a if isinstance(b, NullCategory): return b return self if isinstance(catA, Iterable): return [cmp(a, b) for a, b in zip(catA, catB)] return cmp(catA, catB) @set_module("sisl.category") class XOrCategory(CompositeCategory, composite_name="⊕"): """A class consisting of two categories This should take 2 categories as arguments and a binary operator to define how the categories are related. Parameters ---------- A : Category the left hand side of the set operation B : Category the right hand side of the set operation """ __slots__ = () def categorize(self, *args, **kwargs): r"""Base method for queriyng whether an object is a certain category""" catA = self.A.categorize(*args, **kwargs) catB = self.B.categorize(*args, **kwargs) def cmp(a, b): if isinstance(a, NullCategory): return b if isinstance(b, NullCategory): return a # both are not NullCategory, in which case nothing # is exclusive, so we return the NullCategory return NullCategory() if isinstance(catA, Iterable): return [cmp(a, b) for a, b in zip(catA, catB)] return cmp(catA, catB)