mirror of
https://github.com/fhswf/aki_prj23_transparenzregister.git
synced 2025-06-22 07:53:55 +02:00
Repaired the SQL copy and reduced the log volume a bit (#141)
- Added a cli interface to the SQL copy - Repaired the SQL copy function - Added the SQL copy function to the scripts - Reduced the logging verbosity
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@ -221,4 +221,6 @@ replay_pid*
|
|||||||
/lbr-audit.md
|
/lbr-audit.md
|
||||||
/.ruff_cache/
|
/.ruff_cache/
|
||||||
/Jupyter/test.ipynb
|
/Jupyter/test.ipynb
|
||||||
/secrets*.json
|
secrets*.json
|
||||||
|
*.db-journal
|
||||||
|
*.db
|
||||||
|
@ -16,8 +16,8 @@ See the [CONTRIBUTING.md](CONTRIBUTING.md) about how code should be formatted an
|
|||||||
|
|
||||||
The project has currently the following entrypoint available:
|
The project has currently the following entrypoint available:
|
||||||
|
|
||||||
- data-transfer > Transfers all the data from the mongodb into the sql db to make it available as production data.
|
- **data-transfer** > Transfers all the data from the mongodb into the sql db to make it available as production data.
|
||||||
- reset-sql > Resets all sql tables in the connected db.
|
- **reset-sql** > Resets all sql tables in the connected db.
|
||||||
|
|
||||||
## DB Connection settings
|
## DB Connection settings
|
||||||
|
|
||||||
|
@ -96,6 +96,7 @@ pytest-mock = "^3.11.1"
|
|||||||
pytest-repeat = "^0.9.1"
|
pytest-repeat = "^0.9.1"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
|
copy-sql = "aki_prj23_transparenzregister.utils.sql.copy_sql:copy_db_cli"
|
||||||
data-transfer = "aki_prj23_transparenzregister.utils.data_transfer:transfer_data"
|
data-transfer = "aki_prj23_transparenzregister.utils.data_transfer:transfer_data"
|
||||||
reset-sql = "aki_prj23_transparenzregister.utils.sql.connector:reset_all_tables"
|
reset-sql = "aki_prj23_transparenzregister.utils.sql.connector:reset_all_tables"
|
||||||
|
|
||||||
@ -136,7 +137,7 @@ unfixable = ["B"]
|
|||||||
builtins-ignorelist = ["id"]
|
builtins-ignorelist = ["id"]
|
||||||
|
|
||||||
[tool.ruff.per-file-ignores]
|
[tool.ruff.per-file-ignores]
|
||||||
"tests/*.py" = ["S101", "SLF001", "S311", "D103"]
|
"tests/*.py" = ["S101", "SLF001", "S311", "D103", "PLR0913"]
|
||||||
|
|
||||||
[tool.ruff.pydocstyle]
|
[tool.ruff.pydocstyle]
|
||||||
convention = "google"
|
convention = "google"
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
"""Wrappers for config providers."""
|
"""Wrappers for config providers."""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import errno
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
@ -41,7 +43,7 @@ class JsonFileConfigProvider(ConfigProvider):
|
|||||||
|
|
||||||
__data__: dict = {}
|
__data__: dict = {}
|
||||||
|
|
||||||
def __init__(self, file_path: str):
|
def __init__(self, file_path: str | Path):
|
||||||
"""Constructor reading its data from a given .json file.
|
"""Constructor reading its data from a given .json file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -52,7 +54,7 @@ class JsonFileConfigProvider(ConfigProvider):
|
|||||||
TypeError: File could not be read or is malformed
|
TypeError: File could not be read or is malformed
|
||||||
"""
|
"""
|
||||||
if not os.path.isfile(file_path):
|
if not os.path.isfile(file_path):
|
||||||
raise FileNotFoundError
|
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file_path)
|
||||||
with open(file_path) as file:
|
with open(file_path) as file:
|
||||||
try:
|
try:
|
||||||
data = json.loads(file.read())
|
data = json.loads(file.read())
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
"""This module contains the data transfer and refinement functionalities between staging and production DB."""
|
"""This module contains the data transfer and refinement functionalities between staging and production DB."""
|
||||||
import sys
|
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -11,6 +10,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider
|
from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider
|
||||||
from aki_prj23_transparenzregister.utils.enum_types import RelationTypeEnum
|
from aki_prj23_transparenzregister.utils.enum_types import RelationTypeEnum
|
||||||
|
from aki_prj23_transparenzregister.utils.logger_config import configer_logger
|
||||||
from aki_prj23_transparenzregister.utils.mongo.company_mongo_service import (
|
from aki_prj23_transparenzregister.utils.mongo.company_mongo_service import (
|
||||||
CompanyMongoService,
|
CompanyMongoService,
|
||||||
)
|
)
|
||||||
@ -266,10 +266,11 @@ def add_relationship(
|
|||||||
relation: entities.CompanyRelation | entities.PersonRelation
|
relation: entities.CompanyRelation | entities.PersonRelation
|
||||||
if "date_of_birth" in relationship:
|
if "date_of_birth" in relationship:
|
||||||
name = relationship["name"]
|
name = relationship["name"]
|
||||||
|
date_of_brith: str = relationship["date_of_birth"]
|
||||||
person_id = get_person_id(
|
person_id = get_person_id(
|
||||||
name["firstname"],
|
name["firstname"],
|
||||||
name["lastname"],
|
name["lastname"],
|
||||||
relationship["date_of_birth"],
|
date_of_brith,
|
||||||
db,
|
db,
|
||||||
)
|
)
|
||||||
relation = entities.PersonRelation(
|
relation = entities.PersonRelation(
|
||||||
@ -278,12 +279,16 @@ def add_relationship(
|
|||||||
relation=relation_type,
|
relation=relation_type,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
relation_to: int = get_company_id(
|
try:
|
||||||
relationship["description"],
|
relation_to: int = get_company_id(
|
||||||
relationship["location"]["zip_code"],
|
relationship["description"],
|
||||||
relationship["location"]["city"],
|
relationship["location"]["zip_code"],
|
||||||
db=db,
|
relationship["location"]["city"],
|
||||||
)
|
db=db,
|
||||||
|
)
|
||||||
|
except KeyError as err:
|
||||||
|
logger.warning(err)
|
||||||
|
return
|
||||||
if company_id == relation_to:
|
if company_id == relation_to:
|
||||||
raise DataInvalidError(
|
raise DataInvalidError(
|
||||||
"For a valid relation both parties can't be the same entity."
|
"For a valid relation both parties can't be the same entity."
|
||||||
@ -414,13 +419,7 @@ def add_annual_financial_reports(companies: list[dict], db: Session) -> None:
|
|||||||
|
|
||||||
def transfer_data(db: Session | None = None) -> None:
|
def transfer_data(db: Session | None = None) -> None:
|
||||||
"""This functions transfers all the data from a production environment to a staging environment."""
|
"""This functions transfers all the data from a production environment to a staging environment."""
|
||||||
logger.remove()
|
configer_logger("info", "data-transfer.log")
|
||||||
logger.add(
|
|
||||||
sys.stdout,
|
|
||||||
level="INFO",
|
|
||||||
catch=True,
|
|
||||||
)
|
|
||||||
logger.add("data-transfer.log", level="INFO", retention=5)
|
|
||||||
|
|
||||||
mongo_connector = MongoConnector(
|
mongo_connector = MongoConnector(
|
||||||
JsonFileConfigProvider("./secrets.json").get_mongo_connection_string()
|
JsonFileConfigProvider("./secrets.json").get_mongo_connection_string()
|
||||||
|
22
src/aki_prj23_transparenzregister/utils/logger_config.py
Normal file
22
src/aki_prj23_transparenzregister/utils/logger_config.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
"""Configures the logger."""
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
def configer_logger(
|
||||||
|
level: Literal["info", "debug", "warning", "error"],
|
||||||
|
path: str | Path,
|
||||||
|
) -> None:
|
||||||
|
"""Configures the logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Defines the logging level that should be used.
|
||||||
|
path: The path where the logs should be saved.
|
||||||
|
"""
|
||||||
|
logger.remove()
|
||||||
|
logger.add(sys.stdout, level=level.upper(), catch=True)
|
||||||
|
if path:
|
||||||
|
logger.add(path, level=level.upper(), retention=5)
|
@ -1,13 +1,11 @@
|
|||||||
"""Module containing connection utils for PostgreSQL DB."""
|
"""Module containing connection utils for PostgreSQL DB."""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from sqlalchemy.engine import URL, Engine
|
from sqlalchemy.engine import URL, Engine
|
||||||
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||||
from sqlalchemy.pool import SingletonThreadPool
|
from sqlalchemy.pool import SingletonThreadPool
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from aki_prj23_transparenzregister.config.config_providers import (
|
from aki_prj23_transparenzregister.config.config_providers import (
|
||||||
ConfigProvider,
|
ConfigProvider,
|
||||||
@ -58,6 +56,9 @@ def get_session(
|
|||||||
A session to connect to an SQL db via SQLAlchemy.
|
A session to connect to an SQL db via SQLAlchemy.
|
||||||
"""
|
"""
|
||||||
engine: Engine
|
engine: Engine
|
||||||
|
if isinstance(connect_to, str) and re.fullmatch(r".*\.json$", connect_to):
|
||||||
|
logger.debug(connect_to)
|
||||||
|
connect_to = JsonFileConfigProvider(connect_to)
|
||||||
if isinstance(connect_to, ConfigProvider):
|
if isinstance(connect_to, ConfigProvider):
|
||||||
engine = get_pg_engine(connect_to.get_postgre_connection_string())
|
engine = get_pg_engine(connect_to.get_postgre_connection_string())
|
||||||
|
|
||||||
@ -90,26 +91,6 @@ def reset_all_tables(db: Session) -> None:
|
|||||||
init_db(db)
|
init_db(db)
|
||||||
|
|
||||||
|
|
||||||
@logger.catch(reraise=True)
|
|
||||||
def transfer_db(*, source: Session, destination: Session) -> None:
|
|
||||||
"""Transfers the data from on db to another db.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source: A session to a source db data should be copied from.
|
|
||||||
destination: A session to a db where the data should be copied to.
|
|
||||||
"""
|
|
||||||
reset_all_tables(destination)
|
|
||||||
init_db(destination)
|
|
||||||
sbind = source.bind
|
|
||||||
dbind = destination.bind
|
|
||||||
assert isinstance(sbind, Engine) # noqa: S101
|
|
||||||
assert isinstance(dbind, Engine) # noqa: S101
|
|
||||||
for table in tqdm(Base.metadata.sorted_tables):
|
|
||||||
pd.read_sql_table(str(table), sbind).to_sql(
|
|
||||||
str(table), dbind, if_exists="append", index=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""Main flow creating tables"""
|
"""Main flow creating tables"""
|
||||||
init_db(get_session(JsonFileConfigProvider("./secrets.json")))
|
init_db(get_session(JsonFileConfigProvider("./secrets.json")))
|
||||||
|
85
src/aki_prj23_transparenzregister/utils/sql/copy_sql.py
Normal file
85
src/aki_prj23_transparenzregister/utils/sql/copy_sql.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
"""Functions to copy a sql table."""
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from loguru import logger
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from aki_prj23_transparenzregister.utils.logger_config import configer_logger
|
||||||
|
from aki_prj23_transparenzregister.utils.sql.connector import (
|
||||||
|
Base,
|
||||||
|
get_session,
|
||||||
|
reset_all_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@logger.catch(reraise=True)
|
||||||
|
def transfer_db_function(*, source: Session, destination: Session) -> None:
|
||||||
|
"""Transfers the data from on db to another db.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A session to a source db data should be copied from.
|
||||||
|
destination: A session to a db where the data should be copied to.
|
||||||
|
"""
|
||||||
|
reset_all_tables(destination)
|
||||||
|
# init_db(destination)
|
||||||
|
sbind = source.bind
|
||||||
|
dbind = destination.bind
|
||||||
|
assert isinstance(sbind, Engine) # noqa: S101
|
||||||
|
assert isinstance(dbind, Engine) # noqa: S101
|
||||||
|
for table in Base.metadata.sorted_tables:
|
||||||
|
logger.info(f"Transferring table {table} from source to destination db.")
|
||||||
|
pd.read_sql_table(str(table), sbind).to_sql(
|
||||||
|
str(table), dbind, if_exists="append", index=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_db_cli(args: list[str] | None = None) -> None:
|
||||||
|
"""CLI interfaces to copy a db from source to destination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: The args ar automaticlly collected from the cli if none are given. They should only be given for testing.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="copy-sql",
|
||||||
|
description="Copy data from one SQL database to another.",
|
||||||
|
epilog="Example: copy-sql source.db destination.json",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"source",
|
||||||
|
metavar="source",
|
||||||
|
help="Source database configuration.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"destination",
|
||||||
|
metavar="destination",
|
||||||
|
help="Destination database configuration.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-level",
|
||||||
|
choices=["info", "debug", "error", "warning"],
|
||||||
|
default="info",
|
||||||
|
metavar="log-level",
|
||||||
|
help="The log level for the output.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-path",
|
||||||
|
metavar="log_path",
|
||||||
|
help="A path to write the log to.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args:
|
||||||
|
args = sys.argv[1:]
|
||||||
|
|
||||||
|
parsed = parser.parse_args(args)
|
||||||
|
configer_logger(level=parsed.log_level, path=parsed.log_path)
|
||||||
|
source = get_session(parsed.source)
|
||||||
|
logger.info(f"Connecting to {source.bind} as a source to copy from.")
|
||||||
|
destination = get_session(parsed.destination)
|
||||||
|
logger.info(f"Connecting to {destination.bind} as a destination to copy to.")
|
||||||
|
transfer_db_function(
|
||||||
|
source=source,
|
||||||
|
destination=destination,
|
||||||
|
)
|
@ -571,23 +571,25 @@ def test_add_relationship_company_unknown(
|
|||||||
city: str | None,
|
city: str | None,
|
||||||
zip_code: str | None,
|
zip_code: str | None,
|
||||||
full_db: Session,
|
full_db: Session,
|
||||||
|
mocker: MockerFixture,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Tests if a relationship to another company can be added."""
|
"""Tests if a relationship to another company can be added."""
|
||||||
with pytest.raises(
|
spy_warning = mocker.spy(data_transfer.logger, "warning")
|
||||||
KeyError, match=f"No corresponding company could be found to {company_name}."
|
spy_info = mocker.spy(data_transfer.logger, "info")
|
||||||
):
|
data_transfer.add_relationship(
|
||||||
data_transfer.add_relationship(
|
{
|
||||||
{
|
"description": company_name,
|
||||||
"description": company_name,
|
"location": {
|
||||||
"location": {
|
"zip_code": zip_code,
|
||||||
"zip_code": zip_code,
|
"city": city,
|
||||||
"city": city,
|
|
||||||
},
|
|
||||||
"role": "organisation",
|
|
||||||
},
|
},
|
||||||
company_id,
|
"role": "organisation",
|
||||||
full_db,
|
},
|
||||||
)
|
company_id,
|
||||||
|
full_db,
|
||||||
|
)
|
||||||
|
spy_warning.assert_called_once()
|
||||||
|
spy_info.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("empty_relations", [[], [{}], [{"relationship": []}]])
|
@pytest.mark.parametrize("empty_relations", [[], [{}], [{"relationship": []}]])
|
||||||
@ -778,7 +780,7 @@ def test_add_annual_financial_reports_no_call(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Testing if financial reports are added correctly to the db."""
|
"""Testing if financial reports are added correctly to the db."""
|
||||||
spy_warning = mocker.spy(data_transfer.logger, "warning")
|
spy_warning = mocker.spy(data_transfer.logger, "warning")
|
||||||
info_warning = mocker.spy(data_transfer.logger, "info")
|
spy_info = mocker.spy(data_transfer.logger, "info")
|
||||||
mocker.patch("aki_prj23_transparenzregister.utils.data_transfer.add_annual_report")
|
mocker.patch("aki_prj23_transparenzregister.utils.data_transfer.add_annual_report")
|
||||||
data_transfer.add_annual_financial_reports(companies, full_db)
|
data_transfer.add_annual_financial_reports(companies, full_db)
|
||||||
|
|
||||||
@ -786,7 +788,7 @@ def test_add_annual_financial_reports_no_call(
|
|||||||
input_kwargs = mocker.call.kwargs
|
input_kwargs = mocker.call.kwargs
|
||||||
assert len(input_args) == len(input_kwargs)
|
assert len(input_args) == len(input_kwargs)
|
||||||
spy_warning.assert_not_called()
|
spy_warning.assert_not_called()
|
||||||
info_warning.assert_called_once()
|
spy_info.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -821,7 +823,7 @@ def test_add_annual_financial_reports_defect_year(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Testing if financial reports are added correctly to the db."""
|
"""Testing if financial reports are added correctly to the db."""
|
||||||
spy_warning = mocker.spy(data_transfer.logger, "warning")
|
spy_warning = mocker.spy(data_transfer.logger, "warning")
|
||||||
info_warning = mocker.spy(data_transfer.logger, "info")
|
spy_info = mocker.spy(data_transfer.logger, "info")
|
||||||
mocker.patch("aki_prj23_transparenzregister.utils.data_transfer.add_annual_report")
|
mocker.patch("aki_prj23_transparenzregister.utils.data_transfer.add_annual_report")
|
||||||
data_transfer.add_annual_financial_reports(companies, full_db)
|
data_transfer.add_annual_financial_reports(companies, full_db)
|
||||||
|
|
||||||
@ -829,7 +831,7 @@ def test_add_annual_financial_reports_defect_year(
|
|||||||
input_kwargs = mocker.call.kwargs
|
input_kwargs = mocker.call.kwargs
|
||||||
assert len(input_args) == len(input_kwargs)
|
assert len(input_args) == len(input_kwargs)
|
||||||
spy_warning.assert_called_once()
|
spy_warning.assert_called_once()
|
||||||
info_warning.assert_called_once()
|
spy_info.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_add_annual_financial_reports(full_db: Session, mocker: MockerFixture) -> None:
|
def test_add_annual_financial_reports(full_db: Session, mocker: MockerFixture) -> None:
|
||||||
@ -864,7 +866,7 @@ def test_add_annual_financial_reports(full_db: Session, mocker: MockerFixture) -
|
|||||||
]
|
]
|
||||||
|
|
||||||
spy_warning = mocker.spy(data_transfer.logger, "warning")
|
spy_warning = mocker.spy(data_transfer.logger, "warning")
|
||||||
info_warning = mocker.spy(data_transfer.logger, "info")
|
spy_info = mocker.spy(data_transfer.logger, "info")
|
||||||
mocked = mocker.patch(
|
mocked = mocker.patch(
|
||||||
"aki_prj23_transparenzregister.utils.data_transfer.add_annual_report"
|
"aki_prj23_transparenzregister.utils.data_transfer.add_annual_report"
|
||||||
)
|
)
|
||||||
@ -890,7 +892,7 @@ def test_add_annual_financial_reports(full_db: Session, mocker: MockerFixture) -
|
|||||||
for input_args in mocked.call_args_list:
|
for input_args in mocked.call_args_list:
|
||||||
assert isinstance(input_args.kwargs["db"], Session)
|
assert isinstance(input_args.kwargs["db"], Session)
|
||||||
|
|
||||||
info_warning.assert_called_once()
|
spy_info.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("year", list(range(2000, 2025, 5)))
|
@pytest.mark.parametrize("year", list(range(2000, 2025, 5)))
|
||||||
|
26
tests/utils/logger_config_test.py
Normal file
26
tests/utils/logger_config_test.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
"""Smoke-test over the logger config."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from aki_prj23_transparenzregister.utils.logger_config import configer_logger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("path", [None, "test-log.log", ""])
|
||||||
|
@pytest.mark.parametrize("upper", [True, False])
|
||||||
|
@pytest.mark.parametrize("level", ["info", "debug", "error", "warning"])
|
||||||
|
def test_configer_logger(
|
||||||
|
level: str,
|
||||||
|
upper: bool,
|
||||||
|
path: Path | str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Tests the configuration of the logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: The log-level to configure.
|
||||||
|
upper: If the upper variant of the level should be used.
|
||||||
|
path: The path where to save the log.
|
||||||
|
"""
|
||||||
|
if level.upper():
|
||||||
|
level = level.upper()
|
||||||
|
configer_logger(level, path) # type: ignore
|
@ -4,19 +4,15 @@ from collections.abc import Generator
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider
|
from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider
|
||||||
from aki_prj23_transparenzregister.config.config_template import PostgreConnectionString
|
from aki_prj23_transparenzregister.config.config_template import PostgreConnectionString
|
||||||
from aki_prj23_transparenzregister.utils.sql.connector import (
|
from aki_prj23_transparenzregister.utils.sql.connector import (
|
||||||
Base,
|
|
||||||
get_pg_engine,
|
get_pg_engine,
|
||||||
get_session,
|
get_session,
|
||||||
init_db,
|
init_db,
|
||||||
transfer_db,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -31,36 +27,6 @@ def test_get_engine_pg() -> None:
|
|||||||
assert get_pg_engine(conn_args) == result
|
assert get_pg_engine(conn_args) == result
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def destination_db() -> Generator[Session, None, None]:
|
|
||||||
"""Generates a db Session to a sqlite db to copy data to."""
|
|
||||||
if os.path.exists("secondary.db"):
|
|
||||||
os.remove("secondary.db")
|
|
||||||
db = get_session("sqlite:///secondary.db")
|
|
||||||
init_db(db)
|
|
||||||
yield db
|
|
||||||
db.close()
|
|
||||||
bind = db.bind
|
|
||||||
assert isinstance(bind, Engine)
|
|
||||||
bind.dispose()
|
|
||||||
os.remove("secondary.db")
|
|
||||||
|
|
||||||
|
|
||||||
def test_transfer_db(full_db: Session, destination_db: Session) -> None:
|
|
||||||
"""Tests if the data transfer between two sql tables works."""
|
|
||||||
transfer_db(source=full_db, destination=destination_db)
|
|
||||||
sbind = full_db.bind
|
|
||||||
dbind = destination_db.bind
|
|
||||||
assert isinstance(sbind, Engine)
|
|
||||||
assert isinstance(dbind, Engine)
|
|
||||||
|
|
||||||
for table in Base.metadata.sorted_tables:
|
|
||||||
pd.testing.assert_frame_equal(
|
|
||||||
pd.read_sql_table(str(table), dbind),
|
|
||||||
pd.read_sql_table(str(table), sbind),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def delete_sqlite_table() -> Generator[str, None, None]:
|
def delete_sqlite_table() -> Generator[str, None, None]:
|
||||||
"""Cleans a path before and deletes the table after a test.
|
"""Cleans a path before and deletes the table after a test.
|
||||||
|
56
tests/utils/sql/copy_sql_test.py
Normal file
56
tests/utils/sql/copy_sql_test.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""Test if the sql db can be copied."""
|
||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from aki_prj23_transparenzregister.utils.sql.connector import Base, get_session, init_db
|
||||||
|
from aki_prj23_transparenzregister.utils.sql.copy_sql import (
|
||||||
|
copy_db_cli,
|
||||||
|
transfer_db_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def destination_db() -> Generator[Session, None, None]:
|
||||||
|
"""Generates a db Session to a sqlite db to copy data to."""
|
||||||
|
if os.path.exists("secondary.db"):
|
||||||
|
os.remove("secondary.db")
|
||||||
|
db = get_session("sqlite:///secondary.db")
|
||||||
|
init_db(db)
|
||||||
|
yield db
|
||||||
|
db.close()
|
||||||
|
bind = db.bind
|
||||||
|
assert isinstance(bind, Engine)
|
||||||
|
bind.dispose()
|
||||||
|
os.remove("secondary.db")
|
||||||
|
|
||||||
|
|
||||||
|
def test_transfer_db(full_db: Session, destination_db: Session) -> None:
|
||||||
|
"""Tests if the data transfer between two sql tables works."""
|
||||||
|
transfer_db_function(source=full_db, destination=destination_db)
|
||||||
|
sbind = full_db.bind
|
||||||
|
dbind = destination_db.bind
|
||||||
|
assert isinstance(sbind, Engine)
|
||||||
|
assert isinstance(dbind, Engine)
|
||||||
|
assert Base.metadata.sorted_tables
|
||||||
|
for table in Base.metadata.sorted_tables + ["company"]:
|
||||||
|
pd.testing.assert_frame_equal(
|
||||||
|
pd.read_sql_table(str(table), dbind),
|
||||||
|
pd.read_sql_table(str(table), sbind),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy_db_cli_help1() -> None:
|
||||||
|
"""Tests if the help argument exits the software gracefully."""
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
copy_db_cli(["-h"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy_db_cli_help2() -> None:
|
||||||
|
"""Tests if the help argument exits the software gracefully."""
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
copy_db_cli(["eskse", "-h", "asdf"])
|
Reference in New Issue
Block a user