Files
reversi/Untitled.ipynb

3.9 KiB

In [4]:
from functools import lru_cache, wraps
import numpy as np
import time
In [27]:
def np_cache(*lru_args, array_argument_indexs=(0,), **lru_kwargs):
    """
    LRU cache implementation for functions whose parameter at ``array_argument_index`` is a numpy array of dimensions <= 2

    Example:
    >>> from sem_env.utils.cache import np_cache
    >>> array = np.array([[1, 2, 3], [4, 5, 6]])
    >>> @np_cache(maxsize=256)
    ... def multiply(array, factor):
    ...     return factor * array
    >>> multiply(array, 2)
    >>> multiply(array, 2)
    >>> multiply.cache_info()
    CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)
    """

    def decorator(function):
        @wraps(function)
        def wrapper(*args, **kwargs):
            for array_argument_index in array_argument_indexs:
                np_array = args[array_argument_index]
                if len(np_array.shape) > 2:
                    raise RuntimeError(
                        f"np_cache is currently only supported for arrays of dim. less than 3 but got shape: {np_array.shape}"
                    )
                hashable_array = tuple(map(tuple, np_array))
                args = list(args)
                args[array_argument_index] = hashable_array
            return cached_wrapper(*args, **kwargs)

        @lru_cache(*lru_args, **lru_kwargs)
        def cached_wrapper(*args, **kwargs):
            for array_argument_index in array_argument_indexs:
                hashable_array = args[array_argument_index]
                array = np.array(hashable_array)
                args = list(args)
                args[array_argument_index] = array
            return function(*args, **kwargs)

        # copy lru_cache attributes over too
        wrapper.cache_info = cached_wrapper.cache_info
        wrapper.cache_clear = cached_wrapper.cache_clear
        return wrapper

    return decorator


@np_cache(maxsize=256, array_argument_indexs=(0, 1))
def multiply(array, array2):
    return array2 * array


array = np.array([[1, 2, 3], [4, 5, 6]])
multiply(array, array)
multiply(array, array)
multiply.cache_info()
Out[27]:
CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)
In [ ]:
arg
In [ ]: