diff --git a/src/aki_prj23_transparenzregister/ui/data_elements.py b/src/aki_prj23_transparenzregister/ui/data_elements.py index 4a93bc3..52588f2 100644 --- a/src/aki_prj23_transparenzregister/ui/data_elements.py +++ b/src/aki_prj23_transparenzregister/ui/data_elements.py @@ -24,11 +24,7 @@ def get_company_data(session: Session) -> pd.DataFrame: query_company = session.query(entities.Company, entities.DistrictCourt.name).join( entities.DistrictCourt ) - engine = session.bind - if not isinstance(engine, sa.engine.Engine): - raise TypeError - - return pd.read_sql(str(query_company), engine, index_col="company_id") + return pd.read_sql(str(query_company), session.connection(), index_col="company_id") def get_person_data(session: Session) -> pd.DataFrame: @@ -41,11 +37,7 @@ def get_person_data(session: Session) -> pd.DataFrame: A dataframe containing all available company data including the corresponding district court. """ query_person = session.query(entities.Person) - engine = session.bind - if not isinstance(engine, sa.engine.Engine): - raise TypeError - - return pd.read_sql(str(query_person), engine, index_col="person_id") + return pd.read_sql(str(query_person), session.connection(), index_col="person_id") def get_finance_data(session: Session) -> pd.DataFrame: @@ -61,11 +53,7 @@ def get_finance_data(session: Session) -> pd.DataFrame: entities.AnnualFinanceStatement, entities.Company.name, entities.Company.id ).join(entities.Company) - engine = session.bind - if not isinstance(engine, sa.engine.Engine): - raise TypeError - - return pd.read_sql(str(query_finance), engine) + return pd.read_sql(str(query_finance), session.connection()) @cached( # type: ignore @@ -95,10 +83,6 @@ def get_finance_data_of_one_company(session: Session, company_id: int) -> pd.Dat logger.warning("SQL rollback when demanded!") session.rollback() annual_finance_data = query.all() - engine = session.bind - if not isinstance(engine, sa.engine.Engine): - raise TypeError - data = [row.__dict__ for row in annual_finance_data] if "_sa_instance_state" not in pd.DataFrame(data).columns: return pd.DataFrame(data) @@ -107,7 +91,7 @@ def get_finance_data_of_one_company(session: Session, company_id: int) -> pd.Dat @cached( # type: ignore cache=TTLCache(maxsize=1, ttl=300), - key=lambda session: 0 if session is None else str(session.bind), + key=lambda session: 0 if session is None else str(session.connection()), ) def get_options(session: Session | None) -> dict[int, str]: """Collects the search options for the companies and persons. diff --git a/src/aki_prj23_transparenzregister/utils/sql/reset_sql.py b/src/aki_prj23_transparenzregister/utils/sql/reset_sql.py index 045a572..dcba98b 100644 --- a/src/aki_prj23_transparenzregister/utils/sql/reset_sql.py +++ b/src/aki_prj23_transparenzregister/utils/sql/reset_sql.py @@ -22,7 +22,7 @@ def reset_tables(db: Session, all_tables: bool = False) -> None: """Drops all SQL tables and recreates them.""" if all_tables: logger.warning(f"Resetting all SQL tables in {db.bind}.") - Base.metadata.drop_all(db.bind) + Base.metadata.drop_all(db.connection()) db.commit() else: logger.info(f"Resetting the main SQL tables in {db.bind}.") @@ -32,7 +32,7 @@ def reset_tables(db: Session, all_tables: bool = False) -> None: if str(table) != entities.MissingCompany.__tablename__ ] logger.debug(f"Dropping tables: {', '.join([str(_) for _ in tables])}") - Base.metadata.drop_all(db.bind, tables=tables) + Base.metadata.drop_all(db.connection(), tables=tables) db.commit() init_db(db) db.commit() diff --git a/tests/conftest.py b/tests/conftest.py index eac13d2..03ee383 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -188,7 +188,6 @@ def full_db(empty_db: Session, finance_statements: list[dict[str, Any]]) -> Sess entities.MissingCompany(name="Some company missing", zip_code="", city="") ) empty_db.commit() - # print(pd.read_sql_table("company", empty_db.bind).to_string()) return empty_db diff --git a/tests/utils/data_transfer_test.py b/tests/utils/data_transfer_test.py index aef4ca0..3a0ff31 100644 --- a/tests/utils/data_transfer_test.py +++ b/tests/utils/data_transfer_test.py @@ -11,7 +11,6 @@ import pytest import sqlalchemy as sa from _pytest.monkeypatch import MonkeyPatch from pytest_mock import MockerFixture -from sqlalchemy.engine import Engine from sqlalchemy.orm import Session from aki_prj23_transparenzregister.models.company import ( @@ -714,10 +713,8 @@ def test_add_relationships_none(empty_relations: list, full_db: Session) -> None def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> None: """Testing to add lots of relations.""" data_transfer.add_relationships(documents, full_db) - bind = full_db.bind - assert isinstance(bind, Engine) pd.testing.assert_frame_equal( - pd.read_sql_table("company", bind), + pd.read_sql_table("company", full_db.connection()), pd.DataFrame( { "id": {0: 1, 1: 2, 2: 3}, @@ -750,13 +747,13 @@ def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> Non } ), ) - assert len(pd.read_sql_table("company_relation", bind).index) == 0 + assert len(pd.read_sql_table("company_relation", full_db.connection()).index) == 0 pd.testing.assert_frame_equal( - pd.read_sql_table("person_relation", bind), + pd.read_sql_table("person_relation", full_db.connection()), pd.DataFrame({"id": {0: 1, 1: 2}, "person_id": {0: 6, 1: 7}}), ) pd.testing.assert_frame_equal( - pd.read_sql_table("relation", bind), + pd.read_sql_table("relation", full_db.connection()), pd.DataFrame( { "id": {0: 1, 1: 2}, @@ -768,7 +765,7 @@ def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> Non ), ) pd.testing.assert_frame_equal( - pd.read_sql_table("person", bind), + pd.read_sql_table("person", full_db.connection()), pd.DataFrame( { "id": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7}, @@ -976,7 +973,7 @@ def test_add_annual_report_empty( ) -> None: """Testing if the correct warning is thrown when the financial and auditor records are empty.""" df_prior = pd.read_sql_table( - entities.AnnualFinanceStatement.__tablename__, full_db.bind # type: ignore + str(entities.AnnualFinanceStatement.__tablename__), full_db.connection() # type: ignore ) spy_warning = mocker.spy(data_transfer.logger, "debug") @@ -985,7 +982,9 @@ def test_add_annual_report_empty( spy_warning.assert_called_once() pd.testing.assert_frame_equal( df_prior, - pd.read_sql_table(entities.AnnualFinanceStatement.__tablename__, full_db.bind), # type: ignore + pd.read_sql_table( + str(entities.AnnualFinanceStatement.__tablename__), full_db.connection() # type: ignore + ), ) @@ -1031,7 +1030,7 @@ def test_add_annual_report( ) full_db.commit() df_prior = pd.read_sql_table( - entities.AnnualFinanceStatement.__tablename__, full_db.bind # type: ignore + str(entities.AnnualFinanceStatement.__tablename__), full_db.connection() # type: ignore ) expected_results = pd.DataFrame( finance_statements @@ -1158,7 +1157,7 @@ def test_company_relation_missing(empty_db: Session) -> None: pd.testing.assert_frame_equal( pd.read_sql_table( - entities.MissingCompany.__tablename__, empty_db.bind # type: ignore + entities.MissingCompany.__tablename__, empty_db.connection() ).set_index("name"), pd.DataFrame( [ @@ -1204,7 +1203,7 @@ def test_company_relation_missing_reset(empty_db: Session) -> None: empty_db.commit() data_transfer.reset_relation_counter(empty_db) queried_df = pd.read_sql_table( - entities.MissingCompany.__tablename__, empty_db.bind # type: ignore + entities.MissingCompany.__tablename__, empty_db.connection() ).set_index("name") pd.testing.assert_frame_equal( queried_df, diff --git a/tests/utils/sql/copy_sql_test.py b/tests/utils/sql/copy_sql_test.py index 28fad91..8516dca 100644 --- a/tests/utils/sql/copy_sql_test.py +++ b/tests/utils/sql/copy_sql_test.py @@ -45,15 +45,11 @@ def destination_db() -> Generator[Session, None, None]: def test_transfer_db(full_db: Session, destination_db: Session) -> None: """Tests if the data transfer between two sql tables works.""" transfer_db_function(source=full_db, destination=destination_db) - sbind = full_db.bind - dbind = destination_db.bind - assert isinstance(sbind, Engine) - assert isinstance(dbind, Engine) assert Base.metadata.sorted_tables for table in Base.metadata.sorted_tables + ["company"]: pd.testing.assert_frame_equal( - pd.read_sql_table(str(table), dbind), - pd.read_sql_table(str(table), sbind), + pd.read_sql_table(str(table), destination_db.connection()), + pd.read_sql_table(str(table), full_db.connection()), ) diff --git a/tests/utils/sql/rest_sql_test.py b/tests/utils/sql/rest_sql_test.py index 7a3b51f..5d1b782 100644 --- a/tests/utils/sql/rest_sql_test.py +++ b/tests/utils/sql/rest_sql_test.py @@ -13,10 +13,10 @@ def test_reset_sql_all(full_db: Session) -> None: """Tests if all sql tables are reset.""" reset_sql.reset_tables(all_tables=True, db=full_db) assert pd.read_sql_table( - entities.MissingCompany.__tablename__, con=full_db.bind # type:ignore + entities.MissingCompany.__tablename__, con=full_db.connection() ).empty assert pd.read_sql_table( - entities.Company.__tablename__, con=full_db.bind # type:ignore + entities.Company.__tablename__, con=full_db.connection() ).empty @@ -24,10 +24,10 @@ def test_reset_sql(full_db: Session) -> None: """Tests if only most sql tables are reset.""" reset_sql.reset_tables(all_tables=False, db=full_db) assert pd.read_sql_table( - entities.Company.__tablename__, con=full_db.bind # type:ignore + entities.Company.__tablename__, con=full_db.connection() ).empty assert not pd.read_sql_table( - entities.MissingCompany.__tablename__, con=full_db.bind # type:ignore + entities.MissingCompany.__tablename__, con=full_db.connection() ).empty diff --git a/tests/utils/transfer_news_test.py b/tests/utils/transfer_news_test.py index 935f3f9..5b132b4 100644 --- a/tests/utils/transfer_news_test.py +++ b/tests/utils/transfer_news_test.py @@ -130,7 +130,7 @@ def test_transfer_news_to_sql(full_db: Session, monkeypatch: MonkeyPatch) -> Non lambda _: NEWS_TEXTS, ) transfer_news._transfer_news_to_sql(None, full_db) # type: ignore - articles = pd.read_sql_table(entities.News.__tablename__, full_db.bind) # type: ignore + articles = pd.read_sql_table(entities.News.__tablename__, full_db.connection()) # type: ignore assert "text" in articles.columns del articles["text"] assert articles.to_dict(orient="records") == [ @@ -158,7 +158,7 @@ def test_transfer_news_to_sql(full_db: Session, monkeypatch: MonkeyPatch) -> Non }, ] pd.testing.assert_frame_equal( - pd.read_sql_table(entities.Sentiment.__tablename__, full_db.bind), # type: ignore + pd.read_sql_table(entities.Sentiment.__tablename__, full_db.connection()), pd.DataFrame( [ {