From 2233b12468c143744b464bd91ccc08a09448885e Mon Sep 17 00:00:00 2001 From: Philipp Horstenkamp Date: Thu, 7 Sep 2023 18:41:10 +0200 Subject: [PATCH] Add sql lite session (#71) I added an sql lite session generator. Changes The function `def get_session() -> Session` has changed to `def get_session(connect_to: ConfigProvider | str) -> Session` If a JsonFileConfig Is given the postgress conection is checked. If a string is given that starts with `sqlite:///` an sql db is created. The use should otherwise be the same. --- .gitignore | 1 + .pre-commit-config.yaml | 4 + .../ui/company_finance_dash.py | 8 +- .../ui/company_stats_dash.py | 7 +- .../utils/postgres/connector.py | 49 ----------- .../utils/{postgres => sql}/__init__.py | 0 .../utils/sql/connector.py | 86 +++++++++++++++++++ .../utils/{postgres => sql}/entities.py | 2 +- tests/utils/postgres/connector_test.py | 35 -------- tests/utils/postgres/entities_test.py | 4 - tests/utils/sql/connector_test.py | 80 +++++++++++++++++ tests/utils/sql/entities_test.py | 4 + 12 files changed, 183 insertions(+), 97 deletions(-) delete mode 100644 src/aki_prj23_transparenzregister/utils/postgres/connector.py rename src/aki_prj23_transparenzregister/utils/{postgres => sql}/__init__.py (100%) create mode 100644 src/aki_prj23_transparenzregister/utils/sql/connector.py rename src/aki_prj23_transparenzregister/utils/{postgres => sql}/entities.py (98%) delete mode 100644 tests/utils/postgres/connector_test.py delete mode 100644 tests/utils/postgres/entities_test.py create mode 100644 tests/utils/sql/connector_test.py create mode 100644 tests/utils/sql/entities_test.py diff --git a/.gitignore b/.gitignore index 1735aea..5371c9d 100644 --- a/.gitignore +++ b/.gitignore @@ -216,3 +216,4 @@ replay_pid* /documentations/modules.rst /unit-test-results.xml /lbr-audit.md +/.ruff_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 651f9b4..fed9ed5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,6 +61,10 @@ repos: - pandas-stubs==2.0.* - types-requests - sqlalchemy[mypy]==1.4.49 + - types-tqdm + - types-setuptools + - types-requests + - types-pyOpenSSL - repo: https://github.com/frnmst/md-toc rev: 8.2.0 diff --git a/src/aki_prj23_transparenzregister/ui/company_finance_dash.py b/src/aki_prj23_transparenzregister/ui/company_finance_dash.py index 8e5c728..ab6c2fa 100644 --- a/src/aki_prj23_transparenzregister/ui/company_finance_dash.py +++ b/src/aki_prj23_transparenzregister/ui/company_finance_dash.py @@ -7,13 +7,11 @@ from dash import Dash, Input, Output, callback, dash_table, dcc, html from dash.exceptions import PreventUpdate from sqlalchemy.engine import Engine -from aki_prj23_transparenzregister.utils.postgres import entities -from aki_prj23_transparenzregister.utils.postgres.connector import ( - get_session, -) +from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider +from aki_prj23_transparenzregister.utils.sql import connector, entities if __name__ == "__main__": - session = get_session() + session = connector.get_session(JsonFileConfigProvider("./secrets.json")) query_finance = session.query( entities.AnnualFinanceStatement, entities.Company.name, entities.Company.id ).join(entities.Company) diff --git a/src/aki_prj23_transparenzregister/ui/company_stats_dash.py b/src/aki_prj23_transparenzregister/ui/company_stats_dash.py index 2a46147..a48375c 100644 --- a/src/aki_prj23_transparenzregister/ui/company_stats_dash.py +++ b/src/aki_prj23_transparenzregister/ui/company_stats_dash.py @@ -3,13 +3,14 @@ import pandas as pd from dash import Dash, Input, Output, callback, dash_table, dcc, html -from aki_prj23_transparenzregister.utils.postgres import entities -from aki_prj23_transparenzregister.utils.postgres.connector import ( +from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider +from aki_prj23_transparenzregister.utils.sql import entities +from aki_prj23_transparenzregister.utils.sql.connector import ( get_session, ) if __name__ == "__main__": - session = get_session() + session = get_session(JsonFileConfigProvider("./secrets.json")) query = session.query(entities.Company) companies_df: pd.DataFrame = pd.read_sql(str(query), session.bind) # type: ignore diff --git a/src/aki_prj23_transparenzregister/utils/postgres/connector.py b/src/aki_prj23_transparenzregister/utils/postgres/connector.py deleted file mode 100644 index 5b41fac..0000000 --- a/src/aki_prj23_transparenzregister/utils/postgres/connector.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Module containing connection utils for PostgreSQL DB.""" -from sqlalchemy import create_engine -from sqlalchemy.engine import URL, Engine -from sqlalchemy.orm import Session, declarative_base, sessionmaker - -from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider -from aki_prj23_transparenzregister.config.config_template import PostgreConnectionString - - -def get_engine(conn_args: PostgreConnectionString) -> Engine: - """Creates an engine connected to a Postgres instance. - - Returns: - sqlalchemy.engine: connection engine - """ - url = URL.create( - drivername="postgresql", - username=conn_args.username, - password=conn_args.password, - host=conn_args.host, - database=conn_args.database, - port=conn_args.port, - ) - - return create_engine(url) - - -def get_session() -> Session: # pragma: no cover - """Return PG Session.""" - config_provider = JsonFileConfigProvider("./secrets.json") - engine = get_engine(config_provider.get_postgre_connection_string()) - session = sessionmaker(bind=engine) - return session() - - -Base = declarative_base() - - -def init_db() -> None: - """Initialize DB with all defined entities.""" - config_provider = JsonFileConfigProvider("./secrets.json") - engine = get_engine(config_provider.get_postgre_connection_string()) - with engine.connect(): - Base.metadata.create_all(engine) - - -if __name__ == "__main__": - """Main flow creating tables""" - init_db() diff --git a/src/aki_prj23_transparenzregister/utils/postgres/__init__.py b/src/aki_prj23_transparenzregister/utils/sql/__init__.py similarity index 100% rename from src/aki_prj23_transparenzregister/utils/postgres/__init__.py rename to src/aki_prj23_transparenzregister/utils/sql/__init__.py diff --git a/src/aki_prj23_transparenzregister/utils/sql/connector.py b/src/aki_prj23_transparenzregister/utils/sql/connector.py new file mode 100644 index 0000000..b2ef367 --- /dev/null +++ b/src/aki_prj23_transparenzregister/utils/sql/connector.py @@ -0,0 +1,86 @@ +"""Module containing connection utils for PostgreSQL DB.""" +import re + +import sqlalchemy as sa +from loguru import logger +from sqlalchemy.engine import URL, Engine +from sqlalchemy.orm import Session, declarative_base, sessionmaker +from sqlalchemy.pool import SingletonThreadPool + +from aki_prj23_transparenzregister.config.config_providers import ( + ConfigProvider, + JsonFileConfigProvider, +) +from aki_prj23_transparenzregister.config.config_template import PostgreConnectionString + + +def get_pg_engine(conn_args: PostgreConnectionString) -> Engine: + """Creates an engine connected to a Postgres instance. + + Returns: + sqlalchemy.engine: connection engine + """ + url = URL.create( + drivername="postgresql", + username=conn_args.username, + password=conn_args.password, + host=conn_args.host, + database=conn_args.database, + port=conn_args.port, + ) + return sa.create_engine(url) + + +def get_sqlite_engine(connect_to: str) -> Engine: + """Creates an engine connected to a sqlite instance. + + Returns: + sqlalchemy.engine: connection engine + """ + return sa.create_engine( + connect_to, + connect_args={"check_same_thread": True}, + poolclass=SingletonThreadPool, + ) + + +def get_session( + connect_to: JsonFileConfigProvider | str, +) -> Session: # pragma: no cover + """Creates a sql session. + + Args: + connect_to: The sqldb to connect to or the configuration there of. + + Returns: + A session to connect to an SQL db via SQLAlchemy. + """ + engine: Engine + if isinstance(connect_to, ConfigProvider): + engine = get_pg_engine(connect_to.get_postgre_connection_string()) + + elif isinstance(connect_to, str) and re.fullmatch( + r"sqlite:\/{3}[A-Za-z].*", connect_to + ): + engine = get_sqlite_engine(connect_to) + logger.info(f"Connection to sqlite3 {connect_to}") + else: + raise TypeError("No valid connection is defined!") + return sessionmaker(autocommit=False, autoflush=False, bind=engine)() + + +Base = declarative_base() + + +def init_db(db: Session) -> None: + """Initialize DB with all defined entities. + + Args: + db: A session to connect to an SQL db via SQLAlchemy. + """ + Base.metadata.create_all(db.bind) + + +if __name__ == "__main__": + """Main flow creating tables""" + init_db(get_session(JsonFileConfigProvider("./secrets.json"))) diff --git a/src/aki_prj23_transparenzregister/utils/postgres/entities.py b/src/aki_prj23_transparenzregister/utils/sql/entities.py similarity index 98% rename from src/aki_prj23_transparenzregister/utils/postgres/entities.py rename to src/aki_prj23_transparenzregister/utils/sql/entities.py index bda9b67..2bb90f1 100644 --- a/src/aki_prj23_transparenzregister/utils/postgres/entities.py +++ b/src/aki_prj23_transparenzregister/utils/sql/entities.py @@ -7,7 +7,7 @@ from aki_prj23_transparenzregister.utils.enumy_types import ( RelationTypeEnum, SentimentTypeEnum, ) -from aki_prj23_transparenzregister.utils.postgres.connector import Base +from aki_prj23_transparenzregister.utils.sql.connector import Base # # create an object *district_court* which inherits attributes from Base-class diff --git a/tests/utils/postgres/connector_test.py b/tests/utils/postgres/connector_test.py deleted file mode 100644 index 35e0c94..0000000 --- a/tests/utils/postgres/connector_test.py +++ /dev/null @@ -1,35 +0,0 @@ -from unittest.mock import Mock, patch - -from aki_prj23_transparenzregister.config.config_template import PostgreConnectionString -from aki_prj23_transparenzregister.utils.postgres.connector import get_engine, init_db - - -def test_get_engine() -> None: - conn_args = PostgreConnectionString("", "", "", "", 42) - with patch( - "aki_prj23_transparenzregister.utils.postgres.connector.create_engine" - ) as mock_create_engine: - result = "someThing" - mock_create_engine.return_value = result - assert get_engine(conn_args) == result - - -def test_init_db() -> None: - with patch( - "aki_prj23_transparenzregister.utils.postgres.connector.get_engine" - ) as mock_get_engine, patch( - "aki_prj23_transparenzregister.utils.postgres.connector.declarative_base" - ) as mock_declarative_base, patch( - "aki_prj23_transparenzregister.utils.postgres.connector.JsonFileConfigProvider" - ) as mock_provider: - mock_get_engine.connect.return_value = {} - - mock_value = Mock() - mock_value.metadata.create_all.return_value = None - mock_declarative_base.return_value = mock_value - - mock_value = Mock() - mock_provider.return_value = mock_value - mock_value.get_postgre_connection_string.return_value = "" - - init_db() diff --git a/tests/utils/postgres/entities_test.py b/tests/utils/postgres/entities_test.py deleted file mode 100644 index 69d376c..0000000 --- a/tests/utils/postgres/entities_test.py +++ /dev/null @@ -1,4 +0,0 @@ -def test_import() -> None: - from aki_prj23_transparenzregister.utils.postgres import entities - - assert entities diff --git a/tests/utils/sql/connector_test.py b/tests/utils/sql/connector_test.py new file mode 100644 index 0000000..658ba57 --- /dev/null +++ b/tests/utils/sql/connector_test.py @@ -0,0 +1,80 @@ +import os.path +from collections.abc import Generator +from typing import Any +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.engine import Engine + +from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider +from aki_prj23_transparenzregister.config.config_template import PostgreConnectionString +from aki_prj23_transparenzregister.utils.sql.connector import ( + get_pg_engine, + get_session, + init_db, +) + + +def test_get_engine_pg() -> None: + conn_args = PostgreConnectionString("", "", "", "", 42) + with patch( + "aki_prj23_transparenzregister.utils.sql.connector.sa.create_engine" + ) as mock_create_engine: + result = "someThing" + mock_create_engine.return_value = result + assert get_pg_engine(conn_args) == result + + +@pytest.fixture() +def delete_sqlite_table() -> Generator[str, None, None]: + """Cleans a path before and deletes the table after a test. + + Returns: + The path where the sqlite table is placed. + """ + sqlite_test_path = "test_db.db" + if os.path.exists(sqlite_test_path): + os.remove(sqlite_test_path) + + yield sqlite_test_path + + if os.path.exists(sqlite_test_path): + os.remove(sqlite_test_path) + + +def test_get_sqlite_init(delete_sqlite_table: str) -> None: + """Tests if a sql table file can be initiated.""" + assert not os.path.exists(delete_sqlite_table) + session = get_session(f"sqlite:///{delete_sqlite_table}") + init_db(session) + session.close() + engine = session.bind + assert isinstance(engine, Engine) + engine.dispose() + assert os.path.exists(delete_sqlite_table) + + +@pytest.mark.parametrize("connection", ["faulty-name", 0, 9.2, True]) +def test_get_invalid_connection(connection: Any) -> None: + """Tests if an error is thrown on a faulty connections.""" + with pytest.raises(TypeError): + get_session(connection) + + +def test_init_pd_db() -> None: + """Tests if a pg sql database can be connected and initiated to.""" + with patch( + "aki_prj23_transparenzregister.utils.sql.connector.get_pg_engine" + ) as mock_get_engine, patch( + "aki_prj23_transparenzregister.utils.sql.connector.declarative_base" + ) as mock_declarative_base: + mock_get_engine.connect.return_value = {} + + mock_value = Mock() + mock_value.metadata.create_all.return_value = None + mock_declarative_base.return_value = mock_value + + mock_value = Mock(spec=JsonFileConfigProvider) + mock_value.get_postgre_connection_string.return_value = "" + + init_db(get_session(mock_value)) diff --git a/tests/utils/sql/entities_test.py b/tests/utils/sql/entities_test.py new file mode 100644 index 0000000..14bc361 --- /dev/null +++ b/tests/utils/sql/entities_test.py @@ -0,0 +1,4 @@ +def test_import() -> None: + from aki_prj23_transparenzregister.utils.sql import entities + + assert entities