2711 lines
795 KiB
Plaintext
2711 lines
795 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Deep Otello AI\n",
|
|
"\n",
|
|
"The game reversi is a very good game to apply deep learning methods to.\n",
|
|
"\n",
|
|
"Othello also known as reversi is a board game first published in 1883 by eiter Lewis Waterman or John W. Mollet in England (each one was denouncing the other as fraud).\n",
|
|
"It is a strickt turn based zero-sum game with a clear Markov chain and now hidden states like in card games with an unknown distribution of cards or unknown player allegiance.\n",
|
|
"There is like for the game go only one set of stones with two colors which is much easier to abstract than chess with its 6 unique pieces.\n",
|
|
"The game has a symmetrical game board wich allows to play with rotating the state around an axis to allow for a breaking of sequences or interesting ANN architectures, quadruple the data generation by simulation or interesting test cases where a symetry in turns should be observable if the AI reaches an \"objective\" policy."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n",
|
|
"## Content\n",
|
|
"\n",
|
|
"* [The game rules](#the-game-rules) A short overview over the rules of the game.\n",
|
|
"* [Some common Otello strategies](#some-common-otello-strategies) introduces some easy approaches to a classic Otello AI and defines some behavioral expectations.\n",
|
|
"* [Initial design decisions](#initial-design-decisions) an explanation about some initial design decision and assumptions\n",
|
|
"* [Imports and dependencies](#imports-and-dependencies) explains what libraries where used"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## The game rules\n",
|
|
"\n",
|
|
"Othello is played on a board with 8 x 8 fields for two player.\n",
|
|
"The board geometry is equal to a chess game.\n",
|
|
"The game is played with game stones that are black on one siede and white on the other.\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"The player take turns.\n",
|
|
"A player places a stone with his or her color up on the game board.\n",
|
|
"The player can only place stones when he surrounds a number of stones with the opponents color with the new stone and already placed stones of his color.\n",
|
|
"Those surrounded stones can either be horizontally, vertically and/or diagonally be placed.\n",
|
|
"All stones thus surrounded will be flipped to be of the players color.\n",
|
|
"Turns are only possible if the player is also changing the color of the opponents stones. If a player can't act he is skipped.\n",
|
|
"The game ends if both players can't act. The player with the most stones wins.\n",
|
|
"If the score is counted in detail unclaimed fields go to the player with more stones of his or her color on the board.\n",
|
|
"The game begins with four stones places in the center of the game. Each player gets two. They are placed diagonally to each other.\n",
|
|
"\n",
|
|
"\n",
|
|
"<img alt=\"Startaufstellung.png\" src=\"Startaufstellung.png\"/>\n",
|
|
"\n",
|
|
"## Some common Othello strategies\n",
|
|
"\n",
|
|
"As can be easily understood the placement of stones and on the bord is always a careful balance of attack and defence.\n",
|
|
"If the player occupies huge homogenous stretches on the board it can be attacked easier.\n",
|
|
"The boards corners provide safety from wich occupied territory is impossible to loos but since it is only possible to reach the corners if the enemy is forced to allow this or calculates the cost of giving a stable base to the enemy it is difficult to obtain.\n",
|
|
"There are some text on otello computer strategies which implement greedy algorithms for reversi based on a modified score to each field.\n",
|
|
"Those different values are score modifiers for a traditional greedy algorithm.\n",
|
|
"If a players stone has captured such a filed the score reached is multiplied by the modifier.\n",
|
|
"The total score is the score reached by the player subtracted with the score of the enemy.\n",
|
|
"The scores change in the course of the game and converges against one. This gives some indications of what to expect from an Othello AI.\n",
|
|
"\n",
|
|
"<img alt=\"ComputerPossitionScore\" src=\"computer-score.png\"/>\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Initial design decisions\n",
|
|
"\n",
|
|
"At the beginning of this project I made some design decisions.\n",
|
|
"The first onw was that I do not want to use a gym library because it limits the data formats accessible.\n",
|
|
"I choose to implement the hole game as entry in a stack in numpy arrays to be able to accommodate interfacing with a neural network easier and to use scipy pattern recognition tools to implement some game mechanics for a fast simulation cycle.\n",
|
|
"I chose to ignore player colors as far as I could instead a player perspective was used. Which allowed to change the perspective with a flipping of the sign. (multiplying with -1).\n",
|
|
"The array format should also allow for data multiplication or the breaking of strikt sequences by flipping the game along one the for axis, (horizontal, vertical, transpose along both diagonals).\n",
|
|
"\n",
|
|
"I wanted to implement different agents as classes that act on those game stacks.\n",
|
|
"\n",
|
|
"Since computation time is critical all computational have results are saved.\n",
|
|
"The analysis of those is then repeated in real time. If a recalculation of such a section is required the save file can be deleted and the code should be executed again."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%load_ext blackcellmagic\n",
|
|
"%load_ext line_profiler\n",
|
|
"%load_ext memory_profiler"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Imports and dependencies\n",
|
|
"\n",
|
|
"The following direct dependencies where used for this project:\n",
|
|
"```toml\n",
|
|
"jupyter = \"^1.0.0\"\n",
|
|
"matplotlib = \"^3.6.3\"\n",
|
|
"numpy = \"^1.24.1\"\n",
|
|
"pytest = \"^7.2.1\"\n",
|
|
"python = \"3.10.*\"\n",
|
|
"scipy = \"^1.10.0\"\n",
|
|
"tqdm = \"^4.64.1\"\n",
|
|
"jupyterlab = \"^3.6.1\"\n",
|
|
"torchvision = \"^0.14.1\"\n",
|
|
"torchaudio = \"^0.13.1\"\n",
|
|
"```\n",
|
|
"* `Jupyter` and `jupyterlab` on pycharm was used as a IDE / Ipython was used to implement this code.\n",
|
|
"* `matplotlib` was used for visualisation and statistics.\n",
|
|
"* `numpy` was used for array support and mathematical functions\n",
|
|
"* `tqdm` was used for progress bars\n",
|
|
"* `scipy` contains fast pattern recognition tools for images. It was used to make an initial estimation about where possible turns should be.\n",
|
|
"* `torch` supplied the ANN functionalities."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"import pickle\n",
|
|
"import abc\n",
|
|
"import itertools\n",
|
|
"import os.path\n",
|
|
"import warnings\n",
|
|
"from abc import ABC\n",
|
|
"from enum import Enum\n",
|
|
"from typing import Final\n",
|
|
"from IPython.display import clear_output\n",
|
|
"from pathlib import Path\n",
|
|
"import glob\n",
|
|
"import copy\n",
|
|
"\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import seaborn as sns\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import torch.optim as optim\n",
|
|
"from ipywidgets import interact\n",
|
|
"from scipy.ndimage import binary_dilation\n",
|
|
"from tqdm.notebook import tqdm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Constants\n",
|
|
"\n",
|
|
"Some general constants needed to be defined. Such as board game size and Player and Enemy representations. Also, directional offsets and the initial placement of blocks."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Object `os.makdir` not found.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"?os.makdir"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"BOARD_SIZE: Final[int] = 8 # defines the board side length as 8\n",
|
|
"PLAYER: Final[int] = 1 # defines the number symbolising the player as 1\n",
|
|
"ENEMY: Final[int] = -1 # defines the number symbolising the enemy as -1\n",
|
|
"EXAMPLE_STACK_SIZE: Final[int] = 1000 # defines the game stack size for examples\n",
|
|
"IMPOSSIBLE: Final[np.ndarray] = np.array([-1, -1], dtype=int)\n",
|
|
"IMPOSSIBLE.setflags(write=False)\n",
|
|
"SIMULATE_TURNS: Final[int] = 70\n",
|
|
"VERIFY_POLICY: Final[bool] = True\n",
|
|
"TRINING_RESULT_PATH: Final[Path] = Path(\"training_data\")\n",
|
|
"if not os.path.exists(TRINING_RESULT_PATH):\n",
|
|
" os.mkdir(TRINING_RESULT_PATH)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The directions array contains all the numerical offsets needed to move along one of the 8 directions in a 2 dimensional grid. This will allow an iteration over the game board.\n",
|
|
"\n",
|
|
""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[-1, -1],\n",
|
|
" [-1, 0],\n",
|
|
" [-1, 1],\n",
|
|
" [ 0, -1],\n",
|
|
" [ 0, 1],\n",
|
|
" [ 1, -1],\n",
|
|
" [ 1, 0],\n",
|
|
" [ 1, 1]])"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"DIRECTIONS: Final[np.ndarray] = np.array(\n",
|
|
" [[i, j] for i in range(-1, 2) for j in range(-1, 2) if j != 0 or i != 0],\n",
|
|
" dtype=int,\n",
|
|
")\n",
|
|
"DIRECTIONS.setflags(write=False)\n",
|
|
"DIRECTIONS"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Another constant needed is the initial start square at the center of the board."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[-1, 1],\n",
|
|
" [ 1, -1]])"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"START_SQUARE: Final[np.ndarray] = np.array(\n",
|
|
" [[ENEMY, PLAYER], [PLAYER, ENEMY]], dtype=int\n",
|
|
")\n",
|
|
"START_SQUARE.setflags(write=False)\n",
|
|
"START_SQUARE"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Creating new boards\n",
|
|
"\n",
|
|
"The first function implemented and tested is a function to generate the starting environment as a stack of games.\n",
|
|
"As described above I simply placed a 2 by 2 square in the center of an empty stack of boards."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[ 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [ 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [ 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [ 0, 0, 0, -1, 1, 0, 0, 0],\n",
|
|
" [ 0, 0, 0, 1, -1, 0, 0, 0],\n",
|
|
" [ 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [ 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
" [ 0, 0, 0, 0, 0, 0, 0, 0]])"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def get_new_games(number_of_games: int) -> np.ndarray:\n",
|
|
" \"\"\"Generates a stack of initialised game boards.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" number_of_games: The size of the board stack.\n",
|
|
"\n",
|
|
" Returns: The generates stack of games as a stack n x 8 x 8.\n",
|
|
"\n",
|
|
" \"\"\"\n",
|
|
" empty = np.zeros([number_of_games, BOARD_SIZE, BOARD_SIZE], dtype=int)\n",
|
|
" empty[:, 3:5, 3:5] = START_SQUARE\n",
|
|
" return empty\n",
|
|
"\n",
|
|
"\n",
|
|
"get_new_games(1)[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"test_number_of_games = 3\n",
|
|
"assert get_new_games(test_number_of_games).shape == (\n",
|
|
" test_number_of_games,\n",
|
|
" BOARD_SIZE,\n",
|
|
" BOARD_SIZE,\n",
|
|
")\n",
|
|
"np.testing.assert_equal(\n",
|
|
" get_new_games(test_number_of_games).sum(axis=1),\n",
|
|
" np.zeros(\n",
|
|
" [\n",
|
|
" test_number_of_games,\n",
|
|
" 8,\n",
|
|
" ]\n",
|
|
" ),\n",
|
|
")\n",
|
|
"np.testing.assert_equal(\n",
|
|
" get_new_games(test_number_of_games).sum(axis=2),\n",
|
|
" np.zeros(\n",
|
|
" [\n",
|
|
" test_number_of_games,\n",
|
|
" 8,\n",
|
|
" ]\n",
|
|
" ),\n",
|
|
")\n",
|
|
"assert np.all(get_new_games(test_number_of_games)[:, 3:4, 3:4] != 0)\n",
|
|
"del test_number_of_games"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Visualisation tools\n",
|
|
"\n",
|
|
"In this section a visualisation help was implemented for debugging of the game and a proper display of the results.\n",
|
|
"For this visualisation ChatGPT was used as a prompted code generator that was later reviewed and refactored by hand to integrate seamlessly into the project as a whole.\n",
|
|
"White stones represent the player, black stones the enemy. A single plot can be used as a subplot when the `ax` argument is used."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 300x300 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"def plot_othello_board(\n",
|
|
" board: np.ndarray,\n",
|
|
" action: np.ndarray | None = None,\n",
|
|
" ax=None,\n",
|
|
") -> None:\n",
|
|
" \"\"\"Plots a single otello board.\n",
|
|
"\n",
|
|
" If a matplot axis object is given the board will be plotted into that axis. If not an axis object will be generated.\n",
|
|
" The image generated will be shown directly.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" board: The bord that should be plotted. Only a single games is allowed. A numpy array of the form 8x8 is expected.\n",
|
|
" ax: If needed a matplotlib axis object can be defined that is used to place the board as a sublot into a bigger context.\n",
|
|
" \"\"\"\n",
|
|
" assert board.shape == (8, 8)\n",
|
|
" plot_all = False\n",
|
|
" if ax is None:\n",
|
|
" fig_size = 3\n",
|
|
" plot_all = True\n",
|
|
" fig, ax = plt.subplots(figsize=(fig_size, fig_size))\n",
|
|
"\n",
|
|
" ax.set_facecolor(\"#0f6b28\")\n",
|
|
" if action is not None:\n",
|
|
" ax.scatter(action[0], action[1], s=350 if plot_all else 200, c=\"red\")\n",
|
|
" for x_pos, y_pos in itertools.product(range(BOARD_SIZE), range(BOARD_SIZE)):\n",
|
|
" if board[x_pos, y_pos] == PLAYER:\n",
|
|
" color = \"white\"\n",
|
|
" elif board[x_pos, y_pos] == ENEMY:\n",
|
|
" color = \"black\"\n",
|
|
" else:\n",
|
|
" continue\n",
|
|
" ax.scatter(x_pos, y_pos, s=280 if plot_all else 140, c=color)\n",
|
|
" for x_pos in range(-1, 8):\n",
|
|
" ax.axhline(x_pos + 0.5, color=\"black\", lw=2)\n",
|
|
" ax.axvline(x_pos + 0.5, color=\"black\", lw=2)\n",
|
|
" ax.set_xlim(-0.5, 7.5)\n",
|
|
" ax.set_ylim(7.5, -0.5)\n",
|
|
" ax.set_xticks(np.arange(8))\n",
|
|
" ax.set_xticklabels(list(\"ABCDEFGH\"))\n",
|
|
" ax.set_yticks(np.arange(8))\n",
|
|
" ax.set_yticklabels(list(\"12345678\"))\n",
|
|
" ax.set_xlabel(\n",
|
|
" f\"W{np.sum(board == ENEMY)} / {np.sum(board == 0)} / B{np.sum(board == PLAYER)}\"\n",
|
|
" )\n",
|
|
" if plot_all:\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"\n",
|
|
"plot_othello_board(get_new_games(1)[0], action=np.array([3, 3]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def plot_othello_boards(boards: np.ndarray, actions: np.ndarray | None = None) -> None:\n",
|
|
" \"\"\"Plots multiple boards into subplots.\n",
|
|
"\n",
|
|
" The plots are shown directly.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" boards: Plots the boards given into subplots. The maximum number of boards accepted is 70.\n",
|
|
" \"\"\"\n",
|
|
" assert len(boards.shape) == 3\n",
|
|
" assert boards.shape[1:] == (BOARD_SIZE, BOARD_SIZE)\n",
|
|
" assert boards.shape[0] < 70\n",
|
|
"\n",
|
|
" if actions is not None:\n",
|
|
" assert len(actions.shape) == 2\n",
|
|
" assert actions.shape[1] == 2\n",
|
|
" assert boards.shape[0] == actions.shape[0]\n",
|
|
"\n",
|
|
" plots_per_row = 4\n",
|
|
" rows = int(np.ceil(boards.shape[0] / plots_per_row))\n",
|
|
" fig, axs = plt.subplots(rows, plots_per_row, figsize=(12, 3 * rows))\n",
|
|
" for game_index, ax in enumerate(axs.flatten()):\n",
|
|
" if game_index >= boards.shape[0]:\n",
|
|
" fig.delaxes(ax)\n",
|
|
" else:\n",
|
|
" action = actions[game_index] if actions is not None else None\n",
|
|
" plot_othello_board(boards[game_index], action=action, ax=ax)\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def drop_duplicate_boards(\n",
|
|
" boards: np.ndarray, actions: np.ndarray | None\n",
|
|
") -> tuple[np.ndarray, np.ndarray | None]:\n",
|
|
" \"\"\"Drop boards that follow each other and are duplicates will be dropped.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" boards: A set of boards to be reduced.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" A sequence of boards where boards that where equal are dropped.\n",
|
|
" \"\"\"\n",
|
|
" non_duplicates = ~np.all(boards == np.roll(boards, axis=0, shift=1), axis=(1, 2))\n",
|
|
" return (\n",
|
|
" boards[non_duplicates],\n",
|
|
" np.roll(actions, axis=0, shift=1)[non_duplicates]\n",
|
|
" if actions is not None\n",
|
|
" else None,\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Find possible actions to take\n",
|
|
"\n",
|
|
"The frist step in the implementation of an AI like this is to get an overview over the possible actions that can be taken in a situation.\n",
|
|
"Here was the design choice taken to first find fields that are empty and have at least one neighbouring enemy stone.\n",
|
|
"This was implemented with element wise check for a stone and a binary dilation marking all fields neighboring an enemy stone.\n",
|
|
"For that the `SURROUNDING` mask was used. Both aries are then element wise combined using and.\n",
|
|
"The resulting array contains all filed where a turn could potentially be made. Those are then check in detail.\n",
|
|
"The previous element wise operations on the numpy array increase the spead for this operation dramatically.\n",
|
|
"\n",
|
|
"The check for a possible turn is done in detail by following each direction step by step as long as there are enemy stones in that direction.\n",
|
|
"If the board end is reached or en empty filed before reaching a field occupied by the player that direction does not surround enemy stones.\n",
|
|
"If one direction surrounds enemy stone a turn is possible.\n",
|
|
"This detailed step is implemented as a recursion and need to go at leas one step to return True."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[[1, 1, 1],\n",
|
|
" [1, 0, 1],\n",
|
|
" [1, 1, 1]]])"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"SURROUNDING: Final = np.array(\n",
|
|
" [[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]\n",
|
|
") # defines the binary dilation mask to check if a field is next to an enemy stones\n",
|
|
"SURROUNDING"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"9.99 ms ± 557 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
|
|
"985 ms ± 41.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[[False, False, False, False, False, False, False, False],\n",
|
|
" [False, False, False, False, False, False, False, False],\n",
|
|
" [False, False, False, True, False, False, False, False],\n",
|
|
" [False, False, True, False, False, False, False, False],\n",
|
|
" [False, False, False, False, False, True, False, False],\n",
|
|
" [False, False, False, False, True, False, False, False],\n",
|
|
" [False, False, False, False, False, False, False, False],\n",
|
|
" [False, False, False, False, False, False, False, False]]])"
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def _recursive_steps(\n",
|
|
" board: np.ndarray,\n",
|
|
" rec_direction: np.ndarray,\n",
|
|
" rec_position: np.ndarray,\n",
|
|
" step_one: int = 0,\n",
|
|
") -> int:\n",
|
|
" \"\"\"Check if a player can place a stone on the board specified in the direction specified and direction specified.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" board: The board that should be checked for a playable action.\n",
|
|
" rec_direction: The direction that should be checked.\n",
|
|
" rec_position: The position that should be checked.\n",
|
|
" step_one: Defines if the call of this function is the firs or not. Should be kept to the default value for proper functionality.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" True if a turn is possible for possition and direction on the board defined.\n",
|
|
" \"\"\"\n",
|
|
" rec_position = rec_position + rec_direction\n",
|
|
" if np.any((rec_position >= BOARD_SIZE) | (rec_position < 0)):\n",
|
|
" return 0\n",
|
|
" next_field = board[tuple(rec_position.tolist())]\n",
|
|
" if next_field == 0:\n",
|
|
" return 0\n",
|
|
" if next_field == -1:\n",
|
|
" return _recursive_steps(\n",
|
|
" board, rec_direction, rec_position, step_one=step_one + 1\n",
|
|
" )\n",
|
|
" if next_field == 1:\n",
|
|
" return step_one\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_possible_turns(boards: np.ndarray, tqdm_on: bool = False) -> np.ndarray:\n",
|
|
" \"\"\"Analyses a stack of boards.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" boards: A stack of boards to check.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" A stack of game boards containing boolean values showing where turns are possible for the player.\n",
|
|
" \"\"\"\n",
|
|
" assert len(boards.shape) == 3, \"The number fo input dimensions does not fit.\"\n",
|
|
" assert boards.shape[1:] == (\n",
|
|
" BOARD_SIZE,\n",
|
|
" BOARD_SIZE,\n",
|
|
" ), \"The input dimensions do not fit.\"\n",
|
|
"\n",
|
|
" poss_turns = boards == 0 # checks where fields are empty.\n",
|
|
" poss_turns &= binary_dilation(\n",
|
|
" boards == -1, SURROUNDING\n",
|
|
" ) # checks where fields are next to an enemy filed an empty\n",
|
|
" iterate_over = itertools.product(\n",
|
|
" range(boards.shape[0]), range(BOARD_SIZE), range(BOARD_SIZE)\n",
|
|
" )\n",
|
|
" if tqdm_on:\n",
|
|
" iterate_over = tqdm(iterate_over, total=np.prod(boards.shape))\n",
|
|
" for game, idx, idy in iterate_over:\n",
|
|
" if poss_turns[game, idx, idy]:\n",
|
|
" position = idx, idy\n",
|
|
" poss_turns[game, idx, idy] = any(\n",
|
|
" _recursive_steps(boards[game, :, :], direction, position) > 0\n",
|
|
" for direction in DIRECTIONS\n",
|
|
" )\n",
|
|
" return poss_turns\n",
|
|
"\n",
|
|
"\n",
|
|
"# some simple testing to ensure the function works after simple changes\n",
|
|
"# this testing is complete, its more of a smoke-test\n",
|
|
"test_array = get_new_games(3)\n",
|
|
"expected_result = np.zeros_like(test_array, dtype=bool)\n",
|
|
"expected_result[:, 4, 5] = expected_result[:, 2, 3] = True\n",
|
|
"expected_result[:, 5, 4] = expected_result[:, 3, 2] = True\n",
|
|
"np.testing.assert_equal(get_possible_turns(test_array), expected_result)\n",
|
|
"\n",
|
|
"\n",
|
|
"%timeit get_possible_turns(get_new_games(10)) # checks turn possibility evaluation time for 10 initial games\n",
|
|
"%timeit get_possible_turns(get_new_games(EXAMPLE_STACK_SIZE)) # check turn possibility evaluation time for EXAMPLE_STACK_SIZE initial games\n",
|
|
"\n",
|
|
"# shows a singe game\n",
|
|
"get_possible_turns(get_new_games(3))[:1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Besides the ability to generate an array of possible turns there needs to be a functions that check if a given turn is possible.\n",
|
|
"On is needed for the action space validation. The other is for validating a players turn."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def move_possible(board: np.ndarray, move: np.ndarray) -> bool:\n",
|
|
" \"\"\"Checks if a turn is possible.\n",
|
|
"\n",
|
|
" Checks if a turn is possible. If no turn is possible to input array [-1, -1] is expected.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" board: A board where it should be checkt if a turn is possible.\n",
|
|
" move: The move that should be taken. Expected is the index of the filed where a stone should be placed [x, y]. If no placement is possible [-1, -1] is expected as an input.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" True if the move is possible\n",
|
|
" \"\"\"\n",
|
|
" if np.all(move == -1):\n",
|
|
" return not np.any(get_possible_turns(np.reshape(board, (1, 8, 8))))\n",
|
|
" return any(\n",
|
|
" _recursive_steps(board[:, :], direction, move) > 0 for direction in DIRECTIONS\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"# Some testing for this function and the underlying recursive functions that are called.\n",
|
|
"assert move_possible(get_new_games(1)[0], np.array([2, 3])) is True\n",
|
|
"assert move_possible(get_new_games(1)[0], np.array([3, 2])) is True\n",
|
|
"assert move_possible(get_new_games(1)[0], np.array([2, 2])) is False\n",
|
|
"assert move_possible(np.zeros((8, 8)), np.array([3, 2])) is False\n",
|
|
"assert move_possible(np.ones((8, 8)) * 1, np.array([-1, -1])) is True\n",
|
|
"assert move_possible(np.ones((8, 8)) * -1, np.array([-1, -1])) is True\n",
|
|
"assert move_possible(np.ones((8, 8)) * 0, np.array([-1, -1])) is True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def moves_possible(boards: np.ndarray, moves: np.ndarray) -> np.ndarray:\n",
|
|
" \"\"\"Checks if a stack of moves can be executed on a stack of boards.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" boards: A board where the next stone should be placed.\n",
|
|
" moves: A stack stones to be placed. Each move is formatted as an array in the form of [x, y] if no turn is possible the value [-1, -1] is expected.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" An array marking for each and every game and move in the stack if the move can be executed.\n",
|
|
" \"\"\"\n",
|
|
" arr_moves_possible = np.zeros(boards.shape[0], dtype=bool)\n",
|
|
" for game in range(boards.shape[0]):\n",
|
|
" if np.all(\n",
|
|
" moves[game] == -1\n",
|
|
" ): # can be all or any. All should be faster since most times neither value will be -1.\n",
|
|
" arr_moves_possible[game] = not np.any(\n",
|
|
" get_possible_turns(np.reshape(boards[game], (1, 8, 8)))\n",
|
|
" )\n",
|
|
" else:\n",
|
|
" arr_moves_possible[game] = any(\n",
|
|
" _recursive_steps(boards[game, :, :], direction, moves[game]) > 0\n",
|
|
" for direction in DIRECTIONS\n",
|
|
" )\n",
|
|
" return arr_moves_possible\n",
|
|
"\n",
|
|
"\n",
|
|
"np.testing.assert_array_equal(\n",
|
|
" moves_possible(np.ones((3, 8, 8)) * 1, np.array([[-1, -1]] * 3)),\n",
|
|
" np.array([True] * 3),\n",
|
|
")\n",
|
|
"\n",
|
|
"np.testing.assert_array_equal(\n",
|
|
" moves_possible(get_new_games(3), np.array([[2, 3], [3, 2], [3, 2]])),\n",
|
|
" np.array([True] * 3),\n",
|
|
")\n",
|
|
"np.testing.assert_array_equal(\n",
|
|
" moves_possible(get_new_games(3), np.array([[2, 2], [1, 1], [0, 0]])),\n",
|
|
" np.array([False] * 3),\n",
|
|
")\n",
|
|
"np.testing.assert_array_equal(\n",
|
|
" moves_possible(np.ones((3, 8, 8)) * -1, np.array([[-1, -1]] * 3)),\n",
|
|
" np.array([True] * 3),\n",
|
|
")\n",
|
|
"np.testing.assert_array_equal(\n",
|
|
" moves_possible(np.zeros((3, 8, 8)), np.array([[-1, -1]] * 3)),\n",
|
|
" np.array([True] * 3),\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Reword functions\n",
|
|
"\n",
|
|
"For any kind of reinforcement learning is a reword function needed.\n",
|
|
"For otello this would be the final score, the information who won or changes to the score.\n",
|
|
"A combination of those three would also be possible.\n",
|
|
"It is probably not be possible to weight the current score to high in a reword function since that would be to close to a classic greedy algorithm.\n",
|
|
"But some direct influence would increase the learning speed.\n",
|
|
"In the next section are all three reword functions implemented to be combined and weight later on as needed."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"213 µs ± 7.62 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
|
|
"38 µs ± 1.99 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n",
|
|
"38 µs ± 1.92 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def final_boards_evaluation(boards: np.ndarray) -> np.ndarray:\n",
|
|
" \"\"\"Evaluates the board at the end of the game.\n",
|
|
"\n",
|
|
" All unused fields are added to the score of the player that has more stones with his color up.\n",
|
|
" This score only applies to the end of the game.\n",
|
|
" Normally the score is represented by the number of stones each player has.\n",
|
|
" In this case the score was combined by building the difference.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" boards: A stack of game bords ot the end of the game.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" the combined score for both player.\n",
|
|
" \"\"\"\n",
|
|
" score1, score2 = np.sum(boards == 1, axis=(1, 2)), np.sum(boards == -1, axis=(1, 2))\n",
|
|
" player_1_won = score1 > score2\n",
|
|
" player_2_won = score1 < score2\n",
|
|
" score1_final = 64 - score2[player_1_won]\n",
|
|
" score2_final = 64 - score1[player_2_won]\n",
|
|
" score1[player_1_won] = score1_final\n",
|
|
" score2[player_2_won] = score2_final\n",
|
|
" return score1 - score2\n",
|
|
"\n",
|
|
"\n",
|
|
"def evaluate_boards(boards: np.ndarray) -> np.ndarray:\n",
|
|
" \"\"\"Counts the stones each player has on the board.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" boards: A stack of boards for evaluation.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" the combined score for both player.\n",
|
|
" \"\"\"\n",
|
|
" return np.sum(boards, axis=(1, 2))\n",
|
|
"\n",
|
|
"\n",
|
|
"def evaluate_who_won(boards: np.ndarray) -> np.ndarray:\n",
|
|
" \"\"\"Checks who won or is winning a game.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" boards: A stack of boards for evaluation.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" The information who won for both player. 1 meaning the player won, -1 means the opponent lost. 0 represents a patt.\n",
|
|
" \"\"\"\n",
|
|
" return np.sign(np.sum(boards, axis=(1, 2)))\n",
|
|
"\n",
|
|
"\n",
|
|
"_boards = get_new_games(EXAMPLE_STACK_SIZE)\n",
|
|
"%timeit final_boards_evaluation(_boards)\n",
|
|
"%timeit evaluate_boards(_boards)\n",
|
|
"%timeit evaluate_who_won(_boards)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Execute a chosen action\n",
|
|
"\n",
|
|
"After an evaluation what turns are possible there needs to be a function that executes a turn.\n",
|
|
"This next sections does that."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class InvalidTurn(ValueError):\n",
|
|
" \"\"\"\n",
|
|
" This error is thrown if a given turn is not valid.\n",
|
|
" \"\"\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"99.2 ms ± 2.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 300x300 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"def do_moves(boards: np.ndarray, moves: np.ndarray) -> np.ndarray:\n",
|
|
" \"\"\"Executes a single move on a stack o Othello boards.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" boards: A stack of Othello boards where the next stone should be placed.\n",
|
|
" moves: A stack of stone placement orders for the game. Formatted as coordinates in an array [x, y] of the place where the stone should be placed. Should contain [-1,-1] if no new placement is possible.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" The new state of the board.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def _do_directional_move(\n",
|
|
" board: np.ndarray, rec_move: np.ndarray, rev_direction, step_one=True\n",
|
|
" ) -> bool:\n",
|
|
" \"\"\"Changes the color of enemy stones in one direction.\n",
|
|
"\n",
|
|
" This function works recursive. The argument step_one should always be used in its default value.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" board: A bord on which a stone was placed.\n",
|
|
" rec_move: The position on the board in x and y where this function is called from. Will be moved by recursive called.\n",
|
|
" rev_direction: The position where the stone was placed. Inside this recursion it will also be the last step that was checked.\n",
|
|
" step_one: Set to true if this is the first step in the recursion. False later on.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" True if a stone could be flipped.\n",
|
|
" All changes are made on the view of the numpy array and therefore not included in the return value.\n",
|
|
" \"\"\"\n",
|
|
" rec_position = rec_move + rev_direction\n",
|
|
" if np.any((rec_position >= 8) | (rec_position < 0)):\n",
|
|
" return False\n",
|
|
" next_field = board[tuple(rec_position.tolist())]\n",
|
|
" if next_field == 0:\n",
|
|
" return False\n",
|
|
" if next_field == 1:\n",
|
|
" return not step_one\n",
|
|
" if next_field == -1:\n",
|
|
" if _do_directional_move(board, rec_position, rev_direction, step_one=False):\n",
|
|
" board[tuple(rec_position.tolist())] = 1\n",
|
|
" return True\n",
|
|
" return False\n",
|
|
"\n",
|
|
" def _do_move(_board: np.ndarray, move: np.ndarray) -> None:\n",
|
|
" \"\"\"Executes a turn on a board.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" _board: The game board on wich to place a stone.\n",
|
|
" move: The coordinates of a stone that should be placed. Should be formatted as an array of the form [x, y]. The value [-1, -1] is expected if no turn is possible.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" All changes are made on the view of the numpy array.\n",
|
|
" \"\"\"\n",
|
|
" if np.all(move == -1):\n",
|
|
" if not move_possible(_board, move):\n",
|
|
" raise InvalidTurn(\"An action should be taken. A turn is possible.\")\n",
|
|
" return\n",
|
|
"\n",
|
|
" # noinspection PyTypeChecker\n",
|
|
" if _board[tuple(move.tolist())] != 0:\n",
|
|
" raise InvalidTurn(\"This turn is not possible.\")\n",
|
|
"\n",
|
|
" action = False\n",
|
|
" for direction in DIRECTIONS:\n",
|
|
" if _do_directional_move(_board, move, direction):\n",
|
|
" action = True\n",
|
|
" if not action:\n",
|
|
" raise InvalidTurn(\"This turn is not possible.\")\n",
|
|
"\n",
|
|
" # noinspection PyTypeChecker\n",
|
|
" _board[tuple(move.tolist())] = 1\n",
|
|
"\n",
|
|
" boards = boards.copy()\n",
|
|
" for game in range(boards.shape[0]):\n",
|
|
" _do_move(boards[game], moves[game])\n",
|
|
" return boards\n",
|
|
"\n",
|
|
"\n",
|
|
"%timeit do_moves(get_new_games(EXAMPLE_STACK_SIZE), np.array([[2, 3]] * EXAMPLE_STACK_SIZE))[0]\n",
|
|
"\n",
|
|
"plot_othello_board(\n",
|
|
" do_moves(\n",
|
|
" get_new_games(EXAMPLE_STACK_SIZE), np.array([[2, 3]] * EXAMPLE_STACK_SIZE)\n",
|
|
" )[0],\n",
|
|
" action=np.array([2, 3]),\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## An abstract reversi game policy\n",
|
|
"\n",
|
|
"For an easy use of policies an abstract class containing the policy generation / requests an action in an inherited instance of this class.\n",
|
|
"This class filters the policy to only propose valid actions. Inherited instance do not need to care about this. This super class also manges exploration and exploitation with the epsilon value."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class GamePolicy(ABC):\n",
|
|
" \"\"\"\n",
|
|
" A game policy. Proposes where to place a stone next.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, epsilon: float):\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" epsilon: the epsilon / greedy value. Should be between zero and one. Set the mixture of policy and exploration. One means only the policy is used. Zero means only random policies are used. All mixtures inbetween between are possible.\n",
|
|
" \"\"\"\n",
|
|
" if 0 > epsilon > 1:\n",
|
|
" raise ValueError(\"Epsilon should be between zero and one.\")\n",
|
|
" self._epsilon: float = epsilon\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def epsilon(self):\n",
|
|
" return self._epsilon\n",
|
|
"\n",
|
|
" @property\n",
|
|
" @abc.abstractmethod\n",
|
|
" def policy_name(self) -> str:\n",
|
|
" \"\"\"The name of this policy\"\"\"\n",
|
|
" raise NotImplementedError()\n",
|
|
"\n",
|
|
" @abc.abstractmethod\n",
|
|
" def _internal_policy(self, boards: np.ndarray) -> np.ndarray:\n",
|
|
" \"\"\"The internal policy is an unfiltered policy. It should only be called from inside this function\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" boards: A board where a policy should be calculated for.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" The policy for this board. Should have the same size as the boards array.\n",
|
|
" \"\"\"\n",
|
|
" raise NotImplementedError()\n",
|
|
"\n",
|
|
" def get_policy(self, boards: np.ndarray) -> np.ndarray:\n",
|
|
" \"\"\"Calculates the policy that should be followed.\n",
|
|
"\n",
|
|
" Calculates the policy that should be followed.\n",
|
|
" This function does include the usage of epsilon to configure greediness and exploration.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" boards: A set of boards that show the environment where the policy should be calculated for.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" A vector of indices. Should be formatted as an array of the form [x, y]. The value [-1, -1] is expected if no turn is possible.\n",
|
|
" \"\"\"\n",
|
|
" assert len(boards.shape) == 3\n",
|
|
" assert boards.shape[1:] == (BOARD_SIZE, BOARD_SIZE)\n",
|
|
"\n",
|
|
" if self.epsilon <= 0:\n",
|
|
" policies = np.random.rand(*boards.shape)\n",
|
|
" else:\n",
|
|
" policies = self._internal_policy(boards)\n",
|
|
" if self.epsilon < 1:\n",
|
|
" policies = policies * self.epsilon + np.random.rand(*boards.shape) * (\n",
|
|
" 1 - self.epsilon\n",
|
|
" )\n",
|
|
"\n",
|
|
" # todo talk to team about backpropagation of score and epsilon for greedy factor\n",
|
|
"\n",
|
|
" # todo possibly change this function to only validate the purpose turn and not all turns\n",
|
|
" possible_turns = get_possible_turns(boards)\n",
|
|
" policies[possible_turns == False] = -1.0\n",
|
|
" max_indices = [\n",
|
|
" np.unravel_index(policy.argmax(), policy.shape) for policy in policies\n",
|
|
" ]\n",
|
|
" policy_vector = np.array(max_indices, dtype=int)\n",
|
|
" no_turn_possible = np.all(policy_vector == 0, 1) & (policies[:, 0, 0] == -1.0)\n",
|
|
"\n",
|
|
" policy_vector[no_turn_possible, :] = IMPOSSIBLE\n",
|
|
" return policy_vector"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## A first policy\n",
|
|
"\n",
|
|
"To quantify the quality of a game AI there needs to be some benchmarks.\n",
|
|
"The easiest benchmark is to play against a random player.\n",
|
|
"The easiest player to use as a benchmark is the random player.\n",
|
|
"For this and testing purpose the random policy was implemented."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class RandomPolicy(GamePolicy):\n",
|
|
" \"\"\"\n",
|
|
" A policy playing a random turn by setting epsilon to 0.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, epsilon: float = 0):\n",
|
|
" _ = epsilon\n",
|
|
" super().__init__(epsilon=0)\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def policy_name(self) -> str:\n",
|
|
" return \"random\"\n",
|
|
"\n",
|
|
" def _internal_policy(self, boards: np.ndarray) -> np.ndarray:\n",
|
|
" pass\n",
|
|
"\n",
|
|
"\n",
|
|
"rnd_policy = RandomPolicy(1)\n",
|
|
"assert rnd_policy.policy_name == \"random\"\n",
|
|
"assert rnd_policy.epsilon == 0\n",
|
|
"\n",
|
|
"rnd_policy_result = rnd_policy.get_policy(get_new_games(10))\n",
|
|
"assert np.any((5 >= rnd_policy_result) & (rnd_policy_result >= 3))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class GreedyPolicy(GamePolicy):\n",
|
|
" \"\"\"\n",
|
|
" A policy playing always one of the strongest turns.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, epsilon: float = 1):\n",
|
|
" _ = epsilon\n",
|
|
" super().__init__(1)\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def policy_name(self) -> str:\n",
|
|
" return \"greedy_policy\"\n",
|
|
"\n",
|
|
" def _internal_policy(self, boards: np.ndarray) -> np.ndarray:\n",
|
|
" policies = np.random.rand(*boards.shape)\n",
|
|
" poss_turns = boards == 0 # checks where fields are empty.\n",
|
|
" poss_turns &= binary_dilation(boards == -1, SURROUNDING)\n",
|
|
" for game, idx, idy in itertools.product(\n",
|
|
" range(boards.shape[0]), range(BOARD_SIZE), range(BOARD_SIZE)\n",
|
|
" ):\n",
|
|
"\n",
|
|
" if poss_turns[game, idx, idy]:\n",
|
|
" position = idx, idy\n",
|
|
" policies[game, idx, idy] += np.sum(\n",
|
|
" np.array(\n",
|
|
" list(\n",
|
|
" _recursive_steps(boards[game, :, :], direction, position)\n",
|
|
" for direction in DIRECTIONS\n",
|
|
" )\n",
|
|
" )\n",
|
|
" )\n",
|
|
" return policies"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Putting the game simulation together\n",
|
|
"Now it's time to bring all together for a proper simulation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Playing a single turn\n",
|
|
"\n",
|
|
"The next function needed is used to request a policy, verify that the turn is legit and place a stone and turn enemy stones if possible."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1.18 s ± 36.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
|
|
"1.08 s ± 32.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 1200x600 with 8 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"def single_turn(\n",
|
|
" current_boards: np, policy: GamePolicy\n",
|
|
") -> tuple[np.ndarray, np.ndarray]:\n",
|
|
" \"\"\"Execute a single turn on a board.\n",
|
|
"\n",
|
|
" Places a new stone on the board. Turns captured enemy stones.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" current_boards: The current board before the game.\n",
|
|
" policy: The game policy to be used.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" The new game board and the policy vector containing the index of the action used.\n",
|
|
" \"\"\"\n",
|
|
" policy_results = policy.get_policy(current_boards)\n",
|
|
"\n",
|
|
" # if the constant VERIFY_POLICY is set to true the policy is verified. Should be good though.\n",
|
|
" # todo deactivate the policy verification after some testing.\n",
|
|
" if VERIFY_POLICY:\n",
|
|
" assert np.all(moves_possible(current_boards, policy_results)), (\n",
|
|
" current_boards[(moves_possible(current_boards, policy_results) == False)],\n",
|
|
" policy_results[(moves_possible(current_boards, policy_results) == False)],\n",
|
|
" np.where(moves_possible(current_boards, policy_results) == False),\n",
|
|
" )\n",
|
|
" return do_moves(current_boards, policy_results), policy_results\n",
|
|
"\n",
|
|
"\n",
|
|
"%timeit single_turn(get_new_games(EXAMPLE_STACK_SIZE), RandomPolicy(1))\n",
|
|
"VERIFY_POLICY = False # type: ignore\n",
|
|
"%timeit single_turn(get_new_games(EXAMPLE_STACK_SIZE), RandomPolicy(1))\n",
|
|
"VERIFY_POLICY = True # type: ignore\n",
|
|
"_turn_result = single_turn(get_new_games(EXAMPLE_STACK_SIZE), RandomPolicy(1))\n",
|
|
"plot_othello_boards(_turn_result[0][:8], _turn_result[1][:8])\n",
|
|
"del _turn_result"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Simulate a stack of games\n",
|
|
"This function will simulate a stack of games and return an array of policies and histories."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 1200x4800 with 61 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"def simulate_game(\n",
|
|
" nr_of_games: int,\n",
|
|
" policies: tuple[GamePolicy, GamePolicy],\n",
|
|
" tqdm_on: bool = False,\n",
|
|
") -> tuple[np.ndarray, np.ndarray]:\n",
|
|
" \"\"\"Simulates a stack of games.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" nr_of_games: The number of games that should be simulated.\n",
|
|
" policies: The policies that should be used to simulate the game.\n",
|
|
" tqdm_on: Switches tqdm on.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" A stack of board histories and actions.\n",
|
|
" \"\"\"\n",
|
|
" board_history_stack = np.zeros((SIMULATE_TURNS, nr_of_games, 8, 8), dtype=np.int8)\n",
|
|
" action_history_stack = np.zeros((SIMULATE_TURNS, nr_of_games, 2), dtype=np.int8)\n",
|
|
" current_boards = get_new_games(nr_of_games)\n",
|
|
" for turn_index in tqdm(range(SIMULATE_TURNS)) if tqdm_on else range(SIMULATE_TURNS):\n",
|
|
" policy_index = turn_index % 2\n",
|
|
" policy = policies[policy_index]\n",
|
|
" board_history_stack[turn_index, :, :, :] = current_boards\n",
|
|
" if policy_index == 0:\n",
|
|
" current_boards = current_boards * -1\n",
|
|
" current_boards, action_taken = single_turn(current_boards, policy)\n",
|
|
" action_history_stack[turn_index, :] = action_taken\n",
|
|
"\n",
|
|
" if policy_index == 0:\n",
|
|
" current_boards = current_boards * -1\n",
|
|
"\n",
|
|
" return board_history_stack, action_history_stack\n",
|
|
"\n",
|
|
"\n",
|
|
"simulation_results = simulate_game(1, (RandomPolicy(1), RandomPolicy(1)))\n",
|
|
"_unique_bords, _unique_actions = drop_duplicate_boards(\n",
|
|
" simulation_results[0].reshape(-1, 8, 8), simulation_results[1].reshape(-1, 2)\n",
|
|
")\n",
|
|
"plot_othello_boards(_unique_bords, actions=_unique_actions)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(70, 8, 8)"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"np.reshape(simulation_results[0], (-1, 8, 8)).shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(70, 2)"
|
|
]
|
|
},
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"simulation_results[1].reshape(-1, 2).shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"peak memory: 340.06 MiB, increment: 0.29 MiB\n",
|
|
"10.3 s ± 473 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%memit simulate_game(100, (RandomPolicy(1), RandomPolicy(1)))\n",
|
|
"%timeit simulate_game(100, (RandomPolicy(1), RandomPolicy(1)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Statistical examination of the natural action space and result\n",
|
|
"As for many project some evaluation of the project is in order.\n",
|
|
"\n",
|
|
"1. What is the expected distribution of scores\n",
|
|
"2. What is the expected distribution of possible actions\n",
|
|
"\n",
|
|
" a. over time\n",
|
|
" \n",
|
|
" b. ober space\n",
|
|
"\n",
|
|
"The easiest and robustest way to analyse this is when analyzing randomly played games."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"For this pupose we played a sample of 10k games and saved them for later analysis."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"((70, 10000, 8, 8), (70, 10000, 2))"
|
|
]
|
|
},
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"if not os.path.exists(\"rnd_history.npy\") and not os.path.exists(\"rnd_action.npy\"):\n",
|
|
" rnds = RandomPolicy(1), RandomPolicy(1)\n",
|
|
" simulation_results = simulate_game(10_000, rnds, tqdm_on=True)\n",
|
|
" _board_history, _action_history = simulation_results\n",
|
|
" np.save(\"rnd_history.npy\", np.astpye.astype(np.int8))\n",
|
|
" np.save(\"rnd_action.npy\", _action_history.astype(np.int8))\n",
|
|
"else:\n",
|
|
" _board_history = np.load(\"rnd_history.npy\")\n",
|
|
" _action_history = np.load(\"rnd_action.npy\")\n",
|
|
"_board_history.shape, _action_history.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"For those 10k games the possible actions where evaluated and saved for each and every turn in the game."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(70, 10000, 8, 8)"
|
|
]
|
|
},
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"if not os.path.exists(\"turn_possible.npy\"):\n",
|
|
" __board_history = _board_history.copy()\n",
|
|
" __board_history[1::2] = __board_history[1::2] * -1\n",
|
|
"\n",
|
|
" _poss_turns = get_possible_turns(\n",
|
|
" __board_history.reshape((-1, 8, 8)), tqdm_on=True\n",
|
|
" ).reshape((SIMULATE_TURNS, -1, 8, 8))\n",
|
|
" np.save(\"turn_possible.npy\", _poss_turns)\n",
|
|
" del __board_history\n",
|
|
"_poss_turns = np.load(\"turn_possible.npy\")\n",
|
|
"_poss_turns.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Those possible turms then where counted for all games in the history stack."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The action space size can be drawn into a histogram by turn and a curve over the mean action space size.\n",
|
|
"This can be used to analyse in which area of the game that cant be solved abolutely."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "f676f497a974475bb7a51fd6fb3921a3",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"interactive(children=(IntSlider(value=34, description='turn', max=69), Output()), _dom_classes=('widget-intera…"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"count_poss_turns = np.sum(_poss_turns, axis=(2, 3))\n",
|
|
"mean_possibilitie_count = np.mean(count_poss_turns, axis=1)\n",
|
|
"std_possibilitie_count = np.std(count_poss_turns, axis=1)\n",
|
|
"cum_prod = count_poss_turns\n",
|
|
"\n",
|
|
"\n",
|
|
"@interact(turn=(0, 69))\n",
|
|
"def poss_turn_count(turn):\n",
|
|
" fig, axes = plt.subplots(2, 2, figsize=(15, 8))\n",
|
|
" ax1, ax2, ax3, ax4 = axes.flatten()\n",
|
|
" _mean_possibilitie_count = mean_possibilitie_count.copy()\n",
|
|
" _std_possibilitie_count = std_possibilitie_count.copy()\n",
|
|
" _mean_possibilitie_count[_mean_possibilitie_count <= 1] = 1\n",
|
|
" _std_possibilitie_count[_std_possibilitie_count <= 1] = 1\n",
|
|
" np.cumprod(_mean_possibilitie_count[::-1], axis=0)[::-1]\n",
|
|
" fig.suptitle(\n",
|
|
" f\"Action space size analysis\\nThe total size is estimated to be around {np.prod(_mean_possibilitie_count):.4g}\"\n",
|
|
" )\n",
|
|
" ax1.hist(count_poss_turns[turn], density=True)\n",
|
|
" ax1.set_title(f\"Histogram of the action space size for turn {turn}\")\n",
|
|
" ax1.set_xlabel(\"Action space size\")\n",
|
|
" ax1.set_ylabel(\"Action space size probability\")\n",
|
|
" ax2.set_title(f\"Mean size of the action space per turn\")\n",
|
|
" ax2.set_xlabel(\"Turn\")\n",
|
|
" ax2.set_ylabel(\"Average possible moves\")\n",
|
|
"\n",
|
|
" ax2.errorbar(\n",
|
|
" range(70),\n",
|
|
" mean_possibilitie_count,\n",
|
|
" yerr=std_possibilitie_count,\n",
|
|
" label=\"Mean action space size with error bars\",\n",
|
|
" )\n",
|
|
" ax2.scatter(turn, mean_possibilitie_count[turn], marker=\"x\")\n",
|
|
" ax2.legend()\n",
|
|
"\n",
|
|
" ax4.plot(\n",
|
|
" range(70),\n",
|
|
" np.cumprod((_mean_possibilitie_count)[::-1], axis=0)[::-1],\n",
|
|
" # yerr=np.cumprod(_std_possibilitie_count[::-1], axis=0)[::-1],\n",
|
|
" )\n",
|
|
" ax4.scatter(\n",
|
|
" turn,\n",
|
|
" np.cumprod(_mean_possibilitie_count[::-1], axis=0)[::-1][turn],\n",
|
|
" marker=\"x\",\n",
|
|
" )\n",
|
|
" ax4.set_yscale(\"log\", base=10)\n",
|
|
" ax4.set_xlabel(\"Turn\")\n",
|
|
" ax4.set_ylabel(\"Mean remaining total action space size\")\n",
|
|
" fig.delaxes(ax3)\n",
|
|
" fig.tight_layout()\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"It is interesting to see that the action space for the first player (white) is much smaller than for the second palyer."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>Total mean actionspace</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>white</th>\n",
|
|
" <td>5.687159e+18</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>black</th>\n",
|
|
" <td>3.753117e+20</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" Total mean actionspace\n",
|
|
"white 5.687159e+18\n",
|
|
"black 3.753117e+20"
|
|
]
|
|
},
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"white = mean_possibilitie_count[::2]\n",
|
|
"black = mean_possibilitie_count[1::2]\n",
|
|
"df = pd.DataFrame(\n",
|
|
" [\n",
|
|
" {\n",
|
|
" \"white\": np.prod(np.extract(white, white)),\n",
|
|
" \"black\": np.prod(np.extract(black, black)),\n",
|
|
" }\n",
|
|
" ],\n",
|
|
" index=[\"Total mean actionspace\"],\n",
|
|
").T\n",
|
|
"del white, black\n",
|
|
"df"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(70, 10000, 8, 8)"
|
|
]
|
|
},
|
|
"execution_count": 31,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"_poss_turns.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "8d48165bf2db47f491a1fdfea7445b27",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"interactive(children=(IntSlider(value=34, description='turn', max=69), Output()), _dom_classes=('widget-intera…"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"mean_poss_turn = np.mean(_poss_turns, axis=1)\n",
|
|
"del _poss_turns\n",
|
|
"\n",
|
|
"\n",
|
|
"@interact(turn=(0, 69))\n",
|
|
"def turn_distribution_heatmap(turn):\n",
|
|
" turn_possibility_on_field = mean_poss_turn[turn]\n",
|
|
"\n",
|
|
" uniform_data = np.random.rand(10, 12)\n",
|
|
" sns.heatmap(\n",
|
|
" turn_possibility_on_field,\n",
|
|
" linewidth=0.5,\n",
|
|
" square=True,\n",
|
|
" annot=True,\n",
|
|
" xticklabels=\"ABCDEFGH\",\n",
|
|
" yticklabels=list(range(1, 9)),\n",
|
|
" )\n",
|
|
" plt.title(f\"Headmap of where stones can be placed on turn {turn}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(70, 10000)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def calculate_direct_score(board_history: np.ndarray) -> np.ndarray:\n",
|
|
" boards_evaluated = np.reshape(\n",
|
|
" evaluate_boards(np.reshape(board_history, (-1, 8, 8))), (SIMULATE_TURNS, -1)\n",
|
|
" )\n",
|
|
" direct_score = boards_evaluated - np.roll(boards_evaluated, shift=-1, axis=0)\n",
|
|
" direct_score[-1] = 0\n",
|
|
" return direct_score / 64\n",
|
|
"\n",
|
|
"\n",
|
|
"print(np.max(np.abs(calculate_direct_score(_board_history))))\n",
|
|
"assert len(calculate_direct_score(_board_history).shape) == 2\n",
|
|
"assert calculate_direct_score(_board_history).shape[0] == SIMULATE_TURNS\n",
|
|
"print(np.mincalculate_direct_score(_board_history).shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "1381da4a05b24c60b58f0a8ff2d8d7be",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"interactive(children=(IntSlider(value=29, description='turn', max=59), Output()), _dom_classes=('widget-intera…"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"score_history = calculate_direct_score(_board_history) * 64\n",
|
|
"score_history[1::2] = score_history[1::2] * -1\n",
|
|
"\n",
|
|
"\n",
|
|
"@interact(turn=(0, 59))\n",
|
|
"def hist_direct_score(turn):\n",
|
|
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n",
|
|
" fig.suptitle(\n",
|
|
" f\"Action space size analysis / total size estimat {np.prod(np.extract(mean_possibilitie_count, mean_possibilitie_count)):.4g}\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" ax1.set_title(\n",
|
|
" f\"Histogram of scores on turn {turn} by {'white' if turn % 2 == 0 else 'black'}\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" ax1.hist(score_history[turn], density=True)\n",
|
|
" ax1.set_xlabel(\"Points made\")\n",
|
|
" ax1.set_ylabel(\"Score probability\")\n",
|
|
" ax2.set_title(f\"Points scored at turn\")\n",
|
|
" ax2.set_xlabel(\"Turn\")\n",
|
|
" ax2.set_ylabel(\"Average points scored\")\n",
|
|
"\n",
|
|
" ax2.errorbar(\n",
|
|
" range(60),\n",
|
|
" np.mean(score_history, axis=1)[:60],\n",
|
|
" yerr=np.std(score_history, axis=1)[:60],\n",
|
|
" label=\"Mean socre at turn\",\n",
|
|
" )\n",
|
|
" ax2.scatter(turn, np.mean(score_history, axis=1)[turn], marker=\"x\", color=\"red\")\n",
|
|
" ax2.legend()\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"def calculate_final_evaluation_for_history(board_history: np.ndarray) -> np.ndarray:\n",
|
|
" final_evaluation = final_boards_evaluation(board_history[-1])\n",
|
|
" return final_evaluation / 64\n",
|
|
"\n",
|
|
"\n",
|
|
"print(np.max(np.abs(calculate_final_evaluation_for_history(_board_history))))\n",
|
|
"assert len(calculate_final_evaluation_for_history(_board_history).shape) == 1\n",
|
|
"_final_eval = calculate_final_evaluation_for_history(_board_history)\n",
|
|
"plt.title(\"Histogram over the score distribution\")\n",
|
|
"plt.hist((_final_eval * 64), density=True)\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"def calculate_who_won(board_history: np.ndarray) -> np.ndarray:\n",
|
|
" who_won = evaluate_who_won(board_history[-1])\n",
|
|
" return who_won\n",
|
|
"\n",
|
|
"\n",
|
|
"plt.title(\"Win distribtuion\")\n",
|
|
"plt.bar(\n",
|
|
" [\"black\", \"draw\", \"white\"],\n",
|
|
" pd.Series(calculate_who_won(_board_history)).value_counts().sort_index() / 10000,\n",
|
|
")\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"def history_changed(board_history: np.ndarray) -> np.ndarray:\n",
|
|
" return ~np.all(\n",
|
|
" np.roll(board_history, shift=1, axis=0) == board_history, axis=(2, 3)\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"plt.title(\"Share of turns skipped\")\n",
|
|
"plt.plot(1 - np.mean(history_changed(_board_history), axis=1))\n",
|
|
"plt.xlabel(\"Turn\")\n",
|
|
"plt.ylabel(\"Factor of skipped turns\")\n",
|
|
"plt.yscale(\"log\", base=10)\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(70, 10000)"
|
|
]
|
|
},
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def get_gamma_table(board_history, gamma_value: float):\n",
|
|
" unchanged = history_changed(board_history)\n",
|
|
" gamma_values = np.ones_like(unchanged, dtype=float)\n",
|
|
" gamma_values[unchanged] = gamma_value\n",
|
|
" return gamma_values\n",
|
|
"\n",
|
|
"\n",
|
|
"get_gamma_table(_board_history, 0.8).shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([ 0.09677184, 0.0037773 , 0.12190913, 0.03519891, 0.16118614,\n",
|
|
" 0.00617017, 0.12490022, -0.03918723, 0.14632847, -0.01240192,\n",
|
|
" 0.1016851 , 0.00991888, 0.1295861 , -0.03332988, 0.07552515,\n",
|
|
" -0.10090606, 0.14730492, -0.08930635, 0.08367957, -0.09071304,\n",
|
|
" 0.1600462 , 0.08287025, 0.22077531, -0.07559336, 0.1789458 ,\n",
|
|
" 0.02836975, 0.23077469, 0.01503086, 0.13597608, -0.18159241,\n",
|
|
" -0.03167801, -0.23491001, 0.05792499, -0.04478127, 0.06121092,\n",
|
|
" -0.04067385, 0.37884519, 0.04386898, 0.17202373, -0.05840784,\n",
|
|
" 0.0441777 , -0.14009038, 0.02019953, -0.09193809, 0.15851489,\n",
|
|
" 0.08095611, 0.45275764, 0.13625955, 0.36563693, -0.05076633,\n",
|
|
" 0.28810459, -0.22580677, -0.16507096, -0.5579012 , -0.033314 ,\n",
|
|
" -0.15883 , 0.23115 , -0.45325 , -0.37125 , -0.58125 ,\n",
|
|
" -0.21875 , -0.21875 , -0.21875 , -0.21875 , -0.21875 ,\n",
|
|
" -0.21875 , -0.21875 , -0.21875 , -0.21875 , -0.21875 ])"
|
|
]
|
|
},
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def calculate_q_reword(\n",
|
|
" board_history: np.ndarray,\n",
|
|
" who_won_fraction: float = 0.2,\n",
|
|
" final_score_fraction=0.2,\n",
|
|
" gamma=0.8,\n",
|
|
") -> np.ndarray:\n",
|
|
" assert who_won_fraction + final_score_fraction <= 1\n",
|
|
" assert final_score_fraction >= 0\n",
|
|
" assert who_won_fraction >= 0\n",
|
|
"\n",
|
|
" gama_table = get_gamma_table(board_history, gamma)\n",
|
|
" combined_score = np.zeros_like(gama_table)\n",
|
|
" combined_score += calculate_direct_score(board_history) * (\n",
|
|
" 1 - who_won_fraction + final_score_fraction\n",
|
|
" )\n",
|
|
" combined_score[-1] += (\n",
|
|
" calculate_final_evaluation_for_history(board_history) * final_score_fraction\n",
|
|
" )\n",
|
|
" combined_score[-1] += calculate_who_won(board_history) * who_won_fraction\n",
|
|
" for turn in range(SIMULATE_TURNS - 1, 0, -1):\n",
|
|
" values = gama_table[turn] * combined_score[turn]\n",
|
|
" combined_score[turn - 1] += values\n",
|
|
"\n",
|
|
" return combined_score\n",
|
|
"\n",
|
|
"\n",
|
|
"calculate_q_reword(\n",
|
|
" _board_history, gamma=0.8, who_won_fraction=0, final_score_fraction=1\n",
|
|
")[:, 0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([-1.53249554e-06, -1.91561943e-06, -2.39452428e-06, -2.99315535e-06,\n",
|
|
" -3.74144419e-06, -4.67680524e-06, -5.84600655e-06, -7.30750819e-06,\n",
|
|
" -9.13438523e-06, -1.14179815e-05, -1.42724769e-05, -1.78405962e-05,\n",
|
|
" -2.23007452e-05, -2.78759315e-05, -3.48449144e-05, -4.35561430e-05,\n",
|
|
" -5.44451787e-05, -6.80564734e-05, -8.50705917e-05, -1.06338240e-04,\n",
|
|
" -1.32922800e-04, -1.66153499e-04, -2.07691874e-04, -2.59614843e-04,\n",
|
|
" -3.24518554e-04, -4.05648192e-04, -5.07060240e-04, -6.33825300e-04,\n",
|
|
" -7.92281625e-04, -9.90352031e-04, -1.23794004e-03, -1.54742505e-03,\n",
|
|
" -1.93428131e-03, -2.41785164e-03, -3.02231455e-03, -3.77789319e-03,\n",
|
|
" -4.72236648e-03, -5.90295810e-03, -7.37869763e-03, -9.22337204e-03,\n",
|
|
" -1.15292150e-02, -1.44115188e-02, -1.80143985e-02, -2.25179981e-02,\n",
|
|
" -2.81474977e-02, -3.51843721e-02, -4.39804651e-02, -5.49755814e-02,\n",
|
|
" -6.87194767e-02, -8.58993459e-02, -1.07374182e-01, -1.34217728e-01,\n",
|
|
" -1.67772160e-01, -2.09715200e-01, -2.62144000e-01, -3.27680000e-01,\n",
|
|
" -4.09600000e-01, -5.12000000e-01, -6.40000000e-01, -8.00000000e-01,\n",
|
|
" -1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00,\n",
|
|
" -1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00,\n",
|
|
" -1.00000000e+00, -1.00000000e+00])"
|
|
]
|
|
},
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"calculate_q_reword(\n",
|
|
" _board_history, gamma=0.8, who_won_fraction=1, final_score_fraction=0\n",
|
|
")[:, 0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([ 3.09670969, 0.12088712, 3.9011089 , 1.12638612,\n",
|
|
" 5.15798265, 0.19747831, 3.99684789, -1.25394014,\n",
|
|
" 4.68257483, -0.39678147, 3.25402317, 0.31752896,\n",
|
|
" 4.1469112 , -1.066361 , 2.41704875, -3.22868907,\n",
|
|
" 4.71413867, -2.85732667, 2.67834167, -2.90207292,\n",
|
|
" 5.12240885, 2.65301107, 7.06626383, -2.41717021,\n",
|
|
" 5.72853724, 0.91067155, 7.38833944, 0.4854243 ,\n",
|
|
" 4.35678037, -5.80402453, -1.00503067, -7.50628834,\n",
|
|
" 1.86713958, -1.41607552, 1.9799056 , -1.27511801,\n",
|
|
" 12.15610249, 1.44512812, 5.55641015, -1.80448732,\n",
|
|
" 1.49439085, -4.38201144, 0.77248571, -2.78439287,\n",
|
|
" 5.26950892, 2.83688614, 14.79610768, 4.7451346 ,\n",
|
|
" 12.18141825, -1.02322719, 9.97096602, -6.28629248,\n",
|
|
" -4.1078656 , -16.384832 , 0.76896 , -2.7888 ,\n",
|
|
" 10.264 , -10.92 , -7.4 , -13. ,\n",
|
|
" 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. , 0. , 0. ,\n",
|
|
" 0. , 0. ])"
|
|
]
|
|
},
|
|
"execution_count": 41,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"calculate_q_reword(\n",
|
|
" _board_history, gamma=0.8, who_won_fraction=0, final_score_fraction=0\n",
|
|
")[:, 0] * 64"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def weights_init_normal(m):\n",
|
|
" \"\"\"Takes in a module and initializes all linear layers with weight\n",
|
|
" values taken from a normal distribution.\n",
|
|
" Source: https://stackoverflow.com/a/55546528/11003343\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" classname = m.__class__.__name__\n",
|
|
" # for every Linear layer in a model\n",
|
|
" if classname.find(\"Linear\") != -1:\n",
|
|
" y = m.in_features\n",
|
|
" # m.weight.data shoud be taken from a normal distribution\n",
|
|
" m.weight.data.normal_(0.0, 1 / np.sqrt(y))\n",
|
|
" # m.bias.data should be 0\n",
|
|
" m.bias.data.fill_(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[0.],\n",
|
|
" [0.],\n",
|
|
" [0.],\n",
|
|
" [0.],\n",
|
|
" [0.]], grad_fn=<TanhBackward0>)"
|
|
]
|
|
},
|
|
"execution_count": 43,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"BATCH_SIZE = 1000\n",
|
|
"\n",
|
|
"\n",
|
|
"class DQLNet(nn.Module):\n",
|
|
" def __init__(self, load_from: str | None = None):\n",
|
|
" super().__init__()\n",
|
|
" self.fc1 = nn.Linear(8 * 8 * 2, 128 * 2)\n",
|
|
" # self.nb1 = nn.BatchNorm1d([128 * 2])\n",
|
|
" self.fc2 = nn.Linear(128 * 2, 128 * 3)\n",
|
|
" # self.nb2 = nn.BatchNorm1d([128 * 3])\n",
|
|
" self.fc3 = nn.Linear(128 * 3, 128 * 2)\n",
|
|
" self.fc4 = nn.Linear(128 * 2, 1)\n",
|
|
" if not load_from:\n",
|
|
" self.apply(weights_init_normal)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" if isinstance(x, np.ndarray):\n",
|
|
" x = torch.from_numpy(x).float()\n",
|
|
" x = torch.flatten(x, 1)\n",
|
|
" x = self.fc1(x)\n",
|
|
" x = F.relu(x)\n",
|
|
" # x = self.nb1(x)\n",
|
|
" # x = self.dropout1(x)\n",
|
|
" x = self.fc2(x)\n",
|
|
" x = F.relu(x)\n",
|
|
" # x = self.nb2(x)\n",
|
|
" x = self.fc3(x)\n",
|
|
" x = F.relu(x)\n",
|
|
" x = self.fc4(x)\n",
|
|
" x = torch.tanh(x)\n",
|
|
" return x\n",
|
|
"\n",
|
|
"\n",
|
|
"DQLNet().forward(np.zeros((5, 2, 8, 8)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class SymmetryMode(Enum):\n",
|
|
" MULTIPLY = \"MULTIPLY\"\n",
|
|
" BREAK_SEQUENCE = \"BREAK_SEQUENCE\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"((70, 100, 8, 8), (70, 100, 2))"
|
|
]
|
|
},
|
|
"execution_count": 45,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"_board_history, _action_history = simulate_game(100, (RandomPolicy(1), RandomPolicy(1)))\n",
|
|
"_board_history.shape, _action_history.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"472 ms ± 24.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
|
|
"peak memory: 382.54 MiB, increment: 6.84 MiB\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(70, 100, 2, 8, 8)"
|
|
]
|
|
},
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def action_to_q_learning_format(\n",
|
|
" board_history: np.ndarray, action_history: np.ndarray\n",
|
|
") -> np.ndarray:\n",
|
|
" q_learning_format = np.zeros(\n",
|
|
" (SIMULATE_TURNS, board_history.shape[1], 2, 8, 8), dtype=float\n",
|
|
" )\n",
|
|
" q_learning_format[:, :, 1, :, :] = -1\n",
|
|
" q_learning_format[:, :, 1, action_history[:, :, 0], action_history[:, :, 0]] = 1\n",
|
|
" return q_learning_format\n",
|
|
"\n",
|
|
"\n",
|
|
"%timeit action_to_q_learning_format(_board_history, _action_history)\n",
|
|
"%memit action_to_q_learning_format(_board_history, _action_history)\n",
|
|
"action_to_q_learning_format(_board_history, _action_history).shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"493 ms ± 27.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
|
|
"peak memory: 378.35 MiB, increment: 6.84 MiB\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(2, 2, 2, 70, 100, 2, 8, 8)"
|
|
]
|
|
},
|
|
"execution_count": 47,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def build_symetry_action(\n",
|
|
" board_history: np.ndarray, action_history: np.ndarray\n",
|
|
") -> np.ndarray:\n",
|
|
" board_history = board_history.copy()\n",
|
|
" board_history[::2] *= -1\n",
|
|
" q_learning_format = np.zeros(\n",
|
|
" (2, 2, 2, SIMULATE_TURNS, board_history.shape[1], 2, 8, 8)\n",
|
|
" )\n",
|
|
" q_learning_format[0, 0, 0, :, :, :, :, :] = action_to_q_learning_format(\n",
|
|
" board_history, action_history\n",
|
|
" )\n",
|
|
" q_learning_format[1, 0, 0, :, :, :, :, :] = np.transpose(\n",
|
|
" q_learning_format[0, 0, 0, :, :, :, :, :], [0, 1, 2, 4, 3]\n",
|
|
" )\n",
|
|
" q_learning_format[:, 1, 0, :, :, :, :, :] = q_learning_format[\n",
|
|
" :, 0, 0, :, :, :, ::-1, :\n",
|
|
" ]\n",
|
|
" q_learning_format[:, :, 1, :, :, :, :, :] = q_learning_format[\n",
|
|
" :, :, 0, :, :, :, :, ::-1\n",
|
|
" ]\n",
|
|
" return q_learning_format\n",
|
|
"\n",
|
|
"\n",
|
|
"%timeit build_symetry_action(_board_history, _action_history)\n",
|
|
"%memit build_symetry_action(_board_history, _action_history)\n",
|
|
"build_symetry_action(_board_history, _action_history).shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 67,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def live_history(training_history: pd.DataFrame, trainable, max_epochs: int):\n",
|
|
" clear_output(wait=True)\n",
|
|
" # plt.ylim(0, 100)\n",
|
|
" _ = training_history[[c for c in training_history.columns if c[0] != \"base\"]].plot(\n",
|
|
" secondary_y=[c for c in training_history.columns if c[1] == \"final_score\"]\n",
|
|
" )\n",
|
|
" plt.xlim(0, max_epochs)\n",
|
|
"\n",
|
|
" plt.title(\"title\")\n",
|
|
" plt.xlabel(\"axis x\")\n",
|
|
" plt.ylabel(\"axis y\")\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 71,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class QLPolicy(GamePolicy):\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" epsilon: float,\n",
|
|
" neural_network: DQLNet,\n",
|
|
" symmetry_mode: SymmetryMode,\n",
|
|
" gamma: float = 0.8,\n",
|
|
" who_won_fraction: float = 0,\n",
|
|
" final_score_fraction: float = 0,\n",
|
|
" optimizer: torch.optim.Optimizer | None = None,\n",
|
|
" loss: nn.modules.loss._Loss | None = None,\n",
|
|
" ):\n",
|
|
" super().__init__(epsilon)\n",
|
|
" assert 0 <= gamma <= 1\n",
|
|
" self.gamma: str = gamma\n",
|
|
" del gamma\n",
|
|
" self.symmetry_mode: SymmetryMode = symmetry_mode\n",
|
|
" del symmetry_mode\n",
|
|
" self.neural_network: DQLNet = neural_network\n",
|
|
" del neural_network\n",
|
|
" self.who_won_fraction: final = who_won_fraction\n",
|
|
" del who_won_fraction\n",
|
|
" self.final_score_fraction: final = final_score_fraction\n",
|
|
" del final_score_fraction\n",
|
|
"\n",
|
|
" if optimizer is None:\n",
|
|
" self.optimizer = torch.optim.Adam(self.neural_network.parameters(), lr=5e-3)\n",
|
|
" else:\n",
|
|
" self.optimizer = optimizer\n",
|
|
" if loss is None:\n",
|
|
" self.loss = nn.MSELoss()\n",
|
|
" else:\n",
|
|
" self.loss = loss\n",
|
|
" self.training_results: list[dict[tuple[str, str], float]] = []\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def policy_name(self) -> str:\n",
|
|
" symmetry_name = {SymmetryMode.MULTIPLY: \"M\", SymmetryMode.BREAK_SEQUENCE: \"B\"}\n",
|
|
" g = f\"{self.gamma:.1f}\".replace(\".\", \"\")\n",
|
|
" ww = f\"{self.who_won_fraction:.1f}\".replace(\".\", \"\")\n",
|
|
" fsf = f\"{self.final_score_fraction:.1f}\".replace(\".\", \"\")\n",
|
|
" return f\"QL-{symmetry_name[self.symmetry_mode]}-G{g}-WW{ww}-FSF{fsf}-{ql_policy.neural_network.__class__.__name__}-{self.loss.__class__.__name__}\"\n",
|
|
"\n",
|
|
" def _internal_policy(self, boards: np.ndarray) -> np.ndarray:\n",
|
|
" results = np.zeros_like(boards, dtype=float)\n",
|
|
" results = torch.from_numpy(results).float()\n",
|
|
" q_learning_boards = np.zeros((boards.shape[0], 2, 8, 8))\n",
|
|
" q_learning_boards[:, 0, :, :] = boards\n",
|
|
" poss_turns = boards == 0 # checks where fields are empty.\n",
|
|
" poss_turns &= binary_dilation(boards == -1, SURROUNDING)\n",
|
|
" turn_possible = np.any(poss_turns, axis=0)\n",
|
|
" for action_x, action_y in itertools.product(range(8), range(8)):\n",
|
|
" if not turn_possible[action_x, action_y]:\n",
|
|
" continue\n",
|
|
" _q_learning_board = q_learning_boards[\n",
|
|
" poss_turns[:, action_x, action_y]\n",
|
|
" ].copy()\n",
|
|
" _q_learning_board[:, 1, action_x, action_y] = 1\n",
|
|
" ql_result = self.neural_network.forward(_q_learning_board)\n",
|
|
" results[poss_turns[:, action_x, action_y], action_x, action_y] = (\n",
|
|
" ql_result.reshape(-1) + 0.1\n",
|
|
" )\n",
|
|
" return results.cpu().detach().numpy()\n",
|
|
"\n",
|
|
" def generate_trainings_data(self, generate_data_size: int) -> np.ndarray:\n",
|
|
" train_boards, train_actions = simulate_game(generate_data_size, [self] * 2)\n",
|
|
" action_possible = ~np.all(train_actions[:, :] == -1, axis=(2))\n",
|
|
" q_leaning_formated_action = build_symetry_action(train_boards, train_actions)\n",
|
|
" q_rewords = calculate_q_reword(\n",
|
|
" board_history=train_boards,\n",
|
|
" who_won_fraction=self.who_won_fraction,\n",
|
|
" final_score_fraction=self.final_score_fraction,\n",
|
|
" )\n",
|
|
" if self.symmetry_mode == SymmetryMode.MULTIPLY:\n",
|
|
" q_rewords = np.array([q_rewords] * 8)\n",
|
|
" action_possible = np.array([action_possible] * 8).reshape(-1)\n",
|
|
"\n",
|
|
" elif self.symmetry_mode == SymmetryMode.BREAK_SEQUENCE:\n",
|
|
" axis1 = np.random.randint(0, high=2, size=SIMULATE_TURNS, dtype=int)\n",
|
|
" axis2 = np.random.randint(0, high=2, size=SIMULATE_TURNS, dtype=int)\n",
|
|
" axis3 = np.random.randint(0, high=2, size=SIMULATE_TURNS, dtype=int)\n",
|
|
" q_leaning_formated_action = q_leaning_formated_action[\n",
|
|
" axis1, axis2, axis3, range(SIMULATE_TURNS)\n",
|
|
" ]\n",
|
|
" action_possible = action_possible.reshape(-1)\n",
|
|
"\n",
|
|
" return (\n",
|
|
" torch.from_numpy(\n",
|
|
" q_leaning_formated_action.reshape(-1, 2, BOARD_SIZE, BOARD_SIZE)[\n",
|
|
" action_possible\n",
|
|
" ]\n",
|
|
" ).float(),\n",
|
|
" torch.from_numpy(q_rewords.reshape(-1, 1)[action_possible]).float(),\n",
|
|
" )\n",
|
|
"\n",
|
|
" def train_batch(self, nr_of_games: int):\n",
|
|
" x_train, y_train = self.generate_trainings_data(nr_of_games)\n",
|
|
" y_pred = self.neural_network.forward(x_train)\n",
|
|
" loss_score = self.loss(y_pred, y_train)\n",
|
|
" self.optimizer.zero_grad()\n",
|
|
"\n",
|
|
" loss_score.backward()\n",
|
|
" # Update the parameters\n",
|
|
" self.optimizer.step()\n",
|
|
" # generate trainings data\n",
|
|
"\n",
|
|
" def evaluate_model(self, compare_models: list[GamePolicy], nr_of_games: int):\n",
|
|
" result_dict: dict[tuple[str, str], float] = {}\n",
|
|
" eval_copy = copy.copy(ql_policy)\n",
|
|
" eval_copy._epsilon = 1\n",
|
|
" for model in compare_models:\n",
|
|
" boards_white, _ = simulate_game(nr_of_games, (eval_copy, model))\n",
|
|
" boards_black, _ = simulate_game(nr_of_games, (model, eval_copy))\n",
|
|
" win_eval_white = evaluate_who_won(boards_white[-1])\n",
|
|
" win_eval_black = evaluate_who_won(boards_black[-1])\n",
|
|
" result_dict[(model.policy_name, \"final_score\")] = np.mean(\n",
|
|
" final_boards_evaluation(boards_white[-1])\n",
|
|
" + final_boards_evaluation(boards_black[-1]) * -1\n",
|
|
" )\n",
|
|
" result_dict[(model.policy_name, \"white_win\")] = (\n",
|
|
" np.sum(win_eval_white == 1) / nr_of_games\n",
|
|
" )\n",
|
|
" result_dict[(model.policy_name, \"white_lose\")] = (\n",
|
|
" np.sum(win_eval_white == -1) / nr_of_games\n",
|
|
" )\n",
|
|
" result_dict[(model.policy_name, \"black_win\")] = (\n",
|
|
" np.sum(win_eval_black == 1) / nr_of_games\n",
|
|
" )\n",
|
|
" result_dict[(model.policy_name, \"black_lose\")] = (\n",
|
|
" np.sum(win_eval_black == -1) / nr_of_games\n",
|
|
" )\n",
|
|
" result_dict[(\"base\", \"base\")] = nr_of_games\n",
|
|
" return result_dict\n",
|
|
"\n",
|
|
" def save(self):\n",
|
|
" filename: str = f\"{self.policy_name}-{len(self.training_results)}\"\n",
|
|
" with open(TRINING_RESULT_PATH / Path(f\"{filename}.pickle\"), \"wb\") as f:\n",
|
|
" pickle.dump(self.training_results, f)\n",
|
|
" torch.save(\n",
|
|
" self.neural_network.state_dict(),\n",
|
|
" TRINING_RESULT_PATH / Path(f\"{filename}.torch\"),\n",
|
|
" )\n",
|
|
"\n",
|
|
" def load(self):\n",
|
|
" pickle_files = glob.glob(f\"{TRINING_RESULT_PATH}/{self.policy_name}-*.pickle\")\n",
|
|
" torch_files = glob.glob(f\"{TRINING_RESULT_PATH}/{self.policy_name}-*.torch\")\n",
|
|
"\n",
|
|
" assert len(pickle_files) == len(torch_files)\n",
|
|
" if not pickle_files:\n",
|
|
" return\n",
|
|
"\n",
|
|
" pickle_dict = {\n",
|
|
" int(file.split(\"-\")[-1].split(\".\")[0]): file for file in pickle_files\n",
|
|
" }\n",
|
|
" torch_dict = {\n",
|
|
" int(file.split(\"-\")[-1].split(\".\")[0]): file for file in torch_files\n",
|
|
" }\n",
|
|
" pickle_file = pickle_dict[max(pickle_dict.keys())]\n",
|
|
" torch_file = torch_dict[max(torch_dict.keys())]\n",
|
|
"\n",
|
|
" with open(pickle_file, \"rb\") as f:\n",
|
|
" self.training_results = pickle.load(f)\n",
|
|
"\n",
|
|
" self.neural_network.load_state_dict(torch.load(Path(torch_file)))\n",
|
|
"\n",
|
|
" def train(\n",
|
|
" self,\n",
|
|
" epochs: int,\n",
|
|
" batches: int,\n",
|
|
" batch_size: int,\n",
|
|
" eval_batch_size: int,\n",
|
|
" compare_with: list[GamePolicy],\n",
|
|
" save_every_epoch: bool = True,\n",
|
|
" live_plot: bool = True,\n",
|
|
" ) -> pd.DataFrame:\n",
|
|
" max_epochs = epochs + len(self.training_results)\n",
|
|
" assert epochs > 0\n",
|
|
" for epoch in tqdm(range(epochs)):\n",
|
|
" for batch in tqdm(range(batches)):\n",
|
|
" self.train_batch(batch_size)\n",
|
|
" self.training_results.append(\n",
|
|
" self.evaluate_model(compare_with, eval_batch_size)\n",
|
|
" )\n",
|
|
" if save_every_epoch:\n",
|
|
" self.save()\n",
|
|
" if live_plot:\n",
|
|
" live_history(self.history, self, max_epochs)\n",
|
|
" return self.history\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def history(self) -> pd.DataFrame:\n",
|
|
" pandas_result = pd.DataFrame(self.training_results)\n",
|
|
" pandas_result.columns = pd.MultiIndex.from_tuples(pandas_result.columns)\n",
|
|
" return pandas_result"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 76,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"20.8 s ± 1.61 s per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
|
|
"peak memory: 664.29 MiB, increment: 279.84 MiB\n",
|
|
"27.9 s ± 302 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
|
|
"peak memory: 384.48 MiB, increment: 0.00 MiB\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"ql_policy = QLPolicy(\n",
|
|
" 0.95,\n",
|
|
" neural_network=DQLNet(),\n",
|
|
" symmetry_mode=SymmetryMode.MULTIPLY,\n",
|
|
" gamma=0.8,\n",
|
|
" who_won_fraction=0,\n",
|
|
" final_score_fraction=0,\n",
|
|
")\n",
|
|
"_batch_size = 100\n",
|
|
"%timeit ql_policy.train_batch(_batch_size)\n",
|
|
"%memit ql_policy.train_batch(_batch_size)\n",
|
|
"%timeit ql_policy.evaluate_model([RandomPolicy(0)], _batch_size)\n",
|
|
"%memit ql_policy.evaluate_model([RandomPolicy(0)], _batch_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 73,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'QL-M-G08-WW10-FSF00-DQLNet-MSELoss'"
|
|
]
|
|
},
|
|
"execution_count": 73,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ql_policy = QLPolicy(\n",
|
|
" 0.95,\n",
|
|
" neural_network=DQLNet(),\n",
|
|
" symmetry_mode=SymmetryMode.MULTIPLY,\n",
|
|
" gamma=0.8,\n",
|
|
" who_won_fraction=1,\n",
|
|
" final_score_fraction=0,\n",
|
|
")\n",
|
|
"ql_policy.policy_name"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"gen = ql_policy.generate_trainings_data(10)\n",
|
|
"gen[0][4, 0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 74,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"ql_policy.load()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 75,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "d8c055d3efec4253af97679ddaad11ca",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/200 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "39545244041249cba2f5876dfeded265",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/10 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"ename": "KeyboardInterrupt",
|
|
"evalue": "",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
|
"\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
|
|
"Cell \u001B[1;32mIn[75], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[43mql_policy\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtrain\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m200\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m10\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m1000\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m100\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m[\u001B[49m\u001B[43mRandomPolicy\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mGreedyPolicy\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n",
|
|
"Cell \u001B[1;32mIn[71], line 180\u001B[0m, in \u001B[0;36mQLPolicy.train\u001B[1;34m(self, epochs, batches, batch_size, eval_batch_size, compare_with, save_every_epoch, live_plot)\u001B[0m\n\u001B[0;32m 178\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m epoch \u001B[38;5;129;01min\u001B[39;00m tqdm(\u001B[38;5;28mrange\u001B[39m(epochs)):\n\u001B[0;32m 179\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m batch \u001B[38;5;129;01min\u001B[39;00m tqdm(\u001B[38;5;28mrange\u001B[39m(batches)):\n\u001B[1;32m--> 180\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtrain_batch\u001B[49m\u001B[43m(\u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 181\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mtraining_results\u001B[38;5;241m.\u001B[39mappend(\n\u001B[0;32m 182\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mevaluate_model(compare_with, eval_batch_size)\n\u001B[0;32m 183\u001B[0m )\n\u001B[0;32m 184\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m save_every_epoch:\n",
|
|
"Cell \u001B[1;32mIn[71], line 97\u001B[0m, in \u001B[0;36mQLPolicy.train_batch\u001B[1;34m(self, nr_of_games)\u001B[0m\n\u001B[0;32m 96\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mtrain_batch\u001B[39m(\u001B[38;5;28mself\u001B[39m, nr_of_games: \u001B[38;5;28mint\u001B[39m):\n\u001B[1;32m---> 97\u001B[0m x_train, y_train \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgenerate_trainings_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnr_of_games\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 98\u001B[0m y_pred \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mneural_network\u001B[38;5;241m.\u001B[39mforward(x_train)\n\u001B[0;32m 99\u001B[0m loss_score \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mloss(y_pred, y_train)\n",
|
|
"Cell \u001B[1;32mIn[71], line 66\u001B[0m, in \u001B[0;36mQLPolicy.generate_trainings_data\u001B[1;34m(self, generate_data_size)\u001B[0m\n\u001B[0;32m 65\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mgenerate_trainings_data\u001B[39m(\u001B[38;5;28mself\u001B[39m, generate_data_size: \u001B[38;5;28mint\u001B[39m) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m np\u001B[38;5;241m.\u001B[39mndarray:\n\u001B[1;32m---> 66\u001B[0m train_boards, train_actions \u001B[38;5;241m=\u001B[39m \u001B[43msimulate_game\u001B[49m\u001B[43m(\u001B[49m\u001B[43mgenerate_data_size\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;241;43m2\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[0;32m 67\u001B[0m action_possible \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m~\u001B[39mnp\u001B[38;5;241m.\u001B[39mall(train_actions[:, :] \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m, axis\u001B[38;5;241m=\u001B[39m(\u001B[38;5;241m2\u001B[39m))\n\u001B[0;32m 68\u001B[0m q_leaning_formated_action \u001B[38;5;241m=\u001B[39m build_symetry_action(train_boards, train_actions)\n",
|
|
"Cell \u001B[1;32mIn[23], line 25\u001B[0m, in \u001B[0;36msimulate_game\u001B[1;34m(nr_of_games, policies, tqdm_on)\u001B[0m\n\u001B[0;32m 23\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m policy_index \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[0;32m 24\u001B[0m current_boards \u001B[38;5;241m=\u001B[39m current_boards \u001B[38;5;241m*\u001B[39m \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m\n\u001B[1;32m---> 25\u001B[0m current_boards, action_taken \u001B[38;5;241m=\u001B[39m \u001B[43msingle_turn\u001B[49m\u001B[43m(\u001B[49m\u001B[43mcurrent_boards\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 26\u001B[0m action_history_stack[turn_index, :] \u001B[38;5;241m=\u001B[39m action_taken\n\u001B[0;32m 28\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m policy_index \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n",
|
|
"Cell \u001B[1;32mIn[22], line 25\u001B[0m, in \u001B[0;36msingle_turn\u001B[1;34m(current_boards, policy)\u001B[0m\n\u001B[0;32m 19\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m VERIFY_POLICY:\n\u001B[0;32m 20\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m np\u001B[38;5;241m.\u001B[39mall(moves_possible(current_boards, policy_results)), (\n\u001B[0;32m 21\u001B[0m current_boards[(moves_possible(current_boards, policy_results) \u001B[38;5;241m==\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m)],\n\u001B[0;32m 22\u001B[0m policy_results[(moves_possible(current_boards, policy_results) \u001B[38;5;241m==\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m)],\n\u001B[0;32m 23\u001B[0m np\u001B[38;5;241m.\u001B[39mwhere(moves_possible(current_boards, policy_results) \u001B[38;5;241m==\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m),\n\u001B[0;32m 24\u001B[0m )\n\u001B[1;32m---> 25\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mdo_moves\u001B[49m\u001B[43m(\u001B[49m\u001B[43mcurrent_boards\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy_results\u001B[49m\u001B[43m)\u001B[49m, policy_results\n",
|
|
"Cell \u001B[1;32mIn[18], line 74\u001B[0m, in \u001B[0;36mdo_moves\u001B[1;34m(boards, moves)\u001B[0m\n\u001B[0;32m 72\u001B[0m boards \u001B[38;5;241m=\u001B[39m boards\u001B[38;5;241m.\u001B[39mcopy()\n\u001B[0;32m 73\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m game \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(boards\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m]):\n\u001B[1;32m---> 74\u001B[0m \u001B[43m_do_move\u001B[49m\u001B[43m(\u001B[49m\u001B[43mboards\u001B[49m\u001B[43m[\u001B[49m\u001B[43mgame\u001B[49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmoves\u001B[49m\u001B[43m[\u001B[49m\u001B[43mgame\u001B[49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 75\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m boards\n",
|
|
"Cell \u001B[1;32mIn[18], line 64\u001B[0m, in \u001B[0;36mdo_moves.<locals>._do_move\u001B[1;34m(_board, move)\u001B[0m\n\u001B[0;32m 62\u001B[0m action \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[0;32m 63\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m direction \u001B[38;5;129;01min\u001B[39;00m DIRECTIONS:\n\u001B[1;32m---> 64\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[43m_do_directional_move\u001B[49m\u001B[43m(\u001B[49m\u001B[43m_board\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmove\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdirection\u001B[49m\u001B[43m)\u001B[49m:\n\u001B[0;32m 65\u001B[0m action \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[0;32m 66\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m action:\n",
|
|
"Cell \u001B[1;32mIn[18], line 38\u001B[0m, in \u001B[0;36mdo_moves.<locals>._do_directional_move\u001B[1;34m(board, rec_move, rev_direction, step_one)\u001B[0m\n\u001B[0;32m 36\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m step_one\n\u001B[0;32m 37\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m next_field \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m:\n\u001B[1;32m---> 38\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[43m_do_directional_move\u001B[49m\u001B[43m(\u001B[49m\u001B[43mboard\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mrec_position\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mrev_direction\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mstep_one\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mFalse\u001B[39;49;00m\u001B[43m)\u001B[49m:\n\u001B[0;32m 39\u001B[0m board[\u001B[38;5;28mtuple\u001B[39m(rec_position\u001B[38;5;241m.\u001B[39mtolist())] \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[0;32m 40\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mTrue\u001B[39;00m\n",
|
|
"Cell \u001B[1;32mIn[18], line 30\u001B[0m, in \u001B[0;36mdo_moves.<locals>._do_directional_move\u001B[1;34m(board, rec_move, rev_direction, step_one)\u001B[0m\n\u001B[0;32m 15\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"Changes the color of enemy stones in one direction.\u001B[39;00m\n\u001B[0;32m 16\u001B[0m \n\u001B[0;32m 17\u001B[0m \u001B[38;5;124;03mThis function works recursive. The argument step_one should always be used in its default value.\u001B[39;00m\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 27\u001B[0m \u001B[38;5;124;03m All changes are made on the view of the numpy array and therefore not included in the return value.\u001B[39;00m\n\u001B[0;32m 28\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 29\u001B[0m rec_position \u001B[38;5;241m=\u001B[39m rec_move \u001B[38;5;241m+\u001B[39m rev_direction\n\u001B[1;32m---> 30\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43many\u001B[49m\u001B[43m(\u001B[49m\u001B[43m(\u001B[49m\u001B[43mrec_position\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m>\u001B[39;49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;241;43m8\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m|\u001B[39;49m\u001B[43m \u001B[49m\u001B[43m(\u001B[49m\u001B[43mrec_position\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m<\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m:\n\u001B[0;32m 31\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[0;32m 32\u001B[0m next_field \u001B[38;5;241m=\u001B[39m board[\u001B[38;5;28mtuple\u001B[39m(rec_position\u001B[38;5;241m.\u001B[39mtolist())]\n",
|
|
"File \u001B[1;32m<__array_function__ internals>:180\u001B[0m, in \u001B[0;36many\u001B[1;34m(*args, **kwargs)\u001B[0m\n",
|
|
"\u001B[1;31mKeyboardInterrupt\u001B[0m: "
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"ql_policy.train(200, 10, 1000, 100, [RandomPolicy(0), GreedyPolicy(0)])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"df2 = train(ql_winner_onyl, 5, 5, 1000, 50, [RandomPolicy(0), GreedyPolicy(0)])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df3 = train(ql_winner_onyl, 5, 5, 1000, 50, [RandomPolicy(0), GreedyPolicy(0)])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Train a model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Sources\n",
|
|
"\n",
|
|
"* Game rules and example board images [https://en.wikipedia.org/wiki/Reversi](https://en.wikipedia.org/wiki/Reversi)\n",
|
|
"* Game rules and example game images [https://de.wikipedia.org/wiki/Othello_(Spiel)](https://de.wikipedia.org/wiki/Othello_(Spiel))\n",
|
|
"* Game strategy examples [https://de.wikipedia.org/wiki/Computer-Othello](https://de.wikipedia.org/wiki/Computer-Othello)\n",
|
|
"* Image for 8 directions [https://www.researchgate.net/journal/EURASIP-Journal-on-Image-and-Video-Processing-1687-5281](https://www.researchgate.net/journal/EURASIP-Journal-on-Image-and-Video-Processing-1687-5281)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sys\n",
|
|
"\n",
|
|
"\n",
|
|
"def sizeof_fmt(num, suffix=\"B\"):\n",
|
|
" \"\"\"by Fred Cirera, https://stackoverflow.com/a/1094933/1870254, modified\"\"\"\n",
|
|
" for unit in [\"\", \"Ki\", \"Mi\", \"Gi\", \"Ti\", \"Pi\", \"Ei\", \"Zi\"]:\n",
|
|
" if abs(num) < 1024.0:\n",
|
|
" return \"%3.1f %s%s\" % (num, unit, suffix)\n",
|
|
" num /= 1024.0\n",
|
|
" return \"%.1f %s%s\" % (num, \"Yi\", suffix)\n",
|
|
"\n",
|
|
"\n",
|
|
"for name, size in sorted(\n",
|
|
" ((name, sys.getsizeof(value)) for name, value in list(locals().items())),\n",
|
|
" key=lambda x: -x[1],\n",
|
|
")[:20]:\n",
|
|
" print(\"{:>30}: {:>8}\".format(name, sizeof_fmt(size)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.8"
|
|
},
|
|
"toc-autonumbering": true,
|
|
"toc-showcode": false
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|