diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c3c7048..cc41585 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,6 +65,8 @@ repos: - types-setuptools - types-requests - types-pyOpenSSL + - types-cachetools + - loguru-mypy - repo: https://github.com/frnmst/md-toc rev: 8.2.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ac2c1ed..e575007 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ ## Dev Setup -- [Install Python 3.11](https://www.python.org/downloads/release/python-3111/) +- [Install Python 3.11](https://www.python.org/downloads/release/python-3115/) - [Install Poetry](https://python-poetry.org/docs/#installation) - [Install GiT](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) - [Configure GiT](https://support.atlassian.com/bitbucket-cloud/docs/configure-your-dvcs-username-for-commits/) diff --git a/poetry.lock b/poetry.lock index e267d74..16ee761 100644 --- a/poetry.lock +++ b/poetry.lock @@ -229,14 +229,34 @@ lxml = ["lxml"] [[package]] name = "black" -version = "23.9.0" +version = "23.9.1" description = "The uncompromising code formatter." category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "black-23.9.0-py3-none-any.whl", hash = "sha256:9366c1f898981f09eb8da076716c02fd021f5a0e63581c66501d68a2e4eab844"}, - {file = "black-23.9.0.tar.gz", hash = "sha256:3511c8a7e22ce653f89ae90dfddaf94f3bb7e2587a245246572d3b9c92adf066"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:d6bc09188020c9ac2555a498949401ab35bb6bf76d4e0f8ee251694664df6301"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:13ef033794029b85dfea8032c9d3b92b42b526f1ff4bf13b2182ce4e917f5100"}, + {file = "black-23.9.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:75a2dc41b183d4872d3a500d2b9c9016e67ed95738a3624f4751a0cb4818fe71"}, + {file = "black-23.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13a2e4a93bb8ca74a749b6974925c27219bb3df4d42fc45e948a5d9feb5122b7"}, + {file = "black-23.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:adc3e4442eef57f99b5590b245a328aad19c99552e0bdc7f0b04db6656debd80"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:8431445bf62d2a914b541da7ab3e2b4f3bc052d2ccbf157ebad18ea126efb91f"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:8fc1ddcf83f996247505db6b715294eba56ea9372e107fd54963c7553f2b6dfe"}, + {file = "black-23.9.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:7d30ec46de88091e4316b17ae58bbbfc12b2de05e069030f6b747dfc649ad186"}, + {file = "black-23.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:031e8c69f3d3b09e1aa471a926a1eeb0b9071f80b17689a655f7885ac9325a6f"}, + {file = "black-23.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:538efb451cd50f43aba394e9ec7ad55a37598faae3348d723b59ea8e91616300"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:638619a559280de0c2aa4d76f504891c9860bb8fa214267358f0a20f27c12948"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:a732b82747235e0542c03bf352c126052c0fbc458d8a239a94701175b17d4855"}, + {file = "black-23.9.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:cf3a4d00e4cdb6734b64bf23cd4341421e8953615cba6b3670453737a72ec204"}, + {file = "black-23.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf99f3de8b3273a8317681d8194ea222f10e0133a24a7548c73ce44ea1679377"}, + {file = "black-23.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:14f04c990259576acd093871e7e9b14918eb28f1866f91968ff5524293f9c573"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:c619f063c2d68f19b2d7270f4cf3192cb81c9ec5bc5ba02df91471d0b88c4c5c"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:6a3b50e4b93f43b34a9d3ef00d9b6728b4a722c997c99ab09102fd5efdb88325"}, + {file = "black-23.9.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c46767e8df1b7beefb0899c4a95fb43058fa8500b6db144f4ff3ca38eb2f6393"}, + {file = "black-23.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50254ebfa56aa46a9fdd5d651f9637485068a1adf42270148cd101cdf56e0ad9"}, + {file = "black-23.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:403397c033adbc45c2bd41747da1f7fc7eaa44efbee256b53842470d4ac5a70f"}, + {file = "black-23.9.1-py3-none-any.whl", hash = "sha256:6ccd59584cc834b6d127628713e4b6b968e5f79572da66284532525a042549f9"}, + {file = "black-23.9.1.tar.gz", hash = "sha256:24b6b3ff5c6d9ea08a8888f6977eae858e1f340d7260cf56d70a49823236b62d"}, ] [package.dependencies] @@ -307,6 +327,18 @@ dev = ["CacheControl[filecache,redis]", "black", "build", "cherrypy", "mypy", "p filecache = ["filelock (>=3.8.0)"] redis = ["redis (>=2.10.5)"] +[[package]] +name = "cachetools" +version = "5.3.1" +description = "Extensible memoizing collections and decorators" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.3.1-py3-none-any.whl", hash = "sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590"}, + {file = "cachetools-5.3.1.tar.gz", hash = "sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b"}, +] + [[package]] name = "certifi" version = "2023.7.22" @@ -1278,6 +1310,18 @@ files = [ {file = "docutils-0.18.1.tar.gz", hash = "sha256:679987caf361a7539d76e584cbeddc311e3aee937877c87346f31debc63e9d06"}, ] +[[package]] +name = "et-xmlfile" +version = "1.1.0" +description = "An implementation of lxml.xmlfile for the standard library" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "et_xmlfile-1.1.0-py3-none-any.whl", hash = "sha256:a2ba85d1d6a74ef63837eed693bcb89c3f752169b0e3e7ae5b16ca5e1b3deada"}, + {file = "et_xmlfile-1.1.0.tar.gz", hash = "sha256:8eb9e2bc2f8c97e37a2dc85a09ecdcdec9d8a396530a6d5a33b30b9a92da0c5c"}, +] + [[package]] name = "exceptiongroup" version = "1.1.3" @@ -2319,14 +2363,14 @@ testing = ["black", "isort", "pytest (>=6,!=7.0.0)", "pytest-xdist (>=2)", "twin [[package]] name = "loguru" -version = "0.7.1" +version = "0.7.2" description = "Python logging made (stupidly) simple" category = "main" optional = false python-versions = ">=3.5" files = [ - {file = "loguru-0.7.1-py3-none-any.whl", hash = "sha256:046bf970cb3cad77a28d607cbf042ac25a407db987a1e801c7f7e692469982f9"}, - {file = "loguru-0.7.1.tar.gz", hash = "sha256:7ba2a7d81b79a412b0ded69bd921e012335e80fd39937a633570f273a343579e"}, + {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, + {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, ] [package.dependencies] @@ -2334,7 +2378,22 @@ colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] -dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "pre-commit (==3.3.1)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] +dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] + +[[package]] +name = "loguru-mypy" +version = "0.0.4" +description = "" +category = "dev" +optional = false +python-versions = ">=3.6.0" +files = [ + {file = "loguru-mypy-0.0.4.tar.gz", hash = "sha256:1f1767d7737f1825295ce147f7e751f91837f5759b3c2f41801adc65691aeed4"}, + {file = "loguru_mypy-0.0.4-py3-none-any.whl", hash = "sha256:98e044be509887a314e683a1e851813310b396be48388c1fe4de97a2eac99d4d"}, +] + +[package.dependencies] +typing-extensions = "*" [[package]] name = "lxml" @@ -3105,6 +3164,21 @@ packaging = "*" protobuf = "*" sympy = "*" +[[package]] +name = "openpyxl" +version = "3.1.2" +description = "A Python library to read/write Excel 2010 xlsx/xlsm files" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "openpyxl-3.1.2-py2.py3-none-any.whl", hash = "sha256:f91456ead12ab3c6c2e9491cf33ba6d08357d802192379bb482f1033ade496f5"}, + {file = "openpyxl-3.1.2.tar.gz", hash = "sha256:a6f5977418eff3b2d5500d54d9db50c8277a368436f4e4f8ddb1be3422870184"}, +] + +[package.dependencies] +et-xmlfile = "*" + [[package]] name = "outcome" version = "1.2.0" @@ -3292,6 +3366,26 @@ files = [ [package.dependencies] ptyprocess = ">=0.5" +[[package]] +name = "pgeocode" +version = "0.4.1" +description = "Approximate geocoding" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pgeocode-0.4.1-py3-none-any.whl", hash = "sha256:0cc3916d75c41ffcd910ccc2252235a66c627346502cba5d2e97b6ea0aa83257"}, + {file = "pgeocode-0.4.1.tar.gz", hash = "sha256:08f35dedf79957769641c7137aa9cc189e1bb63033226372dce372b14973e8b2"}, +] + +[package.dependencies] +numpy = "*" +pandas = "*" +requests = "*" + +[package.extras] +fuzzy = ["thefuzz"] + [[package]] name = "pickleshare" version = "0.7.5" @@ -3564,14 +3658,14 @@ virtualenv = ">=20.10.0" [[package]] name = "prettytable" -version = "3.8.0" +version = "3.9.0" description = "A simple Python library for easily displaying tabular data in a visually appealing ASCII table format" category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "prettytable-3.8.0-py3-none-any.whl", hash = "sha256:03481bca25ae0c28958c8cd6ac5165c159ce89f7ccde04d5c899b24b68bb13b7"}, - {file = "prettytable-3.8.0.tar.gz", hash = "sha256:031eae6a9102017e8c7c7906460d150b7ed78b20fd1d8c8be4edaf88556c07ce"}, + {file = "prettytable-3.9.0-py3-none-any.whl", hash = "sha256:a71292ab7769a5de274b146b276ce938786f56c31cf7cea88b6f3775d82fe8c8"}, + {file = "prettytable-3.9.0.tar.gz", hash = "sha256:f4ed94803c23073a90620b201965e5dc0bccf1760b7a7eaf3158cab8aaffdf34"}, ] [package.dependencies] @@ -5508,6 +5602,18 @@ exceptiongroup = "*" trio = ">=0.11" wsproto = ">=0.14" +[[package]] +name = "types-cachetools" +version = "5.3.0.6" +description = "Typing stubs for cachetools" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "types-cachetools-5.3.0.6.tar.gz", hash = "sha256:595f0342d246c8ba534f5a762cf4c2f60ecb61e8002b8b2277fd5cf791d4e851"}, + {file = "types_cachetools-5.3.0.6-py3-none-any.whl", hash = "sha256:f7f8a25bfe306f2e6bc2ad0a2f949d9e72f2d91036d509c36d3810bf728bc6e1"}, +] + [[package]] name = "types-pyopenssl" version = "23.2.0.2" @@ -5727,14 +5833,14 @@ files = [ [[package]] name = "websocket-client" -version = "1.6.2" +version = "1.6.3" description = "WebSocket client for Python with low level API options" category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "websocket-client-1.6.2.tar.gz", hash = "sha256:53e95c826bf800c4c465f50093a8c4ff091c7327023b10bfaff40cf1ef170eaa"}, - {file = "websocket_client-1.6.2-py3-none-any.whl", hash = "sha256:ce54f419dfae71f4bdba69ebe65bf7f0a93fe71bc009ad3a010aacc3eebad537"}, + {file = "websocket-client-1.6.3.tar.gz", hash = "sha256:3aad25d31284266bcfcfd1fd8a743f63282305a364b8d0948a43bd606acc652f"}, + {file = "websocket_client-1.6.3-py3-none-any.whl", hash = "sha256:6cfc30d051ebabb73a5fa246efdcc14c8fbebbd0330f8984ac3bb6d9edd2ad03"}, ] [package.extras] @@ -5808,4 +5914,4 @@ ingest = ["selenium"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "05d03e1ed3bdaa638f75c853dd28fb501a92d9209e757daf732abd56c78f8332" +content-hash = "f15e3b3171f0b6b22635f5c9de7635114c99447c5b3d41f8b1596d005fe1dce8" diff --git a/pyproject.toml b/pyproject.toml index 99f7e31..b69d034 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ source = ["src"] [tool.mypy] disallow_untyped_defs = true -exclude = ".ipynb_checkpoints, .mypy_cache, .mytest_cache, build" +exclude = [".ipynb_checkpoints", ".mypy_cache", ".mytest_cache", "build", "venv", ".venv", "Jupyter"] follow_imports = "silent" ignore_missing_imports = true install_types = true @@ -35,20 +35,22 @@ readme = "README.md" version = "0.1.0" [tool.poetry.dependencies] -SQLAlchemy = {version = "^1.4.46", extras = ["mypy"]} -dash = "^2.11.1" -dash-bootstrap-components = "^1.4.2" +SQLAlchemy = {version = "^1.4.49", extras = ["mypy"]} +cachetools = "^5.3.1" +dash = "^2.13.0" +dash-bootstrap-components = "^1.5.0" deutschland = {git = "https://github.com/TrisNol/deutschland.git", branch = "hotfix/python-3.11-support"} loguru = "^0.7.0" -matplotlib = "^3.7.1" -plotly = "^5.14.1" +matplotlib = "^3.7.2" +pgeocode = "^0.4.0" +plotly = "^5.16.1" psycopg2-binary = "^2.9.7" -pymongo = "^4.4.1" +pymongo = "^4.5.0" python = "^3.11" python-dotenv = "^1.0.0" seaborn = "^0.12.2" -selenium = "^4.10.0" -tqdm = "^4.65.0" +selenium = "^4.12.0" +tqdm = "^4.66.1" [tool.poetry.extras] ingest = ["selenium"] @@ -57,6 +59,7 @@ ingest = ["selenium"] black = {extras = ["jupyter"], version = "^23.9.0"} jupyterlab = "^4.0.5" nbconvert = "^7.8.0" +openpyxl = "^3.1.2" pre-commit = "^3.4.0" rise = "^5.7.1" @@ -64,7 +67,7 @@ rise = "^5.7.1" jupyter = "^1.0.0" myst-parser = "^1.0.0" nbsphinx = "^0.9.2" -sphinx = "^6.2.1" +sphinx = "*" sphinx-copybutton = "^0.5.2" sphinx-rtd-theme = "^1.3.0" sphinx_autodoc_typehints = "*" @@ -73,11 +76,13 @@ sphinxcontrib-napoleon = "^0.7" [tool.poetry.group.lint.dependencies] black = "^23.9.0" +loguru-mypy = "^0.0.4" mypy = "^1.5.1" -pandas-stubs = "^2.0.3.230814" +pandas-stubs = "^2.0.1.230501" pip-audit = "^2.6.1" pip-licenses = "^4.3.2" ruff = "^0.0.287" +types-cachetools = "^5.3.0.6" types-pyOpenSSL = "*" types-requests = "^2.31.0.2" types-setuptools = "*" @@ -90,6 +95,10 @@ pytest-cov = "^4.1.0" pytest-mock = "^3.11.1" pytest-repeat = "^0.9.1" +[tool.poetry.scripts] +data-transfer = "aki_prj23_transparenzregister.utils.data_transfer:transfer_data" +reset-sql = "aki_prj23_transparenzregister.utils.sql.connector:reset_all_tables" + [tool.ruff] exclude = [ ".bzr", @@ -127,7 +136,7 @@ unfixable = ["B"] builtins-ignorelist = ["id"] [tool.ruff.per-file-ignores] -"tests/*.py" = ["S101", "D100", "D101", "D107", "D103"] +"tests/*.py" = ["S101", "SLF001", "S311", "D103"] [tool.ruff.pydocstyle] convention = "google" diff --git a/src/aki_prj23_transparenzregister/models/company.py b/src/aki_prj23_transparenzregister/models/company.py index d160826..11c0a5b 100644 --- a/src/aki_prj23_transparenzregister/models/company.py +++ b/src/aki_prj23_transparenzregister/models/company.py @@ -1,5 +1,4 @@ """Company model.""" -from abc import ABC from dataclasses import asdict, dataclass from enum import Enum @@ -34,12 +33,8 @@ class Location: @dataclass -class CompanyRelationship(ABC): - """_summary_. - - Args: - ABC (_type_): _description_ - """ +class CompanyRelationship: + """_summary_.""" role: RelationshipRoleEnum location: Location @@ -92,6 +87,8 @@ class YearlyResult: @dataclass class Company: + """_summary_.""" + """Company dataclass.""" id: CompanyID @@ -102,9 +99,5 @@ class Company: # yearly_results: list[FinancialResults] def to_dict(self) -> dict: - """_summary_. - - Returns: - dict: _description_ - """ + """_summary_.""" return asdict(self) diff --git a/src/aki_prj23_transparenzregister/utils/data_transfer.py b/src/aki_prj23_transparenzregister/utils/data_transfer.py new file mode 100644 index 0000000..2d4f129 --- /dev/null +++ b/src/aki_prj23_transparenzregister/utils/data_transfer.py @@ -0,0 +1,355 @@ +"""This module contains the data transfer and refinement functionalities between staging and production DB.""" +import sys +from datetime import date +from typing import Any + +import sqlalchemy as sa +from cachetools import LRUCache, cached +from loguru import logger +from sqlalchemy.orm import Session +from tqdm import tqdm + +from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider +from aki_prj23_transparenzregister.utils.enum_types import RelationTypeEnum +from aki_prj23_transparenzregister.utils.mongo.company_mongo_service import ( + CompanyMongoService, +) +from aki_prj23_transparenzregister.utils.mongo.connector import MongoConnector +from aki_prj23_transparenzregister.utils.sql import entities +from aki_prj23_transparenzregister.utils.sql.connector import ( + get_session, + reset_all_tables, +) +from aki_prj23_transparenzregister.utils.string_tools import simplify_string + + +class DataInvalidError(ValueError): + """This error is thrown if a db entry can't be parsed for the production db.""" + + def __init__(self, message: str) -> None: + """Argument of the error to be parsed along.""" + super().__init__(message) + + +def _refine_district_court_entry(name: str, city: str | None) -> tuple[str, str]: + """Refines the district court entry and tests for consistency. + + Args: + name: The name of the court. + city: The city where the cort is placed. + + Returns: + A tuple containing cort name and court city. + """ + if not name: + raise DataInvalidError("There is no court name.") + if not name.startswith("Amtsgericht "): + raise DataInvalidError( + f"The name of the district court does not start correctly: {name}" + ) + if not city or city not in name.split(" ", 1)[1]: + city = name.split(" ", 1)[1].strip() + return name, city + + +def _read_district_court_id(name: str, city: str, db: Session) -> int | None: + """Reads a district court id for a company if the district court is registered. + + Args: + name: The name of the court. + city: The name of the city where the court is placed. + db: A session to connect to an SQL db via SQLAlchemy. + + + Returns: + The district court id as an int if the district court is known. + Otherwise, returns None. + """ + return ( + db.query(entities.DistrictCourt.id) + .filter(entities.DistrictCourt.name == name) + .filter(entities.DistrictCourt.city == city) + .scalar() + ) + + +def _read_person_id( + name: str, surname: str, date_of_birth: date, db: Session +) -> int | None: + """Reads a person id if the person is already registered. + + Args: + name: The first name of the person. + surname: The last name of the person. + date_of_birth: The date the person was born. + db: A session to connect to an SQL db via SQLAlchemy. + + Returns: + The district court id as an int if the district court is known. + Otherwise, returns None. + """ + return ( + db.query(entities.Person.id) + .filter(entities.Person.name == name) + .filter(entities.Person.surname == surname) + .filter(entities.Person.date_of_birth == date_of_birth) + .scalar() + ) + + +@cached(cache=LRUCache(maxsize=1000), key=lambda name, city, db: hash((name, city))) # type: ignore +def get_district_court_id(name: str, city: str | None, db: Session) -> int: + """Determines the id of a district court. + + Determines the id of a district court and adds an entry to the table if no entry and id could be found. + A lru_cache is used to increase the speed of this application. + + Args: + name: The name of the district court. + city: The name where the court is located. + db: A session to connect to an SQL db via SQLAlchemy. + + Returns: + The id / privat key of a district court in the SQL-database. + """ + name, city = _refine_district_court_entry(name, city) + court_id = _read_district_court_id(name, city, db) + if court_id is not None: + return court_id + court = entities.DistrictCourt(name=name, city=city) + db.add(court) + db.commit() + return court.id # type: ignore + + +@cached(cache=LRUCache(maxsize=2000), key=lambda name, surname, date_of_birth, db: hash((name, surname, date_of_birth))) # type: ignore +def get_person_id( + name: str, surname: str, date_of_birth: date | str, db: Session +) -> int: + """Identifies the id of and court. + + Identifies the id of a district court and adds an entry to the table if no entry and id could be found. + A lru_cache is used to increase the speed of this application. + + Args: + name: The first name of the person. + surname: The last name of the person. + date_of_birth: The date the person was born. + db: A session to connect to an SQL db via SQLAlchemy. + + Returns: + The id / privat key of a district court in the SQL-database. + """ + if isinstance(date_of_birth, str) and date_of_birth: + date_of_birth = date.fromisoformat(date_of_birth) + if not name or not surname or not date_of_birth: + raise DataInvalidError( + f'At least one of the three values name: "{name}", surname: "{surname}" or date_of_birth: "{date_of_birth}" is empty.' + ) + assert isinstance(date_of_birth, date) # noqa: S101 + person_id = _read_person_id(name, surname, date_of_birth, db) + if person_id is not None: + return person_id + person = entities.Person(name=name, surname=surname, date_of_birth=date_of_birth) + db.add(person) + db.commit() + return person.id # type: ignore + + +@cached(cache=LRUCache(maxsize=5000), key=lambda name, zip_code, city, db: hash((name, zip_code, city))) # type: ignore +def get_company_id( + name: str, zip_code: str | None, city: str | None, db: Session +) -> int: + """Queries the id of a company. + + Args: + name: The HR entry of the company. + zip_code: The zip code where the company can be found. + city: The city where the company is found in. + db: A session to connect to an SQL db via SQLAlchemy. + + Returns: + The id / privat key of a company. + """ + if not name: + raise DataInvalidError("The name must be given and contain at least one sign.") + zip_code = simplify_string(zip_code) + city = simplify_string(city) + company_id = ( + db.query(entities.Company.id) + .filter( + sa.or_(entities.Company.zip_code == zip_code, entities.Company.city == city) + ) + .filter(entities.Company.name == name) + .scalar() + ) + if company_id is None and zip_code is None and city is None: + company_id = ( + db.query(entities.Company.id) + .filter(entities.Company.name == name) + .scalar() # todo ensure uniqueness + ) + if company_id is None: + raise KeyError(f"No corresponding company could be found to {name}.") + return company_id + + +@logger.catch(level="WARNING", reraise=True) +def add_company(company: dict[str, Any], db: Session) -> None: + """Add a company with all its data found in the mongodb company entry. + + Args: + company: The company to add. + db: A session to connect to an SQL db via SQLAlchemy. + """ + court_id = get_district_court_id(**company["id"]["district_court"], db=db) + location = company["location"] + name = simplify_string(company.get("name")) + if not name: + raise DataInvalidError( + "The company name needs to be valid (not empty and not only whitespace)." + ) + company_entry = entities.Company( + court_id=court_id, + hr=company["id"]["hr_number"].strip().replace(" ", " ").replace(" ", " "), + name=name, + city=simplify_string(location.get("city")), + zip_code=simplify_string(location.get("zip_code")), + street=simplify_string(location.get("street")), + last_update=company["last_update"], + ) + db.add(company_entry) + db.commit() + logger.debug(f"Added the company entry {company['name']} to the db.") + + +def add_companies(companies: list[dict[str, Any]], db: Session) -> None: + """Adds a company to the database. + + Args: + companies: The company to be added. + db: A session to connect to an SQL db via SQLAlchemy. + """ + data_invalid, error_count = 0, 0 + for company in tqdm(companies, desc="Companies added"): + try: + add_company(company, db) + except DataInvalidError: + data_invalid += 1 + except Exception: + error_count += 1 + db.rollback() + if error_count + data_invalid: + logger.warning( + f"When adding companies {error_count + data_invalid} problems occurred " + f"{data_invalid} where caused by invalid data." + ) + else: + logger.info("When adding companies no problems occurred.") + + +@logger.catch(level="WARNING", reraise=True) +def add_relationship( + relationship: dict[str, Any], company_id: int, db: Session +) -> None: + """Adds a relationship to a company. + + Args: + relationship: The relationship and the relationship partner. + company_id: The company id the relations is rooted in. + db: A session to connect to an SQL db via SQLAlchemy. + """ + relation_type = RelationTypeEnum.get_enum_from_name(relationship.get("role")) + relation: entities.CompanyRelation | entities.PersonRelation + if "date_of_birth" in relationship: + name = relationship["name"] + person_id = get_person_id( + name["firstname"], + name["lastname"], + relationship["date_of_birth"], + db, + ) + relation = entities.PersonRelation( + person_id=person_id, + company_id=company_id, + relation=relation_type, + ) + else: + relation_to: int = get_company_id( + relationship["description"], + relationship["location"]["zip_code"], + relationship["location"]["city"], + db=db, + ) + if company_id == relation_to: + raise DataInvalidError( + "For a valid relation both parties can't be the same entity." + ) + relation = entities.CompanyRelation( + company_id=company_id, + relation=relation_type, + company2_id=relation_to, + ) + db.add(relation) + db.commit() + + +def add_relationships(companies: list[dict[str, dict]], db: Session) -> None: + """Add a list of companies to the database. + + Args: + companies: Companies to be added to the db. + db: A session to connect to an SQL db via SQLAlchemy. + """ + total: int = sum(len(company.get("relationships", [])) for company in companies) + with tqdm( + total=total, + desc="Company connections added", + ) as pbar: + for company in companies: + relationships: list[dict[str, Any]] = company.get("relationships", []) # type: ignore + try: + company_id: int = get_company_id( + company["name"], # type: ignore + company["location"]["zip_code"], + company["location"]["city"], + db=db, + ) + except Exception: + pbar.update(len(relationships)) + db.rollback() + continue + + for relationship in relationships: + try: + add_relationship(relationship, company_id=company_id, db=db) + except Exception: + db.rollback() + pbar.update() + + logger.info("Company connections added.") + + +def transfer_data(db: Session | None) -> None: + """This functions transfers all the data from a production environment to a staging environment.""" + if db is None: + db = get_session(JsonFileConfigProvider("./secrets.json")) + logger.remove() + logger.add(sys.stdout, level="INFO") + logger.add("data-transfer.log", level="INFO", retention=5) + + reset_all_tables(db) + mongo_connector = MongoConnector( + JsonFileConfigProvider("./secrets.json").get_mongo_connection_string() + ) + mongo_company = CompanyMongoService(mongo_connector) + companies: list[dict[str, Any]] = mongo_company.get_all() # type: ignore + del mongo_company + + add_companies(companies, db) + add_relationships(companies, db) + db.close() + + +if __name__ == "__main__": + transfer_data(get_session(JsonFileConfigProvider("./secrets.json"))) diff --git a/src/aki_prj23_transparenzregister/utils/enum_types.py b/src/aki_prj23_transparenzregister/utils/enum_types.py new file mode 100644 index 0000000..51e9799 --- /dev/null +++ b/src/aki_prj23_transparenzregister/utils/enum_types.py @@ -0,0 +1,69 @@ +"""Collection of enumeration types for the whole project.""" +import enum + + +class RelationTypeEnum(enum.IntEnum): + """RelationTypeEnum.""" + + GESCHAEFTSFUEHRER = enum.auto() + KOMMANDITIST = enum.auto() + VORSTAND = enum.auto() + PROKURIST = enum.auto() + LIQUIDATOR = enum.auto() + INHABER = enum.auto() + PERSOENLICH_HAFTENDER_GESELLSCHAFTER = enum.auto() + PARTNER = enum.auto() + DIREKTOR = enum.auto() + + RECHTSNACHFOLGER = enum.auto() + ORGANISATION = enum.auto() + + @staticmethod + def get_enum_from_name(relation_name: str | None) -> "RelationTypeEnum": + """Translates relation name into a RelationTypeEnum. + + If no translation can be found a warning is given. + + Args: + relation_name: The name of the relation to be translated. + + Returns: + The identified translation or None if no translation can be found. + """ + if relation_name is None: + raise ValueError("A relation type needs to be given.") + relation_name = ( + relation_name.strip() + .replace("(in)", "") + .replace("(r)", "r") + .strip() + .lower() + ) + name = { + "geschäftsführer": RelationTypeEnum.GESCHAEFTSFUEHRER, + "kommanditist": RelationTypeEnum.KOMMANDITIST, + "vorstand": RelationTypeEnum.VORSTAND, + "vorstandsvorsitzender": RelationTypeEnum.VORSTAND, + "prokurist": RelationTypeEnum.PROKURIST, + "liquidator": RelationTypeEnum.LIQUIDATOR, + "inhaber": RelationTypeEnum.INHABER, + "persönlich haftender gesellschafter": RelationTypeEnum.PERSOENLICH_HAFTENDER_GESELLSCHAFTER, + "organisation": RelationTypeEnum.ORGANISATION, + "partner": RelationTypeEnum.PARTNER, + "direktor": RelationTypeEnum.DIREKTOR, + "geschäftsführender direktor": RelationTypeEnum.DIREKTOR, + "mitglied des leitungsorgans": RelationTypeEnum.VORSTAND, + "rechtsnachfolger": RelationTypeEnum.RECHTSNACHFOLGER, + }.get(relation_name) + if name is not None: + return name + raise ValueError(f'Relation type "{relation_name}" is not yet implemented!') + + +class SentimentTypeEnum(enum.Enum): + """SentimentTypeEnum.""" + + employee_voting = "employee_voting" + sustainability = "sustainability" + environmental_aspects = "environmental_aspects" + perception = "perception" diff --git a/src/aki_prj23_transparenzregister/utils/enumy_types.py b/src/aki_prj23_transparenzregister/utils/enumy_types.py deleted file mode 100644 index 30901de..0000000 --- a/src/aki_prj23_transparenzregister/utils/enumy_types.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Collection of enumeration types for the whole project.""" -import enum - - -class RelationTypeEnum(enum.IntEnum): - """RelationTypeEnum.""" - - EXECUTIVE = enum.auto() - AUDITOR = enum.auto() - SUPERVISORY_BOARD = enum.auto() - MANAGING_DIRECTOR = enum.auto() - AUTHORIZED_REPRESENTATIVE = enum.auto() - FINAL_AUDITOR = enum.auto() - - PARTICIPATES_WITH = enum.auto() - HAS_SHARES_OF = enum.auto() - IS_SUPPLIED_BY = enum.auto() - WORKS_WITH = enum.auto() - - -class SentimentTypeEnum(enum.Enum): - """SentimentTypeEnum.""" - - employee_voting = "employee_voting" - sustainability = "sustainability" - environmental_aspects = "environmental_aspects" - perception = "perception" diff --git a/src/aki_prj23_transparenzregister/utils/sql/connector.py b/src/aki_prj23_transparenzregister/utils/sql/connector.py index b2ef367..3986d45 100644 --- a/src/aki_prj23_transparenzregister/utils/sql/connector.py +++ b/src/aki_prj23_transparenzregister/utils/sql/connector.py @@ -81,6 +81,13 @@ def init_db(db: Session) -> None: Base.metadata.create_all(db.bind) +def reset_all_tables(db: Session) -> None: + """Drops all SQL tables and recreates them.""" + logger.info("Resetting all PostgreSQL tables.") + Base.metadata.drop_all(db.bind) + init_db(db) + + if __name__ == "__main__": """Main flow creating tables""" init_db(get_session(JsonFileConfigProvider("./secrets.json"))) diff --git a/src/aki_prj23_transparenzregister/utils/sql/entities.py b/src/aki_prj23_transparenzregister/utils/sql/entities.py index 2bb90f1..fefdf83 100644 --- a/src/aki_prj23_transparenzregister/utils/sql/entities.py +++ b/src/aki_prj23_transparenzregister/utils/sql/entities.py @@ -3,7 +3,7 @@ from datetime import datetime import sqlalchemy as sa -from aki_prj23_transparenzregister.utils.enumy_types import ( +from aki_prj23_transparenzregister.utils.enum_types import ( RelationTypeEnum, SentimentTypeEnum, ) @@ -16,7 +16,6 @@ class DistrictCourt(Base): """DistrictCourt.""" __tablename__ = "district_court" - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) city = sa.Column(sa.String(100), nullable=False) name = sa.Column(sa.String(100), nullable=False, unique=True) @@ -54,6 +53,7 @@ class Person(Base): __tablename__ = "person" __table_args__ = (sa.UniqueConstraint("name", "surname", "date_of_birth"),) + # TODO add a constraint that asks for a minlength of 2 for name and surname id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String(100), nullable=False) diff --git a/src/aki_prj23_transparenzregister/utils/string_tools.py b/src/aki_prj23_transparenzregister/utils/string_tools.py new file mode 100644 index 0000000..be399f0 --- /dev/null +++ b/src/aki_prj23_transparenzregister/utils/string_tools.py @@ -0,0 +1,18 @@ +"""Contains functions fot string manipulation.""" + + +def simplify_string(string_to_simplify: str | None) -> str | None: + """Simplifies a string to None if no valid sting is found. + + Args: + string_to_simplify: The string to simplify. + + Returns: + The simplified string or None if the string was empty. + """ + if string_to_simplify is not None: + if isinstance(string_to_simplify, str): + string_to_simplify = string_to_simplify.strip() + else: + raise TypeError("The string to simplify is not a string.") + return string_to_simplify if string_to_simplify else None diff --git a/tests/apps/enrich_company_financials_test.py b/tests/apps/enrich_company_financials_test.py index 22232bd..4368fe9 100644 --- a/tests/apps/enrich_company_financials_test.py +++ b/tests/apps/enrich_company_financials_test.py @@ -18,7 +18,8 @@ def test_import_enrich_company_financials() -> None: @patch( "aki_prj23_transparenzregister.apps.enrich_company_financials.CompanyMongoService" ) -def test_work(mock_compnay_service: Mock, mock_bundesanzeiger: Mock) -> None: +def test_work(mock_company_service: Mock, mock_bundesanzeiger: Mock) -> None: + """Tests the readout of the company financials.""" mock_bundesanzeiger.return_value = pd.DataFrame( [ { @@ -28,9 +29,8 @@ def test_work(mock_compnay_service: Mock, mock_bundesanzeiger: Mock) -> None: } ] ) - # mock_compnay_service.add_yearly_resreturn_value enrich_company_financials.work( {"_id": "", "name": "ABC AG", "location": {"city": "Haltern am See"}}, - mock_compnay_service, + mock_company_service, ) assert enrich_company_financials diff --git a/tests/config/config_providers_test.py b/tests/config/config_providers_test.py index 0d00cfd..60ebb3e 100644 --- a/tests/config/config_providers_test.py +++ b/tests/config/config_providers_test.py @@ -1,3 +1,4 @@ +"""Tests the config provers.""" import json from unittest.mock import mock_open, patch @@ -10,11 +11,13 @@ from aki_prj23_transparenzregister.config.config_providers import ( def test_json_provider_init_fail() -> None: + """Tests the file not found error if an unknown filepath is given for the JsonFileConfigProvider.""" with pytest.raises(FileNotFoundError): JsonFileConfigProvider("file-that-does-not-exist") def test_json_provider_init_no_json() -> None: + """Tests if a non json file throws the correct error.""" with patch("os.path.isfile") as mock_isfile, patch( "builtins.open", mock_open(read_data="fhdaofhdoas") ): @@ -24,6 +27,7 @@ def test_json_provider_init_no_json() -> None: def test_json_provider_init() -> None: + """Tests the JsonFileConfigProvider creation.""" data = {"hello": "world"} input_data = json.dumps(data) with patch("os.path.isfile") as mock_isfile: @@ -34,6 +38,7 @@ def test_json_provider_init() -> None: def test_json_provider_get_postgres() -> None: + """Tests if the config provider can return the postgre config string.""" data = { "postgres": { "username": "user", @@ -56,6 +61,7 @@ def test_json_provider_get_postgres() -> None: def test_json_provider_get_mongo() -> None: + """Tests the JsonConfigProvider for the mongo db.""" data = { "mongo": { "username": "user", diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..b8da9d9 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,116 @@ +"""Global configurations and definitions for pytest.""" +import datetime +import os +from collections.abc import Generator +from inspect import getmembers, isfunction + +import pytest +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session + +from aki_prj23_transparenzregister.utils import data_transfer +from aki_prj23_transparenzregister.utils.sql import entities +from aki_prj23_transparenzregister.utils.sql.connector import get_session, init_db + + +@pytest.fixture(autouse=True) +def _clear_caches() -> Generator[None, None, None]: + """A function that clears all caches after each test. + + All the modules containing the cached functions need to be listed in the modules tuple. + """ + yield + # https://stackoverflow.com/a/139198/11003343 + modules = (data_transfer,) + functions = [ + function + for module in modules + for name, function in getmembers(module, isfunction) + if function.__dict__.get("cache") is not None + ] + # https://cachetools.readthedocs.io/en/stable/?highlight=clear#memoizing-decorators + for function in functions: + function.cache.clear() # type: ignore + + +@pytest.fixture() +def empty_db() -> Generator[Session, None, None]: + """Generates a db Session to a sql_lite db.""" + if os.path.exists("test-db.db"): + os.remove("test-db.db") + db = get_session("sqlite:///test-db.db") + init_db(db) + yield db + db.close() + bind = db.bind + assert isinstance(bind, Engine) + bind.dispose() + os.remove("test-db.db") + + +@pytest.fixture() +def full_db(empty_db: Session) -> Session: + """Fills a db with some test data.""" + empty_db.add_all( + [ + entities.DistrictCourt(name="Amtsgericht Bochum", city="Bochum"), + entities.DistrictCourt(name="Amtsgericht Dortmund", city="Dortmund"), + entities.Person( + name="Max", + surname="Mustermann", + date_of_birth=datetime.date(2023, 1, 1), + ), + entities.Person( + name="Sabine", + surname="Mustermann", + date_of_birth=datetime.date(2023, 1, 1), + ), + entities.Person( + name="Some Firstname", + surname="Some Surname", + date_of_birth=datetime.date(2023, 1, 1), + ), + entities.Person( + name="Some Firstname", + surname="Some Surname", + date_of_birth=datetime.date(2023, 1, 2), + ), + entities.Person( + name="Other Firstname", + surname="Other Surname", + date_of_birth=datetime.date(2023, 1, 2), + ), + ] + ) + empty_db.commit() + empty_db.add_all( + [ + entities.Company( + hr="HRB 123", + court_id=2, + name="Some Company GmbH", + street="Sesamstr.", + zip_code="12345", + city="TV City", + last_update=datetime.date.fromisoformat("2023-01-01"), + ), + entities.Company( + hr="HRB 123", + court_id=1, + name="Other Company GmbH", + street="Sesamstr.", + zip_code="12345", + city="TV City", + last_update=datetime.date.fromisoformat("2023-01-01"), + ), + entities.Company( + hr="HRB 12", + court_id=2, + name="Third Company GmbH", + last_update=datetime.date.fromisoformat("2023-01-01"), + ), + ] + ) + empty_db.commit() + # print(pd.read_sql_table("company", empty_db.bind).to_string()) + return empty_db diff --git a/tests/utils/data_extraction/bundesanzeiger_test.py b/tests/utils/data_extraction/bundesanzeiger_test.py index 30e8007..8829bbd 100644 --- a/tests/utils/data_extraction/bundesanzeiger_test.py +++ b/tests/utils/data_extraction/bundesanzeiger_test.py @@ -1,3 +1,4 @@ +"""Tests if the bundesanzeiger can be accessed and read.""" from unittest.mock import Mock, patch import pandas as pd diff --git a/tests/utils/data_transfer_test.py b/tests/utils/data_transfer_test.py new file mode 100644 index 0000000..f2deb8d --- /dev/null +++ b/tests/utils/data_transfer_test.py @@ -0,0 +1,744 @@ +"""Test the transfer functions from mongodb to sql.""" +import random +import string +from datetime import date +from typing import Any + +import numpy as np +import pandas as pd +import pytest +import sqlalchemy as sa +from pytest_mock import MockerFixture +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session + +from aki_prj23_transparenzregister.utils import data_transfer + + +@pytest.mark.parametrize( + ("original", "expected"), + [ + ( + {"name": "Amtsgericht Herne", "city": "Herne"}, + {"name": "Amtsgericht Herne", "city": "Herne"}, + ), + ( + {"name": "Amtsgericht Herne", "city": ""}, + {"name": "Amtsgericht Herne", "city": "Herne"}, + ), + ( + {"name": "Amtsgericht Herne", "city": None}, + {"name": "Amtsgericht Herne", "city": "Herne"}, + ), + ( + {"name": "Amtsgericht Herne", "city": "Something Wrong"}, + {"name": "Amtsgericht Herne", "city": "Herne"}, + ), + ( + {"name": "Amtsgericht Herne", "city": "NoName"}, + {"name": "Amtsgericht Herne", "city": "Herne"}, + ), + ], +) +def test_refine_district_court_entry(original: dict, expected: dict) -> None: + """Tests the transformation/the cleaning of the district court entry.""" + assert data_transfer._refine_district_court_entry( + **{"name": "Amtsgericht Herne", "city": "Herne"} + ) == tuple(expected.values()) + + +@pytest.mark.parametrize( + "defect_data", + [ + {"name": "Wrong Herne", "city": "Herne"}, + {"name": "Wrong Herne", "city": "NoName"}, + {"city": "Herne", "name": None}, + {"city": "Herne", "name": ""}, + ], +) +def test_refine_district_court_entry_defect_data(defect_data: dict[str, str]) -> None: + """Tests if an error is thrown if the district court data can't be corrected.""" + with pytest.raises(data_transfer.DataInvalidError): + data_transfer._refine_district_court_entry(**defect_data) + + +@pytest.mark.repeat(3) +def test_empty_db_fixture(empty_db: Session) -> None: + """Checks if the db can be created.""" + assert isinstance(empty_db, Session) + + +@pytest.mark.parametrize( + ("name", "city", "id"), + [ + ("Amtsgericht Bochum", "Bochum", 1), + ("Amtsgericht Dortmund", "Dortmund", 2), + ("Amtsgericht Iserlohn", "Iserlohn", None), + ], +) +def test__read_district_court_id( + name: str, city: str, id: int | None, full_db: Session +) -> None: + """Tests if the district court id can be read.""" + assert data_transfer._read_district_court_id(name, city, full_db) == id + + +@pytest.mark.parametrize( + ("firstname", "surname", "date_str", "id"), + [ + ("Max", "Mustermann", "2023-01-01", 1), + ("Sabine", "Mustermann", "2023-01-01", 2), + ("Some Firstname", "Some Surname", "2023-01-01", 3), + ("Some Firstname", "Some Surname", "2023-01-02", 4), + ("Other Firstname", "Other Surname", "2023-01-02", 5), + (None, "Other Surname", "2023-01-02", None), + ("Does not exist", "Other Surname", "2023-01-02", None), + ("Other Firstname", "Does not exists", "2023-01-02", None), + ("Other Firstname", "Other Surname", "1900-01-02", None), + ("Other Firstname", None, "2023-01-02", None), + ], +) +def test__read_person_id( + firstname: str, surname: str, date_str: str, id: int | None, full_db: Session +) -> None: + """Tests if the person id can be read.""" + assert ( + data_transfer._read_person_id( + firstname, surname, date.fromisoformat(date_str), full_db + ) + == id + ) + + +@pytest.mark.parametrize( + ("name", "city", "id"), + [ + ("Amtsgericht Bochum", "Bochum", 1), + ("Amtsgericht Dortmund", "Dortmund", 2), + ("Amtsgericht Iserlohn", "Iserlohn", 3), + ("Amtsgericht Köln", "Köln", 3), + ], +) +def test_get_district_court_id(name: str, city: str, id: int, full_db: Session) -> None: + """Tests if a court id can be returned and the court automatically be added if not yet part of the db.""" + assert data_transfer.get_district_court_id(name, city, full_db) == id + + +@pytest.mark.parametrize( + ("firstname", "surname", "date_str", "id"), + [ + ("Max", "Mustermann", "2023-01-01", 1), + ("Sabine", "Mustermann", "2023-01-01", 2), + ("Some Firstname", "Some Surname", "2023-01-01", 3), + ("Some Firstname", "Some Surname", "2023-01-02", 4), + ("Other Firstname", "Other Surname", "2023-01-02", 5), + ("Does not exist", "Other Surname", "2023-01-02", 6), + ("Other Firstname", "Does not exists", "2023-01-02", 6), + ("Other Firstname", "Other Surname", "1900-01-02", 6), + ], +) +def test_get_person_id( + firstname: str, surname: str, date_str: str, id: int, full_db: Session +) -> None: + """Tests if a person id can be returned and the court automatically be added if not yet part of the db.""" + assert ( + data_transfer.get_person_id( + firstname, surname, date.fromisoformat(date_str), full_db + ) + == id + ) + + +@pytest.mark.parametrize( + ("firstname", "surname", "date_str"), + [ + ("", "Other Surname", "2023-01-02"), + ("Other Firstname", "", "2023-01-02"), + ("Other Firstname", "Other Surname", ""), + ], +) +def test_get_person_id_value_check( + firstname: str, surname: str, date_str: str | None, full_db: Session +) -> None: + """Tests if errors on adding persons can be found.""" + with pytest.raises( + data_transfer.DataInvalidError, match="At least one of the three values name:" + ): + data_transfer.get_person_id( + firstname, + surname, + date.fromisoformat(date_str) if date_str else None, # type: ignore + full_db, + ) + + +@pytest.mark.parametrize( + ("name", "zip_code", "city", "id"), + [ + ("Some Company GmbH", "", "", 1), + ("Some Company GmbH", "12345", "", 1), + ("Some Company GmbH", "12345", "TV City", 1), + ("Some Company GmbH", "", "TV City", 1), + ("Other Company GmbH", "", "", 2), + ("Other Company GmbH", "12345", "", 2), + ("Other Company GmbH", "12345", "TV City", 2), + ("Other Company GmbH", "", "TV City", 2), + ("Third Company GmbH", "", "", 3), + ], +) +def test_get_company_id( + name: str, zip_code: str, city: str, id: int | None, full_db: Session +) -> None: + """Tests if the company id can be returned correctly.""" + assert data_transfer.get_company_id(name, zip_code, city, full_db) == id + + +@pytest.mark.parametrize( + ("name", "zip_code", "city"), + [ + ("Does not exist", "", ""), + ("Does not exist", "41265", ""), + ("Does not exist", "", "Some City"), + ("Other Company GmbH", "TV City", "54321"), + ("Other Company GmbH", "OtherCity", "12345"), + ("Other Company GmbH", "OtherCity", "54321"), + ], +) +def test_get_company_id_not_found( + name: str, + zip_code: str, + city: str, + full_db: Session, +) -> None: + """Test the accessing of missing companies.""" + with pytest.raises(KeyError): + data_transfer.get_company_id(name, zip_code, city, full_db) + + +@pytest.mark.parametrize("name", ["", None]) +def test_get_company_id_nameless(name: str | None, full_db: Session) -> None: + """Test accessing a company without valid name.""" + with pytest.raises(data_transfer.DataInvalidError): + data_transfer.get_company_id(name, "zip_code", "city", full_db) # type: ignore + + +def get_random_string(length: int) -> str: + """Creates a random string of a defined length. + + Args: + length: The length of the string to generate. + + Returns: + The generated string. + """ + letters = string.digits + string.ascii_letters + " " + return "".join(random.choice(letters) for _ in range(length)) + + +def get_random_zip() -> str: + """Creates a random zip.""" + letters = string.digits + return "".join(random.choice(letters) for _ in range(5)) + + +def company_generator(seed: int) -> dict[str, Any]: + """Generates a random company entry.""" + random.seed(seed) + if random.choice([True, False]): + city = "Dortmund" + else: + city = get_random_string(random.randint(5, 30)) + return { + "id": { + "district_court": { + "name": f"Amtsgericht {city}", + "city": city if random.choice([True, False]) else None, + }, + "hr_number": get_random_string(7), + }, + "name": get_random_string(random.randint(3, 150)), + "location": { + "city": city if random.choice([True, False]) else None, + "zip_code": get_random_zip() if random.choice([True, False]) else None, + "street": get_random_string(20) if random.choice([True, False]) else None, + }, + "last_update": date(random.randint(2000, 2023), 1, 1), + } + + +@pytest.mark.parametrize("seed", list(range(70, 75))) +def test_add_company(seed: int, full_db: Session) -> None: + """Tests the addition of a company to the db.""" + company = company_generator(seed) + data_transfer.add_company(company, full_db) + + +@pytest.mark.parametrize("seed", list(range(5))) +@pytest.mark.parametrize("overwrite", ["", None, " "]) +def test_add_company_broken_name( + seed: int, overwrite: str | None, full_db: Session +) -> None: + """Tests what happens if a company has a broken / empty name.""" + company = company_generator(seed) + company["name"] = overwrite + if overwrite is None: + with pytest.raises( + data_transfer.DataInvalidError, + match="The company name needs to be valid ", + ): + data_transfer.add_company(company, full_db) + + +@pytest.mark.parametrize("seed", list(range(5))) +@pytest.mark.parametrize("overwrite", ["", None, " "]) +def test_add_company_broken_city( + seed: int, overwrite: str | None, full_db: Session +) -> None: + """Tests a broken / empty city entry.""" + company = company_generator(seed) + company["location"]["city"] = overwrite + data_transfer.add_company(company, full_db) + + +@pytest.mark.parametrize("seed", list(range(5))) +@pytest.mark.parametrize("overwrite", ["", None, " "]) +def test_add_company_broken_zip_code( + seed: int, overwrite: str | None, full_db: Session +) -> None: + """Tests how to add a company if the zip_code is broken / empty.""" + company = company_generator(seed) + company["location"]["zip_code"] = overwrite + data_transfer.add_company(company, full_db) + + +@pytest.mark.parametrize("seed", list(range(5))) +@pytest.mark.parametrize("overwrite", [None]) +def test_add_company_broken_date( + seed: int, overwrite: str | None, full_db: Session +) -> None: + """Tests how the company dadd function deals with a missing date.""" + company = company_generator(seed) + company["last_update"] = overwrite + with pytest.raises(sa.exc.IntegrityError): + data_transfer.add_company(company, full_db) + + +@pytest.mark.parametrize("seed", list(range(5))) +@pytest.mark.parametrize("overwrite", ["", None, " "]) +def test_add_company_broken_district_court( + seed: int, overwrite: str | None, full_db: Session, mocker: MockerFixture +) -> None: + """Test a broken district court entry.""" + company = company_generator(seed) + company["id"]["district_court"]["name"] = overwrite + company["id"]["district_court"]["city"] = get_random_string(10) + with pytest.raises( + data_transfer.DataInvalidError, + match="There is no court name|The name of the district court does not start correctly", + ): + data_transfer.add_company(company, full_db) + + +@pytest.mark.parametrize("seed", list(range(0, 25, 5))) +def test_add_companies(seed: int, mocker: MockerFixture, full_db: Session) -> None: + """Test to add multiple companies.""" + rnd_generator = np.random.default_rng(seed) + companies: list[dict[str, Any]] = [ + company_generator(_) + for _ in set( + rnd_generator.integers(0, 1000, size=rnd_generator.integers(1, 30)).tolist() + ) + ] + spy_warning = mocker.spy(data_transfer.logger, "warning") + spy_info = mocker.spy(data_transfer.logger, "info") + spy_debug = mocker.spy(data_transfer.logger, "debug") + data_transfer.add_companies(companies, full_db) + spy_info.assert_called_once_with("When adding companies no problems occurred.") + spy_warning.assert_not_called() + assert spy_debug.call_count == len(companies) + + +@pytest.mark.parametrize("seed", list(range(1, 25, 5))) +def test_add_companies_duplicate( + seed: int, mocker: MockerFixture, full_db: Session +) -> None: + """Test to add multiple companies.""" + rnd_generator = np.random.default_rng(seed) + companies: list[dict[str, Any]] = [ + company_generator(_) + for _ in set( + rnd_generator.integers(0, 1000, size=rnd_generator.integers(4, 30)).tolist() + ) + ] + unique_companies = len(companies) + companies += companies[-3:] + spy_warning = mocker.spy(data_transfer.logger, "warning") + spy_info = mocker.spy(data_transfer.logger, "info") + spy_debug = mocker.spy(data_transfer.logger, "debug") + data_transfer.add_companies(companies, full_db) + spy_info.assert_not_called() + spy_warning.assert_called_once_with( + "When adding companies 3 problems occurred 0 where caused by invalid data." + ) + assert spy_debug.call_count == unique_companies + + +@pytest.mark.parametrize("seed", list(range(2, 25, 5))) +def test_add_companies_corrupted_data( + seed: int, mocker: MockerFixture, full_db: Session +) -> None: + """Test to add multiple companies.""" + rnd_generator = np.random.default_rng(seed) + companies: list[dict[str, Any]] = [ + company_generator(_) + for _ in set( + rnd_generator.integers(0, 1000, size=rnd_generator.integers(4, 30)).tolist() + ) + ] + companies[len(companies) // 2]["name"] = "" + spy_warning = mocker.spy(data_transfer.logger, "warning") + spy_info = mocker.spy(data_transfer.logger, "info") + spy_debug = mocker.spy(data_transfer.logger, "debug") + data_transfer.add_companies(companies, full_db) + spy_info.assert_not_called() + spy_warning.assert_called_once_with( + "When adding companies 1 problems occurred 1 where caused by invalid data." + ) + assert spy_debug.call_count == len(companies) - 1 + + +@pytest.mark.parametrize("company_id", list(range(5))) +def test_add_relationship_no_relation(company_id: int, full_db: Session) -> None: + """Tests if an error is thrown if the relation type/role is not defined.""" + with pytest.raises(ValueError, match="A relation type needs to be given."): + data_transfer.add_relationship({}, company_id, full_db) + + +@pytest.mark.parametrize("company_id", list(range(5))) +def test_add_relationship_unknown_relation(company_id: int, full_db: Session) -> None: + """Tests if an error is thrown if the relation type/role is unknown.""" + with pytest.raises(ValueError, match="Relation type .* is not yet implemented!"): + data_transfer.add_relationship( + {"role": "something strange"}, company_id, full_db + ) + + +@pytest.mark.parametrize("company_id", [1, 2, 3]) +@pytest.mark.parametrize( + ("firstname", "surname", "date_of_birth"), + [ + ("Max", "Mustermann", "2023-01-01"), + ("Some Firstname", "Some Surname", "2023-01-01"), + ("Other Firstname", "Other Surname", "1900-01-02"), + ], +) +@pytest.mark.parametrize("role", ["Partner", "direktor", "liquidator"]) +def test_add_relationship_person( # noqa: PLR0913 + firstname: str, + surname: str, + date_of_birth: str, + full_db: Session, + company_id: int, + role: str, +) -> None: + """Tests if a personal relation can be added.""" + relation = { + "name": { + "firstname": firstname, + "lastname": surname, + }, + "date_of_birth": date.fromisoformat(date_of_birth), + "role": role, + } + data_transfer.add_relationship(relation, company_id, full_db) + + +@pytest.mark.parametrize("company_id", [1, 2, 3]) +@pytest.mark.parametrize( + ("firstname", "surname", "date_of_birth"), + [ + ("Max", None, "2023-01-01"), + (None, "Some Surname", "2023-01-01"), + ("Other Firstname", "Other Surname", None), + ], +) +@pytest.mark.parametrize("role", ["Partner"]) +def test_add_relationship_person_missing_data( # noqa: PLR0913 + firstname: str, + surname: str, + date_of_birth: str, + full_db: Session, + company_id: int, + role: str, + mocker: MockerFixture, +) -> None: + """Tests if a personal relation can be added.""" + mocker.spy(data_transfer.logger, "warning") + relation = { + "name": { + "firstname": firstname, + "lastname": surname, + }, + "date_of_birth": date_of_birth if date_of_birth else None, + "role": role, + } + with pytest.raises( + data_transfer.DataInvalidError, match="At least one of the three values name:" + ): + data_transfer.add_relationship(relation, company_id, full_db) + + +@pytest.mark.parametrize( + ("company_name", "city", "zip_code", "company_id"), + [ + ("Some Company GmbH", None, None, 2), + ("Some Company GmbH", None, "12345", 2), + ("Some Company GmbH", "TV City", None, 3), + ("Some Company GmbH", "TV City", "12345", 2), + ("Some Company GmbH", "Strange City", "12345", 2), + ("Some Company GmbH", "TV City", "?????", 2), + ("Third Company GmbH", None, None, 1), + ], +) +def test_add_relationship_company( + company_id: int, + company_name: str, + city: str | None, + zip_code: str | None, + full_db: Session, +) -> None: + """Tests if a relationship to another company can be added.""" + data_transfer.add_relationship( + { + "description": company_name, + "location": { + "zip_code": zip_code, + "city": city, + }, + "role": "organisation", + }, + company_id, + full_db, + ) + + +@pytest.mark.parametrize( + ("company_name", "city", "zip_code", "company_id"), + [ + ("Some Company GmbH", None, None, 1), + ("Some Company GmbH", "TV City", "12345", 1), + ("Some Company GmbH", "TV City", None, 1), + ("Third Company GmbH", None, None, 3), + ], +) +def test_add_relationship_company_self_reference( + company_id: int, + company_name: str, + city: str | None, + zip_code: str | None, + full_db: Session, +) -> None: + """Tests if a company referencing a relationship with itself throws an error.""" + with pytest.raises( + data_transfer.DataInvalidError, + match="For a valid relation both parties can't be the same entity.", + ): + data_transfer.add_relationship( + { + "description": company_name, + "location": { + "zip_code": zip_code, + "city": city, + }, + "role": "organisation", + }, + company_id, + full_db, + ) + + +@pytest.mark.parametrize( + ("company_name", "city", "zip_code", "company_id"), + [ + ("Unknown GmbH", None, None, 2), + ("Some Company GmbH", "Strange city", "?????", 2), + ], +) +def test_add_relationship_company_unknown( + company_id: int, + company_name: str, + city: str | None, + zip_code: str | None, + full_db: Session, +) -> None: + """Tests if a relationship to another company can be added.""" + with pytest.raises( + KeyError, match=f"No corresponding company could be found to {company_name}." + ): + data_transfer.add_relationship( + { + "description": company_name, + "location": { + "zip_code": zip_code, + "city": city, + }, + "role": "organisation", + }, + company_id, + full_db, + ) + + +@pytest.mark.parametrize("empty_relations", [[], [{}], [{"relationship": []}]]) +def test_add_relationships_none(empty_relations: list, full_db: Session) -> None: + """Testing what happens if an empty relation is added.""" + data_transfer.add_relationships([], full_db) + + +@pytest.mark.working_on() +@pytest.mark.parametrize( + "documents", + [ + [ + { + "_id": {"$oid": "649f16a2ecc"}, + "id": { + "hr_number": "HRB 123", + "district_court": { + "name": "Amtsgericht Dortmund", + "city": "Dortmund", + }, + }, + "location": { + "city": "TV City", + "zip_code": "12345", + "street": "Sesamstr.", + "house_number": "1", + }, + "name": "Some Company GmbH", + "last_update": "2023-05-04", + "relationships": [ + { + "name": {"firstname": "Second person", "lastname": "Köstser"}, + "date_of_birth": "1961-02-09", + "location": {"city": "Stuttgart"}, + "role": "Geschäftsführer", + }, + { + "name": {"firstname": "First Person", "lastname": "Jifpa"}, + "date_of_birth": "1976-04-20", + "location": {"city": "Stuttgart"}, + "role": "Geschäftsführer", + }, + { + "name": {"firstname": "", "lastname": "Jiapa"}, + "date_of_birth": "1976-04-20", + "location": {"city": "Stuttgart"}, + "role": "Geschäftsführer", + }, + { + "name": {"firstname": "Something", "lastname": ""}, + "date_of_birth": "12i3u", + "location": {"city": "Stuttgart"}, + "role": "Geschäftsführer", + }, + { + "name": {"firstname": "First Person", "lastname": "Jipha"}, + "date_of_birth": "1976-04-20", + }, + ], + "yearly_results": {}, + } + ] + ], +) +def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> None: + """Testing to add lots of relations.""" + data_transfer.add_relationships(documents, full_db) + bind = full_db.bind + assert isinstance(bind, Engine) + pd.testing.assert_frame_equal( + pd.read_sql_table("company", bind), + pd.DataFrame( + { + "id": {0: 1, 1: 2, 2: 3}, + "hr": {0: "HRB 123", 1: "HRB 123", 2: "HRB 12"}, + "court_id": {0: 2, 1: 1, 2: 2}, + "name": { + 0: "Some Company GmbH", + 1: "Other Company GmbH", + 2: "Third Company GmbH", + }, + "street": {0: "Sesamstr.", 1: "Sesamstr.", 2: None}, + "zip_code": {0: "12345", 1: "12345", 2: None}, + "city": {0: "TV City", 1: "TV City", 2: None}, + "last_update": { + 0: pd.Timestamp("2023-01-01 00:00:00"), + 1: pd.Timestamp("2023-01-01 00:00:00"), + 2: pd.Timestamp("2023-01-01 00:00:00"), + }, + "sector": {0: None, 1: None, 2: None}, + } + ), + ) + assert len(pd.read_sql_table("company_relation", bind).index) == 0 + pd.testing.assert_frame_equal( + pd.read_sql_table("person_relation", bind), + pd.DataFrame({"id": {0: 1, 1: 2}, "person_id": {0: 6, 1: 7}}), + ) + pd.testing.assert_frame_equal( + pd.read_sql_table("relation", bind), + pd.DataFrame( + { + "id": {0: 1, 1: 2}, + "company_id": {0: 1, 1: 1}, + "date_from": {0: pd.NaT, 1: pd.NaT}, + "date_to": {0: pd.NaT, 1: pd.NaT}, + "relation": {0: "GESCHAEFTSFUEHRER", 1: "GESCHAEFTSFUEHRER"}, + } + ), + ) + pd.testing.assert_frame_equal( + pd.read_sql_table("person", bind), + pd.DataFrame( + { + "id": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7}, + "name": { + 0: "Max", + 1: "Sabine", + 2: "Some Firstname", + 3: "Some Firstname", + 4: "Other Firstname", + 5: "Second person", + 6: "First Person", + }, + "surname": { + 0: "Mustermann", + 1: "Mustermann", + 2: "Some Surname", + 3: "Some Surname", + 4: "Other Surname", + 5: "Köstser", + 6: "Jifpa", + }, + "date_of_birth": { + 0: pd.Timestamp("2023-01-01 00:00:00"), + 1: pd.Timestamp("2023-01-01 00:00:00"), + 2: pd.Timestamp("2023-01-01 00:00:00"), + 3: pd.Timestamp("2023-01-02 00:00:00"), + 4: pd.Timestamp("2023-01-02 00:00:00"), + 5: pd.Timestamp("1961-02-09 00:00:00"), + 6: pd.Timestamp("1976-04-20 00:00:00"), + }, + "works_for": { + 0: None, + 1: None, + 2: None, + 3: None, + 4: None, + 5: None, + 6: None, + }, + } + ), + ) diff --git a/tests/utils/enum_types_test.py b/tests/utils/enum_types_test.py new file mode 100644 index 0000000..f9f744c --- /dev/null +++ b/tests/utils/enum_types_test.py @@ -0,0 +1,40 @@ +"""Tests for the enumeration types.""" +import pytest + +from aki_prj23_transparenzregister.utils import enum_types + + +def test_import() -> None: + """Tests if enum_types can be imported.""" + assert enum_types + + +@pytest.mark.parametrize("relation_name", ["Vorstand", "Prokurist", "Direktor"]) +@pytest.mark.parametrize("changes", ["lower", "upper", None]) +def test_relation_type_enum_from_string( + relation_name: str, changes: str | None +) -> None: + """Tests the transformation of a name to an enumeration type.""" + if changes == "lower": + relation_name = relation_name.lower() + elif changes == "upper": + relation_name = relation_name.upper() + + assert isinstance( + enum_types.RelationTypeEnum.get_enum_from_name(relation_name), + enum_types.RelationTypeEnum, + ) + + +@pytest.mark.parametrize("relation_name", ["does Not Exists", "Also not"]) +@pytest.mark.parametrize("changes", ["lower", "upper", None]) +def test_relation_type_enum_from_string_wrong( + relation_name: str, changes: str | None +) -> None: + """Tests the transformation of a name to an enumeration type if no equivalent can be found.""" + if changes == "lower": + relation_name = relation_name.lower() + elif changes == "upper": + relation_name = relation_name.upper() + with pytest.raises(ValueError, match='Relation type ".*" is not yet implemented!'): + enum_types.RelationTypeEnum.get_enum_from_name(relation_name) diff --git a/tests/utils/mongo/mongo_test.py b/tests/utils/mongo/mongo_test.py index 51a1abe..eb63a4b 100644 --- a/tests/utils/mongo/mongo_test.py +++ b/tests/utils/mongo/mongo_test.py @@ -1,3 +1,4 @@ +"""Tests for connecting to the mongodb.""" from unittest.mock import patch from aki_prj23_transparenzregister.utils.mongo.connector import ( @@ -7,21 +8,25 @@ from aki_prj23_transparenzregister.utils.mongo.connector import ( def test_get_conn_string_no_credentials() -> None: + """Tests the mongo connection string generation.""" conn = MongoConnection("localhost", "", 27017, None, None) assert conn.get_conn_string() == "mongodb://localhost:27017" def test_get_conn_string_no_port_but_credentials() -> None: + """Tests the mongo connection string generation.""" conn = MongoConnection("localhost", "", None, "admin", "password") assert conn.get_conn_string() == "mongodb+srv://admin:password@localhost" def test_get_conn_simple() -> None: + """Tests the mongo connection string generation.""" conn = MongoConnection("localhost", "", None, None, None) assert conn.get_conn_string() == "mongodb+srv://localhost" def test_mongo_connector() -> None: + """Tests the MongoConnector.""" with patch("pymongo.MongoClient") as mock_mongo_client: expected_result = 42 mock_mongo_client.return_value = {"db": expected_result} diff --git a/tests/utils/mongo/news_mongo_service_test.py b/tests/utils/mongo/news_mongo_service_test.py index ddf1564..9257c51 100644 --- a/tests/utils/mongo/news_mongo_service_test.py +++ b/tests/utils/mongo/news_mongo_service_test.py @@ -1,3 +1,4 @@ +"""Tests for the mongo news service.""" from unittest.mock import Mock, patch import pytest @@ -50,6 +51,7 @@ def test_init(mock_mongo_connector: Mock, mock_collection: Mock) -> None: def test_get_all(mock_mongo_connector: Mock, mock_collection: Mock) -> None: + """Tests the get_all function from the mongo connector.""" mock_mongo_connector.database = {"news": mock_collection} service = MongoNewsService(mock_mongo_connector) @@ -60,6 +62,7 @@ def test_get_all(mock_mongo_connector: Mock, mock_collection: Mock) -> None: def test_get_by_id_with_result( mock_mongo_connector: Mock, mock_collection: Mock ) -> None: + """Tests the get_by_id_with_result function from the mongo connector.""" mock_mongo_connector.database = {"news": mock_collection} service = MongoNewsService(mock_mongo_connector) @@ -72,6 +75,7 @@ def test_get_by_id_with_result( def test_get_by_id_no_result(mock_mongo_connector: Mock, mock_collection: Mock) -> None: + """Test if the mongo connector can get an object by id.""" mock_mongo_connector.database = {"news": mock_collection} service = MongoNewsService(mock_mongo_connector) @@ -80,6 +84,7 @@ def test_get_by_id_no_result(mock_mongo_connector: Mock, mock_collection: Mock) def test_insert(mock_mongo_connector: Mock, mock_collection: Mock) -> None: + """Tests the insert function from the mongo connector.""" mock_mongo_connector.database = {"news": mock_collection} service = MongoNewsService(mock_mongo_connector) @@ -92,6 +97,7 @@ def test_insert(mock_mongo_connector: Mock, mock_collection: Mock) -> None: def test_transform_ingoing() -> None: + """Tests the transform_ingoing function from the mongo connector.""" news = News("42", None, None, None, None) # type: ignore result = MongoEntryTransformer.transform_ingoing(news) assert result["_id"] == "42" @@ -99,6 +105,7 @@ def test_transform_ingoing() -> None: def test_transform_outgoing() -> None: + """Tests the transform_outgoing function from the mongo connector.""" data = { "_id": "4711", "title": "Hello", diff --git a/tests/utils/sql/connector_test.py b/tests/utils/sql/connector_test.py index 658ba57..a671883 100644 --- a/tests/utils/sql/connector_test.py +++ b/tests/utils/sql/connector_test.py @@ -1,3 +1,4 @@ +"""Tests the sql connector.""" import os.path from collections.abc import Generator from typing import Any @@ -16,6 +17,7 @@ from aki_prj23_transparenzregister.utils.sql.connector import ( def test_get_engine_pg() -> None: + """Tests the creation of a postgre engine.""" conn_args = PostgreConnectionString("", "", "", "", 42) with patch( "aki_prj23_transparenzregister.utils.sql.connector.sa.create_engine" diff --git a/tests/utils/sql/entities_test.py b/tests/utils/sql/entities_test.py index 14bc361..bcd3068 100644 --- a/tests/utils/sql/entities_test.py +++ b/tests/utils/sql/entities_test.py @@ -1,4 +1,8 @@ -def test_import() -> None: - from aki_prj23_transparenzregister.utils.sql import entities +"""Tests for the sql entities.""" +from aki_prj23_transparenzregister.utils.sql import entities + + +def test_import() -> None: # + """Tests if the entities can be imported.""" assert entities diff --git a/tests/utils/string_tools_test.py b/tests/utils/string_tools_test.py new file mode 100644 index 0000000..26a7b1b --- /dev/null +++ b/tests/utils/string_tools_test.py @@ -0,0 +1,35 @@ +"""Tests for the string tool module.""" +from typing import Any + +import pytest + +from aki_prj23_transparenzregister.utils import string_tools + + +def test_import() -> None: + """Tests if the import is possible.""" + assert string_tools + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + ("None ", "None"), + (" ", None), + ("", None), + ("\t", None), + ("\n", None), + (" Some String ", "Some String"), + ("Some String", "Some String"), + ], +) +def test_simplify_string(value: str | None, expected: str | None) -> None: + """Tests the sting simplification.""" + assert string_tools.simplify_string(value) == expected + + +@pytest.mark.parametrize("value", [0, 0.1, True, ("1",), {}, set()]) +def test_simplify_string_type_error(value: Any) -> None: + """Tests if the type error is thrown when the value is the wrong type.""" + with pytest.raises(TypeError): + assert string_tools.simplify_string(value)