{ "cells": [ { "cell_type": "code", "execution_count": 4, "id": "aeccefd6-7729-4830-afd8-e8ed0423d92c", "metadata": { "tags": [] }, "outputs": [], "source": [ "from functools import lru_cache, wraps\n", "import numpy as np\n", "import time" ] }, { "cell_type": "code", "execution_count": 27, "id": "65e726fb-d3d3-4b47-a96e-9a09a7c774d7", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def np_cache(*lru_args, array_argument_indexs=(0,), **lru_kwargs):\n", " \"\"\"\n", " LRU cache implementation for functions whose parameter at ``array_argument_index`` is a numpy array of dimensions <= 2\n", "\n", " Example:\n", " >>> from sem_env.utils.cache import np_cache\n", " >>> array = np.array([[1, 2, 3], [4, 5, 6]])\n", " >>> @np_cache(maxsize=256)\n", " ... def multiply(array, factor):\n", " ... return factor * array\n", " >>> multiply(array, 2)\n", " >>> multiply(array, 2)\n", " >>> multiply.cache_info()\n", " CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)\n", " \"\"\"\n", "\n", " def decorator(function):\n", " @wraps(function)\n", " def wrapper(*args, **kwargs):\n", " for array_argument_index in array_argument_indexs:\n", " np_array = args[array_argument_index]\n", " if len(np_array.shape) > 2:\n", " raise RuntimeError(\n", " f\"np_cache is currently only supported for arrays of dim. less than 3 but got shape: {np_array.shape}\"\n", " )\n", " hashable_array = tuple(map(tuple, np_array))\n", " args = list(args)\n", " args[array_argument_index] = hashable_array\n", " return cached_wrapper(*args, **kwargs)\n", "\n", " @lru_cache(*lru_args, **lru_kwargs)\n", " def cached_wrapper(*args, **kwargs):\n", " for array_argument_index in array_argument_indexs:\n", " hashable_array = args[array_argument_index]\n", " array = np.array(hashable_array)\n", " args = list(args)\n", " args[array_argument_index] = array\n", " return function(*args, **kwargs)\n", "\n", " # copy lru_cache attributes over too\n", " wrapper.cache_info = cached_wrapper.cache_info\n", " wrapper.cache_clear = cached_wrapper.cache_clear\n", " return wrapper\n", "\n", " return decorator\n", "\n", "\n", "@np_cache(maxsize=256, array_argument_indexs=(0, 1))\n", "def multiply(array, array2):\n", " return array2 * array\n", "\n", "\n", "array = np.array([[1, 2, 3], [4, 5, 6]])\n", "multiply(array, array)\n", "multiply(array, array)\n", "multiply.cache_info()" ] }, { "cell_type": "code", "execution_count": null, "id": "2e064b94-aa85-4a5f-a99f-0adeeb897153", "metadata": { "tags": [] }, "outputs": [], "source": [ "arg" ] }, { "cell_type": "code", "execution_count": null, "id": "33887602-8325-4bde-8e8a-faf7155ff143", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }