reversi/Untitled.ipynb

138 lines
3.9 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"id": "aeccefd6-7729-4830-afd8-e8ed0423d92c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from functools import lru_cache, wraps\n",
"import numpy as np\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "65e726fb-d3d3-4b47-a96e-9a09a7c774d7",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def np_cache(*lru_args, array_argument_indexs=(0,), **lru_kwargs):\n",
" \"\"\"\n",
" LRU cache implementation for functions whose parameter at ``array_argument_index`` is a numpy array of dimensions <= 2\n",
"\n",
" Example:\n",
" >>> from sem_env.utils.cache import np_cache\n",
" >>> array = np.array([[1, 2, 3], [4, 5, 6]])\n",
" >>> @np_cache(maxsize=256)\n",
" ... def multiply(array, factor):\n",
" ... return factor * array\n",
" >>> multiply(array, 2)\n",
" >>> multiply(array, 2)\n",
" >>> multiply.cache_info()\n",
" CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)\n",
" \"\"\"\n",
"\n",
" def decorator(function):\n",
" @wraps(function)\n",
" def wrapper(*args, **kwargs):\n",
" for array_argument_index in array_argument_indexs:\n",
" np_array = args[array_argument_index]\n",
" if len(np_array.shape) > 2:\n",
" raise RuntimeError(\n",
" f\"np_cache is currently only supported for arrays of dim. less than 3 but got shape: {np_array.shape}\"\n",
" )\n",
" hashable_array = tuple(map(tuple, np_array))\n",
" args = list(args)\n",
" args[array_argument_index] = hashable_array\n",
" return cached_wrapper(*args, **kwargs)\n",
"\n",
" @lru_cache(*lru_args, **lru_kwargs)\n",
" def cached_wrapper(*args, **kwargs):\n",
" for array_argument_index in array_argument_indexs:\n",
" hashable_array = args[array_argument_index]\n",
" array = np.array(hashable_array)\n",
" args = list(args)\n",
" args[array_argument_index] = array\n",
" return function(*args, **kwargs)\n",
"\n",
" # copy lru_cache attributes over too\n",
" wrapper.cache_info = cached_wrapper.cache_info\n",
" wrapper.cache_clear = cached_wrapper.cache_clear\n",
" return wrapper\n",
"\n",
" return decorator\n",
"\n",
"\n",
"@np_cache(maxsize=256, array_argument_indexs=(0, 1))\n",
"def multiply(array, array2):\n",
" return array2 * array\n",
"\n",
"\n",
"array = np.array([[1, 2, 3], [4, 5, 6]])\n",
"multiply(array, array)\n",
"multiply(array, array)\n",
"multiply.cache_info()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e064b94-aa85-4a5f-a99f-0adeeb897153",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"arg"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33887602-8325-4bde-8e8a-faf7155ff143",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}