diff --git a/.gitignore b/.gitignore index f536711..c6cbd81 100644 --- a/.gitignore +++ b/.gitignore @@ -221,4 +221,6 @@ replay_pid* /lbr-audit.md /.ruff_cache/ /Jupyter/test.ipynb -/secrets*.json +secrets*.json +*.db-journal +*.db diff --git a/README.md b/README.md index 1f11c25..679df46 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,8 @@ See the [CONTRIBUTING.md](CONTRIBUTING.md) about how code should be formatted an The project has currently the following entrypoint available: -- data-transfer > Transfers all the data from the mongodb into the sql db to make it available as production data. -- reset-sql > Resets all sql tables in the connected db. +- **data-transfer** > Transfers all the data from the mongodb into the sql db to make it available as production data. +- **reset-sql** > Resets all sql tables in the connected db. ## DB Connection settings diff --git a/pyproject.toml b/pyproject.toml index 3c01d8c..0d730db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ pytest-mock = "^3.11.1" pytest-repeat = "^0.9.1" [tool.poetry.scripts] +copy-sql = "aki_prj23_transparenzregister.utils.sql.copy_sql:copy_db_cli" data-transfer = "aki_prj23_transparenzregister.utils.data_transfer:transfer_data" reset-sql = "aki_prj23_transparenzregister.utils.sql.connector:reset_all_tables" @@ -136,7 +137,7 @@ unfixable = ["B"] builtins-ignorelist = ["id"] [tool.ruff.per-file-ignores] -"tests/*.py" = ["S101", "SLF001", "S311", "D103"] +"tests/*.py" = ["S101", "SLF001", "S311", "D103", "PLR0913"] [tool.ruff.pydocstyle] convention = "google" diff --git a/src/aki_prj23_transparenzregister/config/config_providers.py b/src/aki_prj23_transparenzregister/config/config_providers.py index be818c2..51a76fc 100644 --- a/src/aki_prj23_transparenzregister/config/config_providers.py +++ b/src/aki_prj23_transparenzregister/config/config_providers.py @@ -1,8 +1,10 @@ """Wrappers for config providers.""" import abc +import errno import json import os +from pathlib import Path from dotenv import load_dotenv @@ -41,7 +43,7 @@ class JsonFileConfigProvider(ConfigProvider): __data__: dict = {} - def __init__(self, file_path: str): + def __init__(self, file_path: str | Path): """Constructor reading its data from a given .json file. Args: @@ -52,7 +54,7 @@ class JsonFileConfigProvider(ConfigProvider): TypeError: File could not be read or is malformed """ if not os.path.isfile(file_path): - raise FileNotFoundError + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file_path) with open(file_path) as file: try: data = json.loads(file.read()) diff --git a/src/aki_prj23_transparenzregister/utils/data_transfer.py b/src/aki_prj23_transparenzregister/utils/data_transfer.py index f28712c..65be641 100644 --- a/src/aki_prj23_transparenzregister/utils/data_transfer.py +++ b/src/aki_prj23_transparenzregister/utils/data_transfer.py @@ -1,5 +1,4 @@ """This module contains the data transfer and refinement functionalities between staging and production DB.""" -import sys from datetime import date from typing import Any @@ -11,6 +10,7 @@ from tqdm import tqdm from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider from aki_prj23_transparenzregister.utils.enum_types import RelationTypeEnum +from aki_prj23_transparenzregister.utils.logger_config import configer_logger from aki_prj23_transparenzregister.utils.mongo.company_mongo_service import ( CompanyMongoService, ) @@ -266,10 +266,11 @@ def add_relationship( relation: entities.CompanyRelation | entities.PersonRelation if "date_of_birth" in relationship: name = relationship["name"] + date_of_brith: str = relationship["date_of_birth"] person_id = get_person_id( name["firstname"], name["lastname"], - relationship["date_of_birth"], + date_of_brith, db, ) relation = entities.PersonRelation( @@ -278,12 +279,16 @@ def add_relationship( relation=relation_type, ) else: - relation_to: int = get_company_id( - relationship["description"], - relationship["location"]["zip_code"], - relationship["location"]["city"], - db=db, - ) + try: + relation_to: int = get_company_id( + relationship["description"], + relationship["location"]["zip_code"], + relationship["location"]["city"], + db=db, + ) + except KeyError as err: + logger.warning(err) + return if company_id == relation_to: raise DataInvalidError( "For a valid relation both parties can't be the same entity." @@ -414,13 +419,7 @@ def add_annual_financial_reports(companies: list[dict], db: Session) -> None: def transfer_data(db: Session | None = None) -> None: """This functions transfers all the data from a production environment to a staging environment.""" - logger.remove() - logger.add( - sys.stdout, - level="INFO", - catch=True, - ) - logger.add("data-transfer.log", level="INFO", retention=5) + configer_logger("info", "data-transfer.log") mongo_connector = MongoConnector( JsonFileConfigProvider("./secrets.json").get_mongo_connection_string() diff --git a/src/aki_prj23_transparenzregister/utils/logger_config.py b/src/aki_prj23_transparenzregister/utils/logger_config.py new file mode 100644 index 0000000..c639e1e --- /dev/null +++ b/src/aki_prj23_transparenzregister/utils/logger_config.py @@ -0,0 +1,22 @@ +"""Configures the logger.""" +import sys +from pathlib import Path +from typing import Literal + +from loguru import logger + + +def configer_logger( + level: Literal["info", "debug", "warning", "error"], + path: str | Path, +) -> None: + """Configures the logger. + + Args: + level: Defines the logging level that should be used. + path: The path where the logs should be saved. + """ + logger.remove() + logger.add(sys.stdout, level=level.upper(), catch=True) + if path: + logger.add(path, level=level.upper(), retention=5) diff --git a/src/aki_prj23_transparenzregister/utils/sql/connector.py b/src/aki_prj23_transparenzregister/utils/sql/connector.py index b3e5d23..922bc24 100644 --- a/src/aki_prj23_transparenzregister/utils/sql/connector.py +++ b/src/aki_prj23_transparenzregister/utils/sql/connector.py @@ -1,13 +1,11 @@ """Module containing connection utils for PostgreSQL DB.""" import re -import pandas as pd import sqlalchemy as sa from loguru import logger from sqlalchemy.engine import URL, Engine from sqlalchemy.orm import Session, declarative_base, sessionmaker from sqlalchemy.pool import SingletonThreadPool -from tqdm import tqdm from aki_prj23_transparenzregister.config.config_providers import ( ConfigProvider, @@ -58,6 +56,9 @@ def get_session( A session to connect to an SQL db via SQLAlchemy. """ engine: Engine + if isinstance(connect_to, str) and re.fullmatch(r".*\.json$", connect_to): + logger.debug(connect_to) + connect_to = JsonFileConfigProvider(connect_to) if isinstance(connect_to, ConfigProvider): engine = get_pg_engine(connect_to.get_postgre_connection_string()) @@ -90,26 +91,6 @@ def reset_all_tables(db: Session) -> None: init_db(db) -@logger.catch(reraise=True) -def transfer_db(*, source: Session, destination: Session) -> None: - """Transfers the data from on db to another db. - - Args: - source: A session to a source db data should be copied from. - destination: A session to a db where the data should be copied to. - """ - reset_all_tables(destination) - init_db(destination) - sbind = source.bind - dbind = destination.bind - assert isinstance(sbind, Engine) # noqa: S101 - assert isinstance(dbind, Engine) # noqa: S101 - for table in tqdm(Base.metadata.sorted_tables): - pd.read_sql_table(str(table), sbind).to_sql( - str(table), dbind, if_exists="append", index=False - ) - - if __name__ == "__main__": """Main flow creating tables""" init_db(get_session(JsonFileConfigProvider("./secrets.json"))) diff --git a/src/aki_prj23_transparenzregister/utils/sql/copy_sql.py b/src/aki_prj23_transparenzregister/utils/sql/copy_sql.py new file mode 100644 index 0000000..37a7ab8 --- /dev/null +++ b/src/aki_prj23_transparenzregister/utils/sql/copy_sql.py @@ -0,0 +1,85 @@ +"""Functions to copy a sql table.""" +import argparse +import sys + +import pandas as pd +from loguru import logger +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session + +from aki_prj23_transparenzregister.utils.logger_config import configer_logger +from aki_prj23_transparenzregister.utils.sql.connector import ( + Base, + get_session, + reset_all_tables, +) + + +@logger.catch(reraise=True) +def transfer_db_function(*, source: Session, destination: Session) -> None: + """Transfers the data from on db to another db. + + Args: + source: A session to a source db data should be copied from. + destination: A session to a db where the data should be copied to. + """ + reset_all_tables(destination) + # init_db(destination) + sbind = source.bind + dbind = destination.bind + assert isinstance(sbind, Engine) # noqa: S101 + assert isinstance(dbind, Engine) # noqa: S101 + for table in Base.metadata.sorted_tables: + logger.info(f"Transferring table {table} from source to destination db.") + pd.read_sql_table(str(table), sbind).to_sql( + str(table), dbind, if_exists="append", index=False + ) + + +def copy_db_cli(args: list[str] | None = None) -> None: + """CLI interfaces to copy a db from source to destination. + + Args: + args: The args ar automaticlly collected from the cli if none are given. They should only be given for testing. + """ + parser = argparse.ArgumentParser( + prog="copy-sql", + description="Copy data from one SQL database to another.", + epilog="Example: copy-sql source.db destination.json", + ) + parser.add_argument( + "source", + metavar="source", + help="Source database configuration.", + ) + parser.add_argument( + "destination", + metavar="destination", + help="Destination database configuration.", + ) + parser.add_argument( + "--log-level", + choices=["info", "debug", "error", "warning"], + default="info", + metavar="log-level", + help="The log level for the output.", + ) + parser.add_argument( + "--log-path", + metavar="log_path", + help="A path to write the log to.", + ) + + if not args: + args = sys.argv[1:] + + parsed = parser.parse_args(args) + configer_logger(level=parsed.log_level, path=parsed.log_path) + source = get_session(parsed.source) + logger.info(f"Connecting to {source.bind} as a source to copy from.") + destination = get_session(parsed.destination) + logger.info(f"Connecting to {destination.bind} as a destination to copy to.") + transfer_db_function( + source=source, + destination=destination, + ) diff --git a/tests/utils/data_transfer_test.py b/tests/utils/data_transfer_test.py index b92bb28..61722db 100644 --- a/tests/utils/data_transfer_test.py +++ b/tests/utils/data_transfer_test.py @@ -571,23 +571,25 @@ def test_add_relationship_company_unknown( city: str | None, zip_code: str | None, full_db: Session, + mocker: MockerFixture, ) -> None: """Tests if a relationship to another company can be added.""" - with pytest.raises( - KeyError, match=f"No corresponding company could be found to {company_name}." - ): - data_transfer.add_relationship( - { - "description": company_name, - "location": { - "zip_code": zip_code, - "city": city, - }, - "role": "organisation", + spy_warning = mocker.spy(data_transfer.logger, "warning") + spy_info = mocker.spy(data_transfer.logger, "info") + data_transfer.add_relationship( + { + "description": company_name, + "location": { + "zip_code": zip_code, + "city": city, }, - company_id, - full_db, - ) + "role": "organisation", + }, + company_id, + full_db, + ) + spy_warning.assert_called_once() + spy_info.assert_not_called() @pytest.mark.parametrize("empty_relations", [[], [{}], [{"relationship": []}]]) @@ -778,7 +780,7 @@ def test_add_annual_financial_reports_no_call( ) -> None: """Testing if financial reports are added correctly to the db.""" spy_warning = mocker.spy(data_transfer.logger, "warning") - info_warning = mocker.spy(data_transfer.logger, "info") + spy_info = mocker.spy(data_transfer.logger, "info") mocker.patch("aki_prj23_transparenzregister.utils.data_transfer.add_annual_report") data_transfer.add_annual_financial_reports(companies, full_db) @@ -786,7 +788,7 @@ def test_add_annual_financial_reports_no_call( input_kwargs = mocker.call.kwargs assert len(input_args) == len(input_kwargs) spy_warning.assert_not_called() - info_warning.assert_called_once() + spy_info.assert_called_once() @pytest.mark.parametrize( @@ -821,7 +823,7 @@ def test_add_annual_financial_reports_defect_year( ) -> None: """Testing if financial reports are added correctly to the db.""" spy_warning = mocker.spy(data_transfer.logger, "warning") - info_warning = mocker.spy(data_transfer.logger, "info") + spy_info = mocker.spy(data_transfer.logger, "info") mocker.patch("aki_prj23_transparenzregister.utils.data_transfer.add_annual_report") data_transfer.add_annual_financial_reports(companies, full_db) @@ -829,7 +831,7 @@ def test_add_annual_financial_reports_defect_year( input_kwargs = mocker.call.kwargs assert len(input_args) == len(input_kwargs) spy_warning.assert_called_once() - info_warning.assert_called_once() + spy_info.assert_called_once() def test_add_annual_financial_reports(full_db: Session, mocker: MockerFixture) -> None: @@ -864,7 +866,7 @@ def test_add_annual_financial_reports(full_db: Session, mocker: MockerFixture) - ] spy_warning = mocker.spy(data_transfer.logger, "warning") - info_warning = mocker.spy(data_transfer.logger, "info") + spy_info = mocker.spy(data_transfer.logger, "info") mocked = mocker.patch( "aki_prj23_transparenzregister.utils.data_transfer.add_annual_report" ) @@ -890,7 +892,7 @@ def test_add_annual_financial_reports(full_db: Session, mocker: MockerFixture) - for input_args in mocked.call_args_list: assert isinstance(input_args.kwargs["db"], Session) - info_warning.assert_called_once() + spy_info.assert_called_once() @pytest.mark.parametrize("year", list(range(2000, 2025, 5))) diff --git a/tests/utils/logger_config_test.py b/tests/utils/logger_config_test.py new file mode 100644 index 0000000..03263f9 --- /dev/null +++ b/tests/utils/logger_config_test.py @@ -0,0 +1,26 @@ +"""Smoke-test over the logger config.""" +from pathlib import Path + +import pytest + +from aki_prj23_transparenzregister.utils.logger_config import configer_logger + + +@pytest.mark.parametrize("path", [None, "test-log.log", ""]) +@pytest.mark.parametrize("upper", [True, False]) +@pytest.mark.parametrize("level", ["info", "debug", "error", "warning"]) +def test_configer_logger( + level: str, + upper: bool, + path: Path | str | None, +) -> None: + """Tests the configuration of the logger. + + Args: + level: The log-level to configure. + upper: If the upper variant of the level should be used. + path: The path where to save the log. + """ + if level.upper(): + level = level.upper() + configer_logger(level, path) # type: ignore diff --git a/tests/utils/sql/connector_test.py b/tests/utils/sql/connector_test.py index 6dbd9ba..a671883 100644 --- a/tests/utils/sql/connector_test.py +++ b/tests/utils/sql/connector_test.py @@ -4,19 +4,15 @@ from collections.abc import Generator from typing import Any from unittest.mock import Mock, patch -import pandas as pd import pytest from sqlalchemy.engine import Engine -from sqlalchemy.orm import Session from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider from aki_prj23_transparenzregister.config.config_template import PostgreConnectionString from aki_prj23_transparenzregister.utils.sql.connector import ( - Base, get_pg_engine, get_session, init_db, - transfer_db, ) @@ -31,36 +27,6 @@ def test_get_engine_pg() -> None: assert get_pg_engine(conn_args) == result -@pytest.fixture() -def destination_db() -> Generator[Session, None, None]: - """Generates a db Session to a sqlite db to copy data to.""" - if os.path.exists("secondary.db"): - os.remove("secondary.db") - db = get_session("sqlite:///secondary.db") - init_db(db) - yield db - db.close() - bind = db.bind - assert isinstance(bind, Engine) - bind.dispose() - os.remove("secondary.db") - - -def test_transfer_db(full_db: Session, destination_db: Session) -> None: - """Tests if the data transfer between two sql tables works.""" - transfer_db(source=full_db, destination=destination_db) - sbind = full_db.bind - dbind = destination_db.bind - assert isinstance(sbind, Engine) - assert isinstance(dbind, Engine) - - for table in Base.metadata.sorted_tables: - pd.testing.assert_frame_equal( - pd.read_sql_table(str(table), dbind), - pd.read_sql_table(str(table), sbind), - ) - - @pytest.fixture() def delete_sqlite_table() -> Generator[str, None, None]: """Cleans a path before and deletes the table after a test. diff --git a/tests/utils/sql/copy_sql_test.py b/tests/utils/sql/copy_sql_test.py new file mode 100644 index 0000000..ebe3cea --- /dev/null +++ b/tests/utils/sql/copy_sql_test.py @@ -0,0 +1,56 @@ +"""Test if the sql db can be copied.""" +import os +from collections.abc import Generator + +import pandas as pd +import pytest +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session + +from aki_prj23_transparenzregister.utils.sql.connector import Base, get_session, init_db +from aki_prj23_transparenzregister.utils.sql.copy_sql import ( + copy_db_cli, + transfer_db_function, +) + + +@pytest.fixture() +def destination_db() -> Generator[Session, None, None]: + """Generates a db Session to a sqlite db to copy data to.""" + if os.path.exists("secondary.db"): + os.remove("secondary.db") + db = get_session("sqlite:///secondary.db") + init_db(db) + yield db + db.close() + bind = db.bind + assert isinstance(bind, Engine) + bind.dispose() + os.remove("secondary.db") + + +def test_transfer_db(full_db: Session, destination_db: Session) -> None: + """Tests if the data transfer between two sql tables works.""" + transfer_db_function(source=full_db, destination=destination_db) + sbind = full_db.bind + dbind = destination_db.bind + assert isinstance(sbind, Engine) + assert isinstance(dbind, Engine) + assert Base.metadata.sorted_tables + for table in Base.metadata.sorted_tables + ["company"]: + pd.testing.assert_frame_equal( + pd.read_sql_table(str(table), dbind), + pd.read_sql_table(str(table), sbind), + ) + + +def test_copy_db_cli_help1() -> None: + """Tests if the help argument exits the software gracefully.""" + with pytest.raises(SystemExit): + copy_db_cli(["-h"]) + + +def test_copy_db_cli_help2() -> None: + """Tests if the help argument exits the software gracefully.""" + with pytest.raises(SystemExit): + copy_db_cli(["eskse", "-h", "asdf"])