diff --git a/main.ipynb b/main.ipynb index d350229..c196bde 100644 --- a/main.ipynb +++ b/main.ipynb @@ -159,7 +159,10 @@ "BOARD_SIZE: Final[int] = 8 # defines the board side length as 8\n", "PLAYER: Final[int] = 1 # defines the number symbolising the player as 1\n", "ENEMY: Final[int] = -1 # defines the number symbolising the enemy as -1\n", - "EXAMPLE_STACK_SIZE: Final[int] = 1000 # defines the game stack size for examples" + "EXAMPLE_STACK_SIZE: Final[int] = 1000 # defines the game stack size for examples\n", + "IMPOSSIBLE: Final[np.ndarray] = np.array([-1, -1], dtype=int)\n", + "IMPOSSIBLE.setflags(write=False)\n", + "SIMULATE_TURNS: Final[int] = 70" ] }, { @@ -588,13 +591,71 @@ "assert move_possible(np.ones((8, 8)) * 0, np.array([-1, -1])) is True" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def moves_possible(boards: np.ndarray, moves: np.ndarray) -> np.ndarray:\n", + " \"\"\"Checks if a stack of moves can be executed on a stack of boards.\n", + "\n", + " Args:\n", + " boards: A board where the next stone should be placed.\n", + " moves: A stack stones to be placed. Each move is formatted as an array in the form of [x, y] if no turn is possible the value [-1, -1] is expected.\n", + "\n", + " Returns:\n", + " An array marking for each and every game and move in the stack if the move can be executed.\n", + " \"\"\"\n", + " arr_moves_possible = np.zeros(boards.shape[0], dtype=bool)\n", + " for game in range(boards.shape[0]):\n", + " if np.all(\n", + " moves[game] == -1\n", + " ): # can be all or any. All should be faster since most times neither value will be -1.\n", + " arr_moves_possible[game] = not np.any(\n", + " get_possible_turns(np.reshape(boards[game], (1, 8, 8)))\n", + " )\n", + " else:\n", + " arr_moves_possible[game] = any(\n", + " _recursive_steps(boards[game, :, :], direction, moves[game])\n", + " for direction in DIRECTIONS\n", + " )\n", + " return arr_moves_possible\n", + "\n", + "\n", + "np.testing.assert_array_equal(\n", + " moves_possible(np.ones((3, 8, 8)) * 1, np.array([[-1, -1]] * 3)),\n", + " np.array([True] * 3),\n", + ")\n", + "\n", + "np.testing.assert_array_equal(\n", + " moves_possible(get_new_games(3), np.array([[2, 3], [3, 2], [3, 2]])),\n", + " np.array([True] * 3),\n", + ")\n", + "np.testing.assert_array_equal(\n", + " moves_possible(get_new_games(3), np.array([[2, 2], [1, 1], [0, 0]])),\n", + " np.array([False] * 3),\n", + ")\n", + "np.testing.assert_array_equal(\n", + " moves_possible(np.ones((3, 8, 8)) * -1, np.array([[-1, -1]] * 3)),\n", + " np.array([True] * 3),\n", + ")\n", + "np.testing.assert_array_equal(\n", + " moves_possible(np.zeros((3, 8, 8)), np.array([[-1, -1]] * 3)),\n", + " np.array([True] * 3),\n", + ")" + ] + }, { "cell_type": "markdown", "source": [ "## Reword functions\n", "\n", - "For any kind of reinforcement learning is a reword function needed. For otello this would be the final score, the information who won or changes to the score. A combination of those three would also be possible.\n", - "It is probably not be possible to weight the current score to high in a reword function since that would be to close to a classic greedy algorithm. But some influce would increase learning behavior.\n", + "For any kind of reinforcement learning is a reword function needed.\n", + "For otello this would be the final score, the information who won or changes to the score.\n", + "A combination of those three would also be possible.\n", + "It is probably not be possible to weight the current score to high in a reword function since that would be to close to a classic greedy algorithm.\n", + "But some direct influence would increase the learning speed.\n", "In the next section are all three reword functions implemented to be combined and weight later on as needed." ], "metadata": { @@ -674,48 +735,16 @@ } }, { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "cell_type": "markdown", "source": [ - "def moves_possible(boards: np.ndarray, moves: np.ndarray) -> np.ndarray:\n", - " arr_moves_possible = np.zeros(boards.shape[0], dtype=bool)\n", - " for game in range(boards.shape[0]):\n", - " if np.all(moves[game] == -1):\n", - " arr_moves_possible[game] = not np.any(\n", - " get_possible_turns(np.reshape(boards[game], (1, 8, 8)))\n", - " )\n", - " else:\n", - " arr_moves_possible[game] = any(\n", - " _recursive_steps(boards[game, :, :], direction, moves[game])\n", - " for direction in DIRECTIONS\n", - " )\n", - " return arr_moves_possible\n", + "## Execute a chosen action\n", "\n", - "\n", - "np.testing.assert_array_equal(\n", - " moves_possible(np.ones((3, 8, 8)) * 1, np.array([[-1, -1]] * 3)),\n", - " np.array([True] * 3),\n", - ")\n", - "\n", - "np.testing.assert_array_equal(\n", - " moves_possible(get_new_games(3), np.array([[2, 3], [3, 2], [3, 2]])),\n", - " np.array([True] * 3),\n", - ")\n", - "np.testing.assert_array_equal(\n", - " moves_possible(get_new_games(3), np.array([[2, 2], [1, 1], [0, 0]])),\n", - " np.array([False] * 3),\n", - ")\n", - "np.testing.assert_array_equal(\n", - " moves_possible(np.ones((3, 8, 8)) * -1, np.array([[-1, -1]] * 3)),\n", - " np.array([True] * 3),\n", - ")\n", - "np.testing.assert_array_equal(\n", - " moves_possible(np.zeros((3, 8, 8)), np.array([[-1, -1]] * 3)),\n", - " np.array([True] * 3),\n", - ")" - ] + "After an evaluation what turns are possible there needs to be a function that executes a turn.\n", + "This next sections does that." + ], + "metadata": { + "collapsed": false + } }, { "cell_type": "code", @@ -724,13 +753,60 @@ "outputs": [], "source": [ "class InvalidTurn(ValueError):\n", - " pass\n", - "\n", - "\n", + " \"\"\"\n", + " This error is thrown if a given turn is not valid.\n", + " \"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "95.1 ms ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + }, + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAASIAAAEiCAYAAABdvt+2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAdqElEQVR4nO3de3BU5f0/8PdJNi4QsivEYFiyQEIsMcHw5aalGZCoIIFQ7Di0OqGCAgUJF3XaSqy2tAqLY9uhCgYK4dIh3OyIdRi5y6WiXAJCwRA0CGUhpDg27JJQV5I9vz8Ou78EctmzOec8m5z3a+aMbHbP+TwPG9885/YcSZZlGUREAkWJbgAREYOIiIRjEBGRcAwiIhKOQUREwjGIiEg4BhERCccgIiLhLEYX9Pv9qKioQFxcHCRJMro8ERlElmVcv34dDocDUVHNj3kMD6KKigo4nU6jyxKRIG63G0lJSc1+xvAgiouLC/65U3dja9+oBCADkIBOieapLbo++y6mtuj6N64o/63//3xTDA+iwO5Yp+7AxApjaxcnATWXgVgHkHfJPLVF12ffzdn3dQ4ljEI5BMOD1UQkHIOIiIRjEBGRcAwiIhKOQUREwjGIiEg4BhERCccgIiLhVAfRgQMHMG7cODgcDkiShA8++ECHZhGRmagOopqaGvTv3x9Lly7Voz1EZEKqb/HIyclBTk6OHm0hIpPS/V4zn88Hn88XfO31evUuSURtjO4Hq10uF+x2e3DhFCBEdDvdg6igoAAejye4uN1uvUsSURuj+66Z1WqF1WrVuwwRtWG8joiIhFM9IqqurkZ5eXnw9fnz53HixAl07doVPXv21LRxRGQOqoOopKQE2dnZwdcvvfQSAGDSpElYs2aNZg0jIvNQHUQjRoyALMt6tIWITIrHiIhIOAYREQnHICIi4RhERCQcg4iIhGMQEZFwDCIiEo5BRETCSbLBVyd6vV7Y7XZAUp7HbaQbVwDZD0hRQKfu5qktuj77bs6+11QAkAGPxwObzdbsZ8UFERGZQihBpPs0IE3iiMg09dl3c/Y9MCIKhbAg6pQI5F0ytmZxElBzWflCzFRbdH323Zx9X+dQgjAUPFhNRMIxiIhIOAYREQnHICIi4RhERCQcg4iIhGMQEZFwDCIiEk5VELlcLgwZMgRxcXHo1q0bnnjiCZw9e1avthGRSagKov379yM/Px+HDh3Crl27cPPmTYwaNQo1NTV6tY+ITEDVLR7bt29v8HrNmjXo1q0bjh07huHDh2vaMCIyj1bda+bxeAAAXbt2bfIzPp8PPp8v+Nrr9bamJBG1Q2EfrPb7/XjhhReQlZWFfv36Nfk5l8sFu90eXJxOZ7gliaidCjuI8vPzcfr0aWzcuLHZzxUUFMDj8QQXt9sdbkkiaqfC2jWbNWsWtm7digMHDiApKanZz1qtVlit1rAaR0TmoCqIZFnG7NmzsWXLFuzbtw/Jycl6tYuITERVEOXn52P9+vX4xz/+gbi4OFRWVgIA7HY7OnbsqEsDiaj9U3WMqLCwEB6PByNGjED37t2Dy6ZNm/RqHxGZgOpdMyIirfFeMyISjkFERMIxiIhIOAYREQnHICIi4RhERCQcg4iIhGMQEZFwkmzwVYperxd2ux2QgFiHkZWV53DLfkCKUp4Fbpbaouuz7+bse00FAFmZt8xmszX7WXFBRESmEEoQtWqGxlbhiMg09dl3c/Y9MCIKhbAg6pQI5F0ytmZxElBzWflCzFS7tfVvVgOecsDvA6KsgD0ViOlsTG0t8HsXU3+dQwnCUIgbEVFEqyoFSpcB7o8A79do+C+bBNhSAOcYIH0G0CVdVCupvWAQUQPe88A/pwOXdwGSBZBrG/mQDHjPAaWFwBfvAD1GAsOWAzbOk0dh4ul7CipbCbyXDlTsVV43GkL1BN6v2KusV7ZS3/ZR+8UgIgDA8QXAgWlA3XctB9Dt5FplvQPTlO0QqcUgIpStBEpe1WZbJa8CZUXabIvMg0Fkct7zwMHZ2m7z4Cxlu0ShYhCZ3D+nA36Vu2It8dcq2yUKlerJ8zMzM2Gz2WCz2TB06FBs27ZNr7aRzqpKlbNjao8JtUSuVbZbdUbb7VL7pSqIkpKSsGjRIhw7dgwlJSV45JFHMH78eHzxxRd6tY90VLpMOUWvB8minN4nCoWqX8Nx48Y1eL1gwQIUFhbi0KFDyMjI0LRhpD/3R9qPhgLkWsDNwTKFKOx/D+vq6vDee++hpqYGQ4cO1bJNZIDvr9+6YlpH3nPK7SFELVEdRKdOncLQoUPx3XffoXPnztiyZQvS05u+xt/n88Hn8wVfe73e8FpKmvKeQ8g3JIZNVu5RI2qJ6rNmffv2xYkTJ3D48GE8//zzmDRpEkpLS5v8vMvlgt1uDy5Op7NVDSZt+H0tf6Yt1aG2TXUQ3XXXXUhNTcWgQYPgcrnQv39//OUvf2ny8wUFBfB4PMHF7Xa3qsGkjShr+6pDbVurz5n4/f4Gu163s1qtsFr52xhp7KkAJOi7eybdqkPUAlVBVFBQgJycHPTs2RPXr1/H+vXrsW/fPuzYsUOv9pFOYjorU3l4z+lXw9ZH3bxFZF6qgujq1at45plncOXKFdjtdmRmZmLHjh0YOXKkXu0jHTnHKNf66HEKX7IAzhztt0vtk6ogKiri3YztSfoMZT4hPci1QPrz+myb2h/ea2ZiXdKVSc20vrpasijb7XK/ttul9otBZHLDlgNRGgdRlEXZLlGoGEQmZ0sGsjTePctawmljSR0GESFtKjD4DW22NWQBkDZFm22ReTCICAAw8DfA8BVAdAf1x4wki7Le8JXAgFf0aR+1bwwiCkqbCkwoBRzZyuuWAinwviNbWY8jIQoXHydEDdiSgbE76z3XbFsjN8hKysWKzhzlFD3PjlFrMYioUV3Sgay3lT+39kmvRC2RZFnWezKIBrxeL+x2OyABsQ4jK/MZ6Ow7+26kmgooU8F4PLDZbM1+VlwQEZEphBJE4nbNOCIyTX323Zx9D4yIQiEsiDolAnmXjK1ZnATUXFa+EDPVFl2ffTdn39c5lCAMBQ9WU4tEHqzmgXJzYBBRo4Kn7z+6Ncn+7afvU5RpRNJnKGfY2kttEoNBRA14zytPab28S7lgsdG5imTl2qLSQmUakR4jlZtcW3t/mcjaJBavrKagspXAe+lAxV7ldUsTpgXer9irrFe2sm3WJvEYRAQAOL4AODANqPtO/YyNcq2y3oFpynbaUm2KDAwiQtlKoORVbbZV8ipQpmIiT5G1KXIwiEzOex44OFvbbR6cpWw3kmtTZGEQmdw/pwN+jSfP99cq243k2hRZWhVEixYtgiRJeOGFFzRqDhmpqlQ5Q6X1UzzkWmW7VWciszZFnrCD6OjRo1i+fDkyMzO1bA8ZqHSZ9hPnB0gW5RR7JNamyBNWEFVXVyMvLw8rVqxAly5dtG4TGcT9kT7PNAOU7bq3RWZtijxhBVF+fj7Gjh2Lxx57TOv2kEG+v37rqmUdec8pt2hEUm2KTKoHxxs3bsTx48dx9OjRkD7v8/ng8/mCr71er9qSpIM7Zl3Ug6zcJyay9j3/p3Md0oSqEZHb7cbcuXNRXFyMDh06hLSOy+WC3W4PLk6nM6yGkrb8vpY/o1cdkbUpMqkKomPHjuHq1asYOHAgLBYLLBYL9u/fj7fffhsWiwV1dXV3rFNQUACPxxNc3G63Zo2n8EVZxdURWZsik6pds0cffRSnTp1q8LNnn30WaWlpePnllxEdHX3HOlarFVYrfyMijT0VgAR9d5GkW3UiqDZFJlVBFBcXh379+jX4WWxsLOLj4+/4OUW2mM7KdBrec/rVsPVpfO4gkbUpMvHKahNzjtH3Wh5nTmTWpsjT6l+Fffv2adAMEiF9hjKnjx7kWuWZZ5FYmyIPR0Qm1iVdmVhM65GJZFG229yDF0XWpsjDIDK5YcuBKI3DIMqibDeSa1NkYRCZnC0ZyNJ4FylrSWhTt4qsTZGFQURImwoMfkObbQ1ZAKRNaRu1KXIwiAgAMPA3wPAVQHQH9cdtJIuy3vCVwIBX2lZtigwMIgpKmwpMKAUc2crrlkIh8L4jW1mvNaMRkbVJPD5OiBqwJQNjd9Z7tti2Rm5SlZQLBp05ymlyrc5QiaxNYjGIqFFd0oGst5U/G/20VZG1SQxJlmW9J2RowOv1wm63AxIQ6zCysvIcbtkPSFHKs8DNUlt0ffbdnH2vqYAyHYvHA5vN1uxnxQUREZlCKEEkbteMIyLT1Gffzdn3wIgoFMKCqFMikHfJ2JrFSUDNZeULMVNt0fXZd3P2fZ1DCcJQ8GA1tUjkAWMrYpGAVFhgRS18+Abl8KHGmOJkGAYRNSp4Cv2jWxPd334KPUWZyiN9hnKWS0vdcT+GYwb6YQwSkAKp3uVuMvz4Bl/jND7CASzDFfABZu0Bg4ga8J5XnpR6eZdy0WCjj/yRlet7SguVqTx6jFRuNG3tPV7x6I08LEcGRqEONxGNmDs+IyEK3ZCKh/E8HsEcfIGdKMZ0fIsLrStOQvHKagoqWwm8lw5U7FVet/TcscD7FXuV9cpWhl87C1MwH6VIg3JpdWMhVF/g/TRkYz6+QBZ4aXVbxiAiAMDxBcCBaUDdd+offCjXKusdmKZsR60cvIJnsBIx6NBiAN0uGjGIQUc8g5XIAW82a6sYRISylUDJq9psq+RVoKwo9M9nYQqegJJeEqSwagbWewILkIXnwtoGicUgMjnveeDgbG23eXCWst2WxKM3nsI7kDV6nIcMGU/hHcSjtybbI+MwiEzun9MBv8bPoPfXKtttSR6WIxqWsEdCt5MgIRoxyAOnaGxrVAXR/PnzIUlSgyUtLU2vtpHOqkqVs2Nqjwm1RK5VtlvVzJn17rgfGRil+phQS6IRgwyMQiL4e9mWqB4RZWRk4MqVK8Hlk08+0aNdZIDSZfo+0qe0sOn3h2MG6nBTl9p1uImHwcd4tCWqfw0tFgsSExP1aAsZzP2R9qOhALlWmU+oKf0wRvPRUEA0YtAPOdiEubpsn7SnekT01VdfweFwICUlBXl5ebh48aIe7SKdfX/91hXTOvKeU24PuZ0VnZGAFF1rJ6APrIjVtQZpR1UQPfTQQ1izZg22b9+OwsJCnD9/HsOGDcP169ebXMfn88Hr9TZYSLw7Zj7Ug6zco3a7BPRpcNuGHiREIQGputYg7ajaNcvJ+f/P8c3MzMRDDz2EXr16YfPmzZgypfErW10uF37/+9+3rpWkOb9PXB0LrIbUNqoOtV6r/lm6++678YMf/ADl5Y38s3dLQUEBPB5PcHG73a0pSRqJMuj/0cbq1MKYFDSqDrVeq4Kouroa586dQ/fuTc+4ZLVaYbPZGiwknj0V0OjynaZJt+rc5huUQ4Zf19LKXfpN/wNJkUVVEP3yl7/E/v37ceHCBXz66af4yU9+gujoaDz99NN6tY90EtNZmcpDT7Y+jc9b5EMNvoG+R8q/wTnOW9SGqAqiS5cu4emnn0bfvn3x05/+FPHx8Th06BASEhL0ah/pyDlG3+uInDlNv38aH+l6HdFpNHPtAEUcVb+GGzdu1KsdJED6DGU+IT3Itcpzx5pyAMvwCOboUjsaMdiPZq6mpIjDe81MrEu6MqmZ1qMiyaJst7mHH17BGXyBnZqPiupwE19gJypRpul2SV8MIpMbthyI0jiIoizKdltSjOmow01N776vw00UI4Q7bimiMIhMzpYMZGm8e5a1JLRpY7/FBWzEHE3vvt+I2Zw2tg1iEBHSpgKD39BmW0MWAGkqZm09iCJ8gN8AQNgjo8B6H+AVHMSqsLZBYnHyfAIADPwN0OleZZI0f626m2Eli7I7lrVEXQgFbMNCePEfPIV3EA2Lqpth63ATdbiJjZjNEGrDOCKioLSpwIRSwKHMX9/iQezA+45sZb1wQijgIIowH+kogzJzf0sHsQPvl2Ev5iODIdTGcUREDdiSgbE76z3XbFsjN8hKysWKzhzlFH1zZ8fU+BYX8DYer/dcs5w7bpBVrpg+h9PYhv0o5NmxdoJBRI3qkg5kva382egnvV7BGWzCXGzCXD7p1SQkWZb1ngyiAa/XC7vdDkhArMPIyspzuGU/IEUpzwI3S23R9dl3c/a9pgLKVDAeT4v3mIoLIiIyhVCCSNyuGUdEpqnPvpuz74ERUSiEBVGnRCDvkrE1i5OAmsvKF2Km2qLrs+/m7Ps6hxKEoeDpeyISjkFERMIxiIhIOAYREQnHICIi4RhERCQcg4iIhGMQEZFwqoPo8uXLmDhxIuLj49GxY0c88MADKCkp0aNtRGQSqq6srqqqQlZWFrKzs7Ft2zYkJCTgq6++QpcuXfRqHxGZgKogevPNN+F0OrF69ergz5KTQ5icmIioGap2zT788EMMHjwYEyZMQLdu3TBgwACsWLGi2XV8Ph+8Xm+DhYioPlVB9PXXX6OwsBD33XcfduzYgeeffx5z5szB2rVrm1zH5XLBbrcHF6fT2epGE1H7oiqI/H4/Bg4ciIULF2LAgAH4xS9+gWnTpmHZsmVNrlNQUACPxxNc3G53qxtNRO2LqiDq3r070tPTG/zs/vvvx8WLF5tcx2q1wmazNViIiOpTFURZWVk4e/Zsg599+eWX6NWrl6aNIiJzURVEL774Ig4dOoSFCxeivLwc69evx1//+lfk5+fr1T4iMgFVQTRkyBBs2bIFGzZsQL9+/fD6669j8eLFyMvL06t9RGQCqqeKzc3NRW5urh5tISKT4r1mRCQcg4iIhGMQEZFwDCIiEo5BRETCMYiISDgGEREJxyAiIuEkWZZlIwt6vV7Y7XZAAmIdRlZWnsMt+wEpSnkWuFlqi67Pvpuz7zUVAGTA4/G0eLO7uCAiIlMIJYhU3+KhGY6ITFOffTdn3wMjolAIC6JOiUDeJWNrFicBNZeVL8RMtUXXZ9/N2fd1DiUIQ8GD1UQkHIOIiIRjEBGRcAwiIhKOQUREwjGIiEg4BhERCccgIiLhVAVR7969IUnSHQsfJ0REraHqyuqjR4+irq4u+Pr06dMYOXIkJkyYoHnDiMg8VAVRQkJCg9eLFi1Cnz598PDDD2vaKCIyl7DvNfv++++xbt06vPTSS5AkqcnP+Xw++Hy+4Guv1xtuSSJqp8I+WP3BBx/g2rVrmDx5crOfc7lcsNvtwcXpdIZbkojaqbCDqKioCDk5OXA4mp/Lo6CgAB6PJ7i43e5wSxJROxXWrtm///1v7N69G++//36Ln7VarbBareGUISKTCGtEtHr1anTr1g1jx47Vuj1EZEKqg8jv92P16tWYNGkSLBZxEzwSUfuhOoh2796Nixcv4rnnntOjPURkQqqHNKNGjYLB8+0TUTvHe82ISDgGEREJxyAiIuEYREQkHIOIiIRjEBGRcAwiIhJOkg2+KMjr9cJutwMSENv8/bKa4zPQ2Xf23Tg1FQBkwOPxwGazNftZcUFERKYQShCJu1mMIyLT1Gffzdn3wIgoFMKCqFMikHfJ2JrFSUDNZeULMVNt0fXZd3P2fZ1DCcJQ8GA1EQnHICIi4RhERCQcg4iIhGMQEZFwDCIiEo5BRETCMYiISDhVQVRXV4fXXnsNycnJ6NixI/r06YPXX3+dc1gTUauourL6zTffRGFhIdauXYuMjAyUlJTg2Wefhd1ux5w5c/RqIxG1c6qC6NNPP8X48eODD1bs3bs3NmzYgCNHjujSOCIyB1W7Zj/60Y+wZ88efPnllwCAkydP4pNPPkFOTo4ujSMic1A1Ipo3bx68Xi/S0tIQHR2Nuro6LFiwAHl5eU2u4/P54PP5gq+9Xm/4rSWidknViGjz5s0oLi7G+vXrcfz4caxduxZ//OMfsXbt2ibXcblcsNvtwcXpdLa60UTUvqgKol/96leYN28ennrqKTzwwAP4+c9/jhdffBEul6vJdQoKCuDxeIKL2+1udaOJqH1RtWt248YNREU1zK7o6Gj4/f4m17FarbBareG1johMQVUQjRs3DgsWLEDPnj2RkZGBzz//HH/+85/x3HPP6dU+IjIBVUH0zjvv4LXXXsPMmTNx9epVOBwOTJ8+Hb/97W/1ah8RmYCqIIqLi8PixYuxePFinZpDRGbEe82ISDgGEREJxyAiIuEYREQkHIOIiIRjEBGRcAwiIhKOQUREwkmywfO8ejwe3H333QCU53Eb6UYlABmABHRKNE9t0fXZdzG1RdcPPPf+2rVrsNvtzX7W8CC6dOkSpwIhMhG3242kpKRmP2N4EPn9flRUVCAuLg6SJKla1+v1wul0wu12w2az6dTCyKzPvpuvtuj6ra0tyzKuX78Oh8Nxx6wdt1N1r5kWoqKiWkzHlthsNiG/FJFQn303X23R9VtTu6VdsgAerCYi4RhERCRcmwoiq9WK3/3ud8JmfBRZn303X23R9Y2sbfjBaiKi27WpERERtU8MIiISjkFERMIxiIhIuDYVRJ999hmio6MxduxYw2pOnjwZkiQFl/j4eIwePRr/+te/DGtDZWUlZs+ejZSUFFitVjidTowbNw579uzRtW79vsfExODee+/FyJEjsWrVqmafZadH/frL6NGjda/dXP3y8nLda1dWVmLu3LlITU1Fhw4dcO+99yIrKwuFhYW4ceOGbnUnT56MJ5544o6f79u3D5Ik4dq1a7rUbVNBVFRUhNmzZ+PAgQOoqKgwrO7o0aNx5coVXLlyBXv27IHFYkFubq4htS9cuIBBgwbh448/xltvvYVTp05h+/btyM7ORn5+vu71A32/cOECtm3bhuzsbMydOxe5ubmora01rH79ZcOGDbrXba5+cnKyrjW//vprDBgwADt37sTChQvx+eef47PPPsOvf/1rbN26Fbt379a1vgiG3+IRrurqamzatAklJSWorKzEmjVr8MorrxhS22q1IjFRuXU5MTER8+bNw7Bhw/DNN98gISFB19ozZ86EJEk4cuQIYmNjgz/PyMgw5MGW9fveo0cPDBw4ED/84Q/x6KOPYs2aNZg6daph9UUQUX/mzJmwWCwoKSlp8J2npKRg/PjxaI9X3LSZEdHmzZuRlpaGvn37YuLEiVi1apWQL6S6uhrr1q1Damoq4uPjda313//+F9u3b0d+fn6DX8iAwHQqRnvkkUfQv39/vP/++0Lqt2fffvstdu7c2eR3DkD1zeJtQZsJoqKiIkycOBGAMlz2eDzYv3+/IbW3bt2Kzp07o3PnzoiLi8OHH36ITZs2tXhHcWuVl5dDlmWkpaXpWiccaWlpuHDhgu516v/dB5aFCxfqXrep+hMmTNC1XuA779u3b4Of33PPPcE2vPzyy7q2obG/85ycHF1rtolds7Nnz+LIkSPYsmULAMBiseBnP/sZioqKMGLECN3rZ2dno7CwEABQVVWFd999Fzk5OThy5Ah69eqlW91IHoLLsmzIv8z1/+4Dunbtqnvdpuo3NUrR25EjR+D3+5GXlwefz6drrcb+zg8fPhwcCOihTQRRUVERamtr4XA4gj+TZRlWqxVLliwJeaqBcMXGxiI1NTX4euXKlbDb7VixYgXeeOMN3ered999kCQJZWVlutUI15kzZ3Q/aAvc+XdvNKPrp6amQpIknD17tsHPU1JSAAAdO3bUvQ2N9fnSpUu61oz4XbPa2lr87W9/w5/+9CecOHEiuJw8eRIOh8PQMygBkiQhKioK//vf/3St07VrVzz++ONYunQpampq7nhfr1OpLfn4449x6tQpPPnkk0Lqt2fx8fEYOXIklixZ0uh33l5F/Iho69atqKqqwpQpU+4Y+Tz55JMoKirCjBkzdG2Dz+dDZWUlAGXXbMmSJaiursa4ceN0rQsAS5cuRVZWFh588EH84Q9/QGZmJmpra7Fr1y4UFhbizJkzutYP9L2urg7/+c9/sH37drhcLuTm5uKZZ57RtXb9+vVZLBbcc889utcW5d1330VWVhYGDx6M+fPnIzMzE1FRUTh69CjKysowaNAg0U3UnhzhcnNz5TFjxjT63uHDh2UA8smTJ3WrP2nSJBnK9OMyADkuLk4eMmSI/Pe//123mrerqKiQ8/Pz5V69esl33XWX3KNHD/nHP/6xvHfvXl3r1u+7xWKRExIS5Mcee0xetWqVXFdXp2vt2+vXX/r27at77UD98ePHG1LrdhUVFfKsWbPk5ORkOSYmRu7cubP84IMPym+99ZZcU1OjW92m+rx3714ZgFxVVaVLXU4DQkTCRfwxIiJq/xhERCQcg4iIhGMQEZFwDCIiEo5BRETCMYiISDgGEREJxyAiIuEYREQkHIOIiIRjEBGRcP8P3ZHAPKDQyJ0AAAAASUVORK5CYII=\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ "def do_moves(boards: np.ndarray, moves: np.ndarray) -> np.ndarray:\n", + " \"\"\"Executes a single move on a stack o Othello boards.\n", + "\n", + " Args:\n", + " boards: A stack of Othello boards where the next stone should be placed.\n", + " moves: A stack of stone placement orders for the game. Formatted as coordinates in an array [x, y] of the place where the stone should be placed. Should contain [-1,-1] if no new placement is possible.\n", + "\n", + " Returns:\n", + " The new state of the board.\n", + " \"\"\"\n", + "\n", " def _do_directional_move(\n", " board: np.ndarray, rec_move: np.ndarray, rev_direction, step_one=True\n", " ) -> bool:\n", + " \"\"\"Changes the color of enemy stones in one direction.\n", + "\n", + " This function works recursive. The argument step_one should always be used in its default value.\n", + "\n", + " Args:\n", + " board: A bord on which a stone was placed.\n", + " rec_move: The position on the board in x and y where this function is called from. Will be moved by recursive called.\n", + " rev_direction: The position where the stone was placed. Inside this recursion it will also be the last step that was checked.\n", + " step_one: Set to true if this is the first step in the recursion. False later on.\n", + "\n", + " Returns:\n", + " True if a stone could be flipped.\n", + " All changes are made on the view of the numpy array and therefore not included in the return value.\n", + " \"\"\"\n", " rec_position = rec_move + rev_direction\n", " if np.any((rec_position >= 8) | (rec_position < 0)):\n", " return False\n", @@ -746,16 +822,32 @@ " return False\n", "\n", " def _do_move(_board: np.ndarray, move: np.ndarray) -> None:\n", + " \"\"\"Executes a turn on a board.\n", + "\n", + " Args:\n", + " _board: The game board on wich to place a stone.\n", + " move: The coordinates of a stone that should be placed. Should be formatted as an array of the form [x, y]. The value [-1, -1] is expected if no turn is possible.\n", + "\n", + " Returns:\n", + " All changes are made on the view of the numpy array.\n", + " \"\"\"\n", " if np.all(move == -1):\n", + " if not move_possible(_board, move):\n", + " raise InvalidTurn(\"An action should be taken. A turn is possible.\")\n", " return\n", + "\n", + " # noinspection PyTypeChecker\n", " if _board[tuple(move.tolist())] != 0:\n", - " raise InvalidTurn\n", + " raise InvalidTurn(\"This turn is not possible.\")\n", + "\n", " action = False\n", " for direction in DIRECTIONS:\n", " if _do_directional_move(_board, move, direction):\n", " action = True\n", " if not action:\n", - " raise InvalidTurn()\n", + " raise InvalidTurn(\"This turn is not possible.\")\n", + "\n", + " # noinspection PyTypeChecker\n", " _board[tuple(move.tolist())] = 1\n", "\n", " boards = boards.copy()\n", @@ -764,8 +856,28 @@ " return boards\n", "\n", "\n", - "do_moves(get_new_games(10), np.array([[2, 3]] * 10))[0]" - ] + "%timeit do_moves(get_new_games(EXAMPLE_STACK_SIZE), np.array([[2, 3]] * EXAMPLE_STACK_SIZE))[0]\n", + "plot_othello_board(\n", + " do_moves(\n", + " get_new_games(EXAMPLE_STACK_SIZE), np.array([[2, 3]] * EXAMPLE_STACK_SIZE)\n", + " )[0]\n", + ")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## An abstract reversi game policy\n", + "\n", + "For an easy use of policies an abstract class containing the policy generation / requests an action in an inherited instance of this class.\n", + "This class filters the policy to only propose valid actions. Inherited instance do not need to care about this." + ], + "metadata": { + "collapsed": false + } }, { "cell_type": "code", @@ -774,33 +886,66 @@ "outputs": [], "source": [ "class GamePolicy(ABC):\n", - "\n", - " IMPOSSIBLE: np.ndarray = np.array([-1, -1], dtype=int)\n", + " \"\"\"\n", + " A game policy. Proposes where to place a stone next.\n", + " \"\"\"\n", "\n", " @property\n", " @abc.abstractmethod\n", " def policy_name(self) -> str:\n", + " \"\"\"The name of this policy\"\"\"\n", " raise NotImplementedError()\n", "\n", " @abc.abstractmethod\n", - " def internal_policy(self, boards: np.ndarray) -> np.ndarray:\n", + " def _internal_policy(self, boards: np.ndarray) -> np.ndarray:\n", + " \"\"\"The internal policy is an unfiltered policy. It should only be called from inside this function\n", + "\n", + " Args:\n", + " boards: A board where a policy should be calculated for.\n", + "\n", + " Returns:\n", + " The policy for this board. Should have the same size as the boards array.\n", + " \"\"\"\n", " raise NotImplementedError()\n", "\n", - " def get_policy(self, boards: np.ndarray) -> np.ndarray:\n", - " policies = self.internal_policy(boards)\n", + " def get_policy(\n", + " self, boards: np.ndarray, epsilon: float = 1\n", + " ) -> tuple[np.ndarray, np.ndarray]:\n", + " assert len(boards.shape) == 3\n", + " assert boards.shape == (BOARD_SIZE, BOARD_SIZE)\n", + "\n", + " # todo possibly change this function to only validate the purpose turn and\n", + "\n", + " policies = self._internal_policy(boards)\n", + " raw_policy = policies.copy()\n", + " if epsilon < 1:\n", + " policies = policies + np.random.rand(*boards.shape)\n", + "\n", + " # todo talk to team about backpropagation epsilon for greedy factor\n", + "\n", " possible_turns = get_possible_turns(boards)\n", " policies[possible_turns == False] = -1.0\n", " max_indices = [\n", " np.unravel_index(policy.argmax(), policy.shape) for policy in policies\n", " ]\n", " policy_vector = np.array(max_indices)\n", - "\n", + " max_policy = policy_vector\n", " no_turn_possible = np.all(policy_vector == 0, 1) & (policies[:, 0, 0] == -1.0)\n", "\n", - " policy_vector[no_turn_possible] = GamePolicy.IMPOSSIBLE\n", - " return policy_vector" + " policy_vector[no_turn_possible] = IMPOSSIBLE\n", + " max_policy[no_turn_possible] = 0\n", + " return policy_vector, raw_policy" ] }, + { + "cell_type": "markdown", + "source": [ + "## A first policy" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "code", "execution_count": null, @@ -854,7 +999,7 @@ "metadata": {}, "outputs": [], "source": [ - "SIMULATE_TURNS = 70\n", + "\n", "\n", "\n", "def simulate_game(\n",