From af8a907cf90350c0d8feb2a52be5635667d31b4c Mon Sep 17 00:00:00 2001 From: Philipp Horstenkamp Date: Sun, 12 Nov 2023 14:27:44 +0100 Subject: [PATCH] Stop table reset of better persistent tables. (#373) --- pyproject.toml | 4 +- .../utils/data_transfer.py | 4 +- .../utils/sql/connector.py | 35 ----------- .../utils/sql/copy_sql.py | 4 +- .../utils/sql/reset_sql.py | 61 +++++++++++++++++++ tests/conftest.py | 3 + tests/utils/sql/rest_sql_test.py | 38 ++++++++++++ 7 files changed, 108 insertions(+), 41 deletions(-) create mode 100644 src/aki_prj23_transparenzregister/utils/sql/reset_sql.py create mode 100644 tests/utils/sql/rest_sql_test.py diff --git a/pyproject.toml b/pyproject.toml index be5c77e..fcb7228 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dashvis = "^0.1.3" datetime = "^5.2" deutschland = {git = "https://github.com/TrisNol/deutschland.git", branch = "hotfix/python-3.11-support"} frozendict = "^2.3.8" +html5lib = "^1.1" loguru = "^0.7.0" matplotlib = "^3.8.1" networkx = "^3.2.1" @@ -82,7 +83,6 @@ torchvision = {version = "*", source = "torch-cpu"} tqdm = "^4.66.1" transformers = {version = "*", extras = ["torch"]} xmltodict = "^0.13.0" -html5lib = "^1.1" [tool.poetry.extras] ingest = ["selenium", "deutschland", "xmltodict", "html5lib"] @@ -143,7 +143,7 @@ copy-sql = "aki_prj23_transparenzregister.utils.sql.copy_sql:copy_db_cli" data-processing = "aki_prj23_transparenzregister.utils.data_processing:cli" data-transformation = "aki_prj23_transparenzregister.utils.data_transfer:transfer_data_cli" fetch-news-schedule = "aki_prj23_transparenzregister.apps.fetch_news:fetch_news_cli" -reset-sql = "aki_prj23_transparenzregister.utils.sql.connector:reset_all_tables_cli" +reset-sql = "aki_prj23_transparenzregister.utils.sql.reset_sql:cli" webserver = "aki_prj23_transparenzregister.ui.app:main" [[tool.poetry.source]] diff --git a/src/aki_prj23_transparenzregister/utils/data_transfer.py b/src/aki_prj23_transparenzregister/utils/data_transfer.py index 148e404..84e747c 100644 --- a/src/aki_prj23_transparenzregister/utils/data_transfer.py +++ b/src/aki_prj23_transparenzregister/utils/data_transfer.py @@ -35,8 +35,8 @@ from aki_prj23_transparenzregister.utils.mongo.connector import MongoConnector from aki_prj23_transparenzregister.utils.sql import entities from aki_prj23_transparenzregister.utils.sql.connector import ( get_session, - reset_all_tables, ) +from aki_prj23_transparenzregister.utils.sql.reset_sql import reset_tables from aki_prj23_transparenzregister.utils.string_tools import simplify_string nomi = pgeocode.Nominatim("de") @@ -639,7 +639,7 @@ def transfer_data(config_provider: ConfigProvider) -> None: companies: list[dict[str, Any]] = mongo_company.get_all() # type: ignore del mongo_company db = get_session(config_provider) - reset_all_tables(db) + reset_tables(db, all_tables=False) add_companies(companies, db) reset_relation_counter(db) diff --git a/src/aki_prj23_transparenzregister/utils/sql/connector.py b/src/aki_prj23_transparenzregister/utils/sql/connector.py index 7b95985..56a6be5 100644 --- a/src/aki_prj23_transparenzregister/utils/sql/connector.py +++ b/src/aki_prj23_transparenzregister/utils/sql/connector.py @@ -1,28 +1,19 @@ """Module containing connection utils for PostgreSQL DB.""" -import argparse -import sys import sqlalchemy as sa -from loguru import logger from sqlalchemy.engine import URL, Engine from sqlalchemy.orm import Session, declarative_base, sessionmaker from sqlalchemy.pool import SingletonThreadPool from aki_prj23_transparenzregister.config.config_providers import ( - HELP_TEXT_CONFIG, ConfigProvider, JsonFileConfigProvider, - get_config_provider, ) from aki_prj23_transparenzregister.config.config_template import ( PostgreConnectionString, SQLConnectionString, SQLiteConnectionString, ) -from aki_prj23_transparenzregister.utils.logger_config import ( - add_logger_options_to_argparse, - configer_logger, -) def get_engine(conn_args: SQLConnectionString) -> Engine: @@ -79,32 +70,6 @@ def init_db(db: Session) -> None: Base.metadata.create_all(db.bind) -def reset_all_tables(db: Session) -> None: - """Drops all SQL tables and recreates them.""" - logger.info("Resetting all SQL tables.") - Base.metadata.drop_all(db.bind) - init_db(db) - - -def reset_all_tables_cli() -> None: - """Resets all tables via a cli.""" - parser = argparse.ArgumentParser( - prog="Reset SQL", - description="Copy data from one SQL database to another.", - epilog="Example: 'reset-sql secrets.json' or 'reset-sql ENV_VARS_'", - ) - parser.add_argument( - "config", - metavar="config", - default="ENV", - help=HELP_TEXT_CONFIG, - ) - add_logger_options_to_argparse(parser) - parsed = parser.parse_args(sys.argv[1:]) - configer_logger(namespace=parsed) - reset_all_tables(get_session(get_config_provider(parsed.config))) - - if __name__ == "__main__": """Main flow creating tables""" init_db(get_session(JsonFileConfigProvider("./secrets.json"))) diff --git a/src/aki_prj23_transparenzregister/utils/sql/copy_sql.py b/src/aki_prj23_transparenzregister/utils/sql/copy_sql.py index 85499fa..8a88c6a 100644 --- a/src/aki_prj23_transparenzregister/utils/sql/copy_sql.py +++ b/src/aki_prj23_transparenzregister/utils/sql/copy_sql.py @@ -15,8 +15,8 @@ from aki_prj23_transparenzregister.utils.logger_config import ( from aki_prj23_transparenzregister.utils.sql.connector import ( Base, get_session, - reset_all_tables, ) +from aki_prj23_transparenzregister.utils.sql.reset_sql import reset_tables @logger.catch(reraise=True) @@ -27,7 +27,7 @@ def transfer_db_function(*, source: Session, destination: Session) -> None: source: A session to a source db data should be copied from. destination: A session to a db where the data should be copied to. """ - reset_all_tables(destination) + reset_tables(destination, all_tables=True) # init_db(destination) sbind = source.bind dbind = destination.bind diff --git a/src/aki_prj23_transparenzregister/utils/sql/reset_sql.py b/src/aki_prj23_transparenzregister/utils/sql/reset_sql.py new file mode 100644 index 0000000..878ec38 --- /dev/null +++ b/src/aki_prj23_transparenzregister/utils/sql/reset_sql.py @@ -0,0 +1,61 @@ +"""Functions to reset the SQL db partially or completely.""" +import argparse +import sys + +from loguru import logger +from sqlalchemy.orm import Session + +from aki_prj23_transparenzregister.config.config_providers import ( + HELP_TEXT_CONFIG, + get_config_provider, +) +from aki_prj23_transparenzregister.utils.logger_config import ( + add_logger_options_to_argparse, + configer_logger, +) +from aki_prj23_transparenzregister.utils.sql import entities +from aki_prj23_transparenzregister.utils.sql.connector import get_session, init_db +from aki_prj23_transparenzregister.utils.sql.entities import Base + + +def reset_tables(db: Session, all_tables: bool = False) -> None: + """Drops all SQL tables and recreates them.""" + if all_tables: + logger.warning(f"Resetting all SQL tables in {db.bind}.") + Base.metadata.drop_all(db.bind) + db.commit() + else: + logger.info(f"Resetting the main SQL tables in {db.bind}.") + for table in Base.metadata.sorted_tables: + if str(table) == entities.MissingCompany.__tablename__: + continue + logger.debug(f"Dropping {table}") + table.drop(db.bind) + db.commit() + init_db(db) + + +def cli() -> None: + """Resets all tables via a cli.""" + parser = argparse.ArgumentParser( + prog="Reset SQL", + description="Copy data from one SQL database to another.", + epilog="Example: 'reset-sql secrets.json' or 'reset-sql ENV_VARS_'", + ) + parser.add_argument( + "-a", + "--all", + default=False, + action="store_true", + help="If set, resets all tables. Default is False.", + ) + parser.add_argument( + "config", + metavar="config", + default="ENV", + help=HELP_TEXT_CONFIG, + ) + add_logger_options_to_argparse(parser) + parsed = parser.parse_args(sys.argv[1:]) + configer_logger(namespace=parsed) + reset_tables(get_session(get_config_provider(parsed.config)), all_tables=parsed.all) diff --git a/tests/conftest.py b/tests/conftest.py index ba0e847..5a7aebc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -183,6 +183,9 @@ def full_db(empty_db: Session, finance_statements: list[dict[str, Any]]) -> Sess for finance_statement in finance_statements ] ) + empty_db.add( + entities.MissingCompany(name="Some company missing", zip_code="", city="") + ) empty_db.commit() # print(pd.read_sql_table("company", empty_db.bind).to_string()) return empty_db diff --git a/tests/utils/sql/rest_sql_test.py b/tests/utils/sql/rest_sql_test.py new file mode 100644 index 0000000..7a3b51f --- /dev/null +++ b/tests/utils/sql/rest_sql_test.py @@ -0,0 +1,38 @@ +"""Tests for sql rests.""" +import sys + +import pandas as pd +import pytest +from _pytest.monkeypatch import MonkeyPatch +from sqlalchemy.orm import Session + +from aki_prj23_transparenzregister.utils.sql import entities, reset_sql + + +def test_reset_sql_all(full_db: Session) -> None: + """Tests if all sql tables are reset.""" + reset_sql.reset_tables(all_tables=True, db=full_db) + assert pd.read_sql_table( + entities.MissingCompany.__tablename__, con=full_db.bind # type:ignore + ).empty + assert pd.read_sql_table( + entities.Company.__tablename__, con=full_db.bind # type:ignore + ).empty + + +def test_reset_sql(full_db: Session) -> None: + """Tests if only most sql tables are reset.""" + reset_sql.reset_tables(all_tables=False, db=full_db) + assert pd.read_sql_table( + entities.Company.__tablename__, con=full_db.bind # type:ignore + ).empty + assert not pd.read_sql_table( + entities.MissingCompany.__tablename__, con=full_db.bind # type:ignore + ).empty + + +def test_reset_help(monkeypatch: MonkeyPatch) -> None: + """Tests if all sql tables are reset.""" + monkeypatch.setattr(sys, "argv", [sys.argv[0], "-h"]) + with pytest.raises(SystemExit): + reset_sql.cli()