138 lines
3.9 KiB
Plaintext
138 lines
3.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "aeccefd6-7729-4830-afd8-e8ed0423d92c",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from functools import lru_cache, wraps\n",
|
|
"import numpy as np\n",
|
|
"import time"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "65e726fb-d3d3-4b47-a96e-9a09a7c774d7",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)"
|
|
]
|
|
},
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def np_cache(*lru_args, array_argument_indexs=(0,), **lru_kwargs):\n",
|
|
" \"\"\"\n",
|
|
" LRU cache implementation for functions whose parameter at ``array_argument_index`` is a numpy array of dimensions <= 2\n",
|
|
"\n",
|
|
" Example:\n",
|
|
" >>> from sem_env.utils.cache import np_cache\n",
|
|
" >>> array = np.array([[1, 2, 3], [4, 5, 6]])\n",
|
|
" >>> @np_cache(maxsize=256)\n",
|
|
" ... def multiply(array, factor):\n",
|
|
" ... return factor * array\n",
|
|
" >>> multiply(array, 2)\n",
|
|
" >>> multiply(array, 2)\n",
|
|
" >>> multiply.cache_info()\n",
|
|
" CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def decorator(function):\n",
|
|
" @wraps(function)\n",
|
|
" def wrapper(*args, **kwargs):\n",
|
|
" for array_argument_index in array_argument_indexs:\n",
|
|
" np_array = args[array_argument_index]\n",
|
|
" if len(np_array.shape) > 2:\n",
|
|
" raise RuntimeError(\n",
|
|
" f\"np_cache is currently only supported for arrays of dim. less than 3 but got shape: {np_array.shape}\"\n",
|
|
" )\n",
|
|
" hashable_array = tuple(map(tuple, np_array))\n",
|
|
" args = list(args)\n",
|
|
" args[array_argument_index] = hashable_array\n",
|
|
" return cached_wrapper(*args, **kwargs)\n",
|
|
"\n",
|
|
" @lru_cache(*lru_args, **lru_kwargs)\n",
|
|
" def cached_wrapper(*args, **kwargs):\n",
|
|
" for array_argument_index in array_argument_indexs:\n",
|
|
" hashable_array = args[array_argument_index]\n",
|
|
" array = np.array(hashable_array)\n",
|
|
" args = list(args)\n",
|
|
" args[array_argument_index] = array\n",
|
|
" return function(*args, **kwargs)\n",
|
|
"\n",
|
|
" # copy lru_cache attributes over too\n",
|
|
" wrapper.cache_info = cached_wrapper.cache_info\n",
|
|
" wrapper.cache_clear = cached_wrapper.cache_clear\n",
|
|
" return wrapper\n",
|
|
"\n",
|
|
" return decorator\n",
|
|
"\n",
|
|
"\n",
|
|
"@np_cache(maxsize=256, array_argument_indexs=(0, 1))\n",
|
|
"def multiply(array, array2):\n",
|
|
" return array2 * array\n",
|
|
"\n",
|
|
"\n",
|
|
"array = np.array([[1, 2, 3], [4, 5, 6]])\n",
|
|
"multiply(array, array)\n",
|
|
"multiply(array, array)\n",
|
|
"multiply.cache_info()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2e064b94-aa85-4a5f-a99f-0adeeb897153",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"arg"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "33887602-8325-4bde-8e8a-faf7155ff143",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|