Replaced the bind with the connection method (#567)

Der syntax den ich euch gezeigt habe der falsch.
Jetzt ist er richtig.
Sorry.
This commit is contained in:
2024-01-15 21:13:47 +01:00
committed by GitHub
7 changed files with 26 additions and 48 deletions

View File

@ -24,11 +24,7 @@ def get_company_data(session: Session) -> pd.DataFrame:
query_company = session.query(entities.Company, entities.DistrictCourt.name).join( query_company = session.query(entities.Company, entities.DistrictCourt.name).join(
entities.DistrictCourt entities.DistrictCourt
) )
engine = session.bind return pd.read_sql(str(query_company), session.connection(), index_col="company_id")
if not isinstance(engine, sa.engine.Engine):
raise TypeError
return pd.read_sql(str(query_company), engine, index_col="company_id")
def get_person_data(session: Session) -> pd.DataFrame: def get_person_data(session: Session) -> pd.DataFrame:
@ -41,11 +37,7 @@ def get_person_data(session: Session) -> pd.DataFrame:
A dataframe containing all available company data including the corresponding district court. A dataframe containing all available company data including the corresponding district court.
""" """
query_person = session.query(entities.Person) query_person = session.query(entities.Person)
engine = session.bind return pd.read_sql(str(query_person), session.connection(), index_col="person_id")
if not isinstance(engine, sa.engine.Engine):
raise TypeError
return pd.read_sql(str(query_person), engine, index_col="person_id")
def get_finance_data(session: Session) -> pd.DataFrame: def get_finance_data(session: Session) -> pd.DataFrame:
@ -61,11 +53,7 @@ def get_finance_data(session: Session) -> pd.DataFrame:
entities.AnnualFinanceStatement, entities.Company.name, entities.Company.id entities.AnnualFinanceStatement, entities.Company.name, entities.Company.id
).join(entities.Company) ).join(entities.Company)
engine = session.bind return pd.read_sql(str(query_finance), session.connection())
if not isinstance(engine, sa.engine.Engine):
raise TypeError
return pd.read_sql(str(query_finance), engine)
@cached( # type: ignore @cached( # type: ignore
@ -95,10 +83,6 @@ def get_finance_data_of_one_company(session: Session, company_id: int) -> pd.Dat
logger.warning("SQL rollback when demanded!") logger.warning("SQL rollback when demanded!")
session.rollback() session.rollback()
annual_finance_data = query.all() annual_finance_data = query.all()
engine = session.bind
if not isinstance(engine, sa.engine.Engine):
raise TypeError
data = [row.__dict__ for row in annual_finance_data] data = [row.__dict__ for row in annual_finance_data]
if "_sa_instance_state" not in pd.DataFrame(data).columns: if "_sa_instance_state" not in pd.DataFrame(data).columns:
return pd.DataFrame(data) return pd.DataFrame(data)
@ -107,7 +91,7 @@ def get_finance_data_of_one_company(session: Session, company_id: int) -> pd.Dat
@cached( # type: ignore @cached( # type: ignore
cache=TTLCache(maxsize=1, ttl=300), cache=TTLCache(maxsize=1, ttl=300),
key=lambda session: 0 if session is None else str(session.bind), key=lambda session: 0 if session is None else str(session.connection()),
) )
def get_options(session: Session | None) -> dict[int, str]: def get_options(session: Session | None) -> dict[int, str]:
"""Collects the search options for the companies and persons. """Collects the search options for the companies and persons.

View File

@ -22,7 +22,7 @@ def reset_tables(db: Session, all_tables: bool = False) -> None:
"""Drops all SQL tables and recreates them.""" """Drops all SQL tables and recreates them."""
if all_tables: if all_tables:
logger.warning(f"Resetting all SQL tables in {db.bind}.") logger.warning(f"Resetting all SQL tables in {db.bind}.")
Base.metadata.drop_all(db.bind) Base.metadata.drop_all(db.connection())
db.commit() db.commit()
else: else:
logger.info(f"Resetting the main SQL tables in {db.bind}.") logger.info(f"Resetting the main SQL tables in {db.bind}.")
@ -32,7 +32,7 @@ def reset_tables(db: Session, all_tables: bool = False) -> None:
if str(table) != entities.MissingCompany.__tablename__ if str(table) != entities.MissingCompany.__tablename__
] ]
logger.debug(f"Dropping tables: {', '.join([str(_) for _ in tables])}") logger.debug(f"Dropping tables: {', '.join([str(_) for _ in tables])}")
Base.metadata.drop_all(db.bind, tables=tables) Base.metadata.drop_all(db.connection(), tables=tables)
db.commit() db.commit()
init_db(db) init_db(db)
db.commit() db.commit()

View File

@ -188,7 +188,6 @@ def full_db(empty_db: Session, finance_statements: list[dict[str, Any]]) -> Sess
entities.MissingCompany(name="Some company missing", zip_code="", city="") entities.MissingCompany(name="Some company missing", zip_code="", city="")
) )
empty_db.commit() empty_db.commit()
# print(pd.read_sql_table("company", empty_db.bind).to_string())
return empty_db return empty_db

View File

@ -11,7 +11,6 @@ import pytest
import sqlalchemy as sa import sqlalchemy as sa
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from aki_prj23_transparenzregister.models.company import ( from aki_prj23_transparenzregister.models.company import (
@ -714,10 +713,8 @@ def test_add_relationships_none(empty_relations: list, full_db: Session) -> None
def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> None: def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> None:
"""Testing to add lots of relations.""" """Testing to add lots of relations."""
data_transfer.add_relationships(documents, full_db) data_transfer.add_relationships(documents, full_db)
bind = full_db.bind
assert isinstance(bind, Engine)
pd.testing.assert_frame_equal( pd.testing.assert_frame_equal(
pd.read_sql_table("company", bind), pd.read_sql_table("company", full_db.connection()),
pd.DataFrame( pd.DataFrame(
{ {
"id": {0: 1, 1: 2, 2: 3}, "id": {0: 1, 1: 2, 2: 3},
@ -750,13 +747,13 @@ def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> Non
} }
), ),
) )
assert len(pd.read_sql_table("company_relation", bind).index) == 0 assert len(pd.read_sql_table("company_relation", full_db.connection()).index) == 0
pd.testing.assert_frame_equal( pd.testing.assert_frame_equal(
pd.read_sql_table("person_relation", bind), pd.read_sql_table("person_relation", full_db.connection()),
pd.DataFrame({"id": {0: 1, 1: 2}, "person_id": {0: 6, 1: 7}}), pd.DataFrame({"id": {0: 1, 1: 2}, "person_id": {0: 6, 1: 7}}),
) )
pd.testing.assert_frame_equal( pd.testing.assert_frame_equal(
pd.read_sql_table("relation", bind), pd.read_sql_table("relation", full_db.connection()),
pd.DataFrame( pd.DataFrame(
{ {
"id": {0: 1, 1: 2}, "id": {0: 1, 1: 2},
@ -768,7 +765,7 @@ def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> Non
), ),
) )
pd.testing.assert_frame_equal( pd.testing.assert_frame_equal(
pd.read_sql_table("person", bind), pd.read_sql_table("person", full_db.connection()),
pd.DataFrame( pd.DataFrame(
{ {
"id": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7}, "id": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7},
@ -976,7 +973,7 @@ def test_add_annual_report_empty(
) -> None: ) -> None:
"""Testing if the correct warning is thrown when the financial and auditor records are empty.""" """Testing if the correct warning is thrown when the financial and auditor records are empty."""
df_prior = pd.read_sql_table( df_prior = pd.read_sql_table(
entities.AnnualFinanceStatement.__tablename__, full_db.bind # type: ignore str(entities.AnnualFinanceStatement.__tablename__), full_db.connection() # type: ignore
) )
spy_warning = mocker.spy(data_transfer.logger, "debug") spy_warning = mocker.spy(data_transfer.logger, "debug")
@ -985,7 +982,9 @@ def test_add_annual_report_empty(
spy_warning.assert_called_once() spy_warning.assert_called_once()
pd.testing.assert_frame_equal( pd.testing.assert_frame_equal(
df_prior, df_prior,
pd.read_sql_table(entities.AnnualFinanceStatement.__tablename__, full_db.bind), # type: ignore pd.read_sql_table(
str(entities.AnnualFinanceStatement.__tablename__), full_db.connection() # type: ignore
),
) )
@ -1031,7 +1030,7 @@ def test_add_annual_report(
) )
full_db.commit() full_db.commit()
df_prior = pd.read_sql_table( df_prior = pd.read_sql_table(
entities.AnnualFinanceStatement.__tablename__, full_db.bind # type: ignore str(entities.AnnualFinanceStatement.__tablename__), full_db.connection() # type: ignore
) )
expected_results = pd.DataFrame( expected_results = pd.DataFrame(
finance_statements finance_statements
@ -1158,7 +1157,7 @@ def test_company_relation_missing(empty_db: Session) -> None:
pd.testing.assert_frame_equal( pd.testing.assert_frame_equal(
pd.read_sql_table( pd.read_sql_table(
entities.MissingCompany.__tablename__, empty_db.bind # type: ignore entities.MissingCompany.__tablename__, empty_db.connection()
).set_index("name"), ).set_index("name"),
pd.DataFrame( pd.DataFrame(
[ [
@ -1204,7 +1203,7 @@ def test_company_relation_missing_reset(empty_db: Session) -> None:
empty_db.commit() empty_db.commit()
data_transfer.reset_relation_counter(empty_db) data_transfer.reset_relation_counter(empty_db)
queried_df = pd.read_sql_table( queried_df = pd.read_sql_table(
entities.MissingCompany.__tablename__, empty_db.bind # type: ignore entities.MissingCompany.__tablename__, empty_db.connection()
).set_index("name") ).set_index("name")
pd.testing.assert_frame_equal( pd.testing.assert_frame_equal(
queried_df, queried_df,

View File

@ -45,15 +45,11 @@ def destination_db() -> Generator[Session, None, None]:
def test_transfer_db(full_db: Session, destination_db: Session) -> None: def test_transfer_db(full_db: Session, destination_db: Session) -> None:
"""Tests if the data transfer between two sql tables works.""" """Tests if the data transfer between two sql tables works."""
transfer_db_function(source=full_db, destination=destination_db) transfer_db_function(source=full_db, destination=destination_db)
sbind = full_db.bind
dbind = destination_db.bind
assert isinstance(sbind, Engine)
assert isinstance(dbind, Engine)
assert Base.metadata.sorted_tables assert Base.metadata.sorted_tables
for table in Base.metadata.sorted_tables + ["company"]: for table in Base.metadata.sorted_tables + ["company"]:
pd.testing.assert_frame_equal( pd.testing.assert_frame_equal(
pd.read_sql_table(str(table), dbind), pd.read_sql_table(str(table), destination_db.connection()),
pd.read_sql_table(str(table), sbind), pd.read_sql_table(str(table), full_db.connection()),
) )

View File

@ -13,10 +13,10 @@ def test_reset_sql_all(full_db: Session) -> None:
"""Tests if all sql tables are reset.""" """Tests if all sql tables are reset."""
reset_sql.reset_tables(all_tables=True, db=full_db) reset_sql.reset_tables(all_tables=True, db=full_db)
assert pd.read_sql_table( assert pd.read_sql_table(
entities.MissingCompany.__tablename__, con=full_db.bind # type:ignore entities.MissingCompany.__tablename__, con=full_db.connection()
).empty ).empty
assert pd.read_sql_table( assert pd.read_sql_table(
entities.Company.__tablename__, con=full_db.bind # type:ignore entities.Company.__tablename__, con=full_db.connection()
).empty ).empty
@ -24,10 +24,10 @@ def test_reset_sql(full_db: Session) -> None:
"""Tests if only most sql tables are reset.""" """Tests if only most sql tables are reset."""
reset_sql.reset_tables(all_tables=False, db=full_db) reset_sql.reset_tables(all_tables=False, db=full_db)
assert pd.read_sql_table( assert pd.read_sql_table(
entities.Company.__tablename__, con=full_db.bind # type:ignore entities.Company.__tablename__, con=full_db.connection()
).empty ).empty
assert not pd.read_sql_table( assert not pd.read_sql_table(
entities.MissingCompany.__tablename__, con=full_db.bind # type:ignore entities.MissingCompany.__tablename__, con=full_db.connection()
).empty ).empty

View File

@ -130,7 +130,7 @@ def test_transfer_news_to_sql(full_db: Session, monkeypatch: MonkeyPatch) -> Non
lambda _: NEWS_TEXTS, lambda _: NEWS_TEXTS,
) )
transfer_news._transfer_news_to_sql(None, full_db) # type: ignore transfer_news._transfer_news_to_sql(None, full_db) # type: ignore
articles = pd.read_sql_table(entities.News.__tablename__, full_db.bind) # type: ignore articles = pd.read_sql_table(entities.News.__tablename__, full_db.connection()) # type: ignore
assert "text" in articles.columns assert "text" in articles.columns
del articles["text"] del articles["text"]
assert articles.to_dict(orient="records") == [ assert articles.to_dict(orient="records") == [
@ -158,7 +158,7 @@ def test_transfer_news_to_sql(full_db: Session, monkeypatch: MonkeyPatch) -> Non
}, },
] ]
pd.testing.assert_frame_equal( pd.testing.assert_frame_equal(
pd.read_sql_table(entities.Sentiment.__tablename__, full_db.bind), # type: ignore pd.read_sql_table(entities.Sentiment.__tablename__, full_db.connection()),
pd.DataFrame( pd.DataFrame(
[ [
{ {