mirror of
https://github.com/fhswf/aki_prj23_transparenzregister.git
synced 2025-06-21 23:13:55 +02:00
Replaced the bind with the connection method (#567)
Der syntax den ich euch gezeigt habe der falsch. Jetzt ist er richtig. Sorry.
This commit is contained in:
@ -24,11 +24,7 @@ def get_company_data(session: Session) -> pd.DataFrame:
|
|||||||
query_company = session.query(entities.Company, entities.DistrictCourt.name).join(
|
query_company = session.query(entities.Company, entities.DistrictCourt.name).join(
|
||||||
entities.DistrictCourt
|
entities.DistrictCourt
|
||||||
)
|
)
|
||||||
engine = session.bind
|
return pd.read_sql(str(query_company), session.connection(), index_col="company_id")
|
||||||
if not isinstance(engine, sa.engine.Engine):
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
return pd.read_sql(str(query_company), engine, index_col="company_id")
|
|
||||||
|
|
||||||
|
|
||||||
def get_person_data(session: Session) -> pd.DataFrame:
|
def get_person_data(session: Session) -> pd.DataFrame:
|
||||||
@ -41,11 +37,7 @@ def get_person_data(session: Session) -> pd.DataFrame:
|
|||||||
A dataframe containing all available company data including the corresponding district court.
|
A dataframe containing all available company data including the corresponding district court.
|
||||||
"""
|
"""
|
||||||
query_person = session.query(entities.Person)
|
query_person = session.query(entities.Person)
|
||||||
engine = session.bind
|
return pd.read_sql(str(query_person), session.connection(), index_col="person_id")
|
||||||
if not isinstance(engine, sa.engine.Engine):
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
return pd.read_sql(str(query_person), engine, index_col="person_id")
|
|
||||||
|
|
||||||
|
|
||||||
def get_finance_data(session: Session) -> pd.DataFrame:
|
def get_finance_data(session: Session) -> pd.DataFrame:
|
||||||
@ -61,11 +53,7 @@ def get_finance_data(session: Session) -> pd.DataFrame:
|
|||||||
entities.AnnualFinanceStatement, entities.Company.name, entities.Company.id
|
entities.AnnualFinanceStatement, entities.Company.name, entities.Company.id
|
||||||
).join(entities.Company)
|
).join(entities.Company)
|
||||||
|
|
||||||
engine = session.bind
|
return pd.read_sql(str(query_finance), session.connection())
|
||||||
if not isinstance(engine, sa.engine.Engine):
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
return pd.read_sql(str(query_finance), engine)
|
|
||||||
|
|
||||||
|
|
||||||
@cached( # type: ignore
|
@cached( # type: ignore
|
||||||
@ -95,10 +83,6 @@ def get_finance_data_of_one_company(session: Session, company_id: int) -> pd.Dat
|
|||||||
logger.warning("SQL rollback when demanded!")
|
logger.warning("SQL rollback when demanded!")
|
||||||
session.rollback()
|
session.rollback()
|
||||||
annual_finance_data = query.all()
|
annual_finance_data = query.all()
|
||||||
engine = session.bind
|
|
||||||
if not isinstance(engine, sa.engine.Engine):
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
data = [row.__dict__ for row in annual_finance_data]
|
data = [row.__dict__ for row in annual_finance_data]
|
||||||
if "_sa_instance_state" not in pd.DataFrame(data).columns:
|
if "_sa_instance_state" not in pd.DataFrame(data).columns:
|
||||||
return pd.DataFrame(data)
|
return pd.DataFrame(data)
|
||||||
@ -107,7 +91,7 @@ def get_finance_data_of_one_company(session: Session, company_id: int) -> pd.Dat
|
|||||||
|
|
||||||
@cached( # type: ignore
|
@cached( # type: ignore
|
||||||
cache=TTLCache(maxsize=1, ttl=300),
|
cache=TTLCache(maxsize=1, ttl=300),
|
||||||
key=lambda session: 0 if session is None else str(session.bind),
|
key=lambda session: 0 if session is None else str(session.connection()),
|
||||||
)
|
)
|
||||||
def get_options(session: Session | None) -> dict[int, str]:
|
def get_options(session: Session | None) -> dict[int, str]:
|
||||||
"""Collects the search options for the companies and persons.
|
"""Collects the search options for the companies and persons.
|
||||||
|
@ -22,7 +22,7 @@ def reset_tables(db: Session, all_tables: bool = False) -> None:
|
|||||||
"""Drops all SQL tables and recreates them."""
|
"""Drops all SQL tables and recreates them."""
|
||||||
if all_tables:
|
if all_tables:
|
||||||
logger.warning(f"Resetting all SQL tables in {db.bind}.")
|
logger.warning(f"Resetting all SQL tables in {db.bind}.")
|
||||||
Base.metadata.drop_all(db.bind)
|
Base.metadata.drop_all(db.connection())
|
||||||
db.commit()
|
db.commit()
|
||||||
else:
|
else:
|
||||||
logger.info(f"Resetting the main SQL tables in {db.bind}.")
|
logger.info(f"Resetting the main SQL tables in {db.bind}.")
|
||||||
@ -32,7 +32,7 @@ def reset_tables(db: Session, all_tables: bool = False) -> None:
|
|||||||
if str(table) != entities.MissingCompany.__tablename__
|
if str(table) != entities.MissingCompany.__tablename__
|
||||||
]
|
]
|
||||||
logger.debug(f"Dropping tables: {', '.join([str(_) for _ in tables])}")
|
logger.debug(f"Dropping tables: {', '.join([str(_) for _ in tables])}")
|
||||||
Base.metadata.drop_all(db.bind, tables=tables)
|
Base.metadata.drop_all(db.connection(), tables=tables)
|
||||||
db.commit()
|
db.commit()
|
||||||
init_db(db)
|
init_db(db)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
@ -188,7 +188,6 @@ def full_db(empty_db: Session, finance_statements: list[dict[str, Any]]) -> Sess
|
|||||||
entities.MissingCompany(name="Some company missing", zip_code="", city="")
|
entities.MissingCompany(name="Some company missing", zip_code="", city="")
|
||||||
)
|
)
|
||||||
empty_db.commit()
|
empty_db.commit()
|
||||||
# print(pd.read_sql_table("company", empty_db.bind).to_string())
|
|
||||||
return empty_db
|
return empty_db
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@ import pytest
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from sqlalchemy.engine import Engine
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from aki_prj23_transparenzregister.models.company import (
|
from aki_prj23_transparenzregister.models.company import (
|
||||||
@ -714,10 +713,8 @@ def test_add_relationships_none(empty_relations: list, full_db: Session) -> None
|
|||||||
def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> None:
|
def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> None:
|
||||||
"""Testing to add lots of relations."""
|
"""Testing to add lots of relations."""
|
||||||
data_transfer.add_relationships(documents, full_db)
|
data_transfer.add_relationships(documents, full_db)
|
||||||
bind = full_db.bind
|
|
||||||
assert isinstance(bind, Engine)
|
|
||||||
pd.testing.assert_frame_equal(
|
pd.testing.assert_frame_equal(
|
||||||
pd.read_sql_table("company", bind),
|
pd.read_sql_table("company", full_db.connection()),
|
||||||
pd.DataFrame(
|
pd.DataFrame(
|
||||||
{
|
{
|
||||||
"id": {0: 1, 1: 2, 2: 3},
|
"id": {0: 1, 1: 2, 2: 3},
|
||||||
@ -750,13 +747,13 @@ def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> Non
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert len(pd.read_sql_table("company_relation", bind).index) == 0
|
assert len(pd.read_sql_table("company_relation", full_db.connection()).index) == 0
|
||||||
pd.testing.assert_frame_equal(
|
pd.testing.assert_frame_equal(
|
||||||
pd.read_sql_table("person_relation", bind),
|
pd.read_sql_table("person_relation", full_db.connection()),
|
||||||
pd.DataFrame({"id": {0: 1, 1: 2}, "person_id": {0: 6, 1: 7}}),
|
pd.DataFrame({"id": {0: 1, 1: 2}, "person_id": {0: 6, 1: 7}}),
|
||||||
)
|
)
|
||||||
pd.testing.assert_frame_equal(
|
pd.testing.assert_frame_equal(
|
||||||
pd.read_sql_table("relation", bind),
|
pd.read_sql_table("relation", full_db.connection()),
|
||||||
pd.DataFrame(
|
pd.DataFrame(
|
||||||
{
|
{
|
||||||
"id": {0: 1, 1: 2},
|
"id": {0: 1, 1: 2},
|
||||||
@ -768,7 +765,7 @@ def test_relationships(documents: list[dict[str, Any]], full_db: Session) -> Non
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
pd.testing.assert_frame_equal(
|
pd.testing.assert_frame_equal(
|
||||||
pd.read_sql_table("person", bind),
|
pd.read_sql_table("person", full_db.connection()),
|
||||||
pd.DataFrame(
|
pd.DataFrame(
|
||||||
{
|
{
|
||||||
"id": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7},
|
"id": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7},
|
||||||
@ -976,7 +973,7 @@ def test_add_annual_report_empty(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Testing if the correct warning is thrown when the financial and auditor records are empty."""
|
"""Testing if the correct warning is thrown when the financial and auditor records are empty."""
|
||||||
df_prior = pd.read_sql_table(
|
df_prior = pd.read_sql_table(
|
||||||
entities.AnnualFinanceStatement.__tablename__, full_db.bind # type: ignore
|
str(entities.AnnualFinanceStatement.__tablename__), full_db.connection() # type: ignore
|
||||||
)
|
)
|
||||||
spy_warning = mocker.spy(data_transfer.logger, "debug")
|
spy_warning = mocker.spy(data_transfer.logger, "debug")
|
||||||
|
|
||||||
@ -985,7 +982,9 @@ def test_add_annual_report_empty(
|
|||||||
spy_warning.assert_called_once()
|
spy_warning.assert_called_once()
|
||||||
pd.testing.assert_frame_equal(
|
pd.testing.assert_frame_equal(
|
||||||
df_prior,
|
df_prior,
|
||||||
pd.read_sql_table(entities.AnnualFinanceStatement.__tablename__, full_db.bind), # type: ignore
|
pd.read_sql_table(
|
||||||
|
str(entities.AnnualFinanceStatement.__tablename__), full_db.connection() # type: ignore
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1031,7 +1030,7 @@ def test_add_annual_report(
|
|||||||
)
|
)
|
||||||
full_db.commit()
|
full_db.commit()
|
||||||
df_prior = pd.read_sql_table(
|
df_prior = pd.read_sql_table(
|
||||||
entities.AnnualFinanceStatement.__tablename__, full_db.bind # type: ignore
|
str(entities.AnnualFinanceStatement.__tablename__), full_db.connection() # type: ignore
|
||||||
)
|
)
|
||||||
expected_results = pd.DataFrame(
|
expected_results = pd.DataFrame(
|
||||||
finance_statements
|
finance_statements
|
||||||
@ -1158,7 +1157,7 @@ def test_company_relation_missing(empty_db: Session) -> None:
|
|||||||
|
|
||||||
pd.testing.assert_frame_equal(
|
pd.testing.assert_frame_equal(
|
||||||
pd.read_sql_table(
|
pd.read_sql_table(
|
||||||
entities.MissingCompany.__tablename__, empty_db.bind # type: ignore
|
entities.MissingCompany.__tablename__, empty_db.connection()
|
||||||
).set_index("name"),
|
).set_index("name"),
|
||||||
pd.DataFrame(
|
pd.DataFrame(
|
||||||
[
|
[
|
||||||
@ -1204,7 +1203,7 @@ def test_company_relation_missing_reset(empty_db: Session) -> None:
|
|||||||
empty_db.commit()
|
empty_db.commit()
|
||||||
data_transfer.reset_relation_counter(empty_db)
|
data_transfer.reset_relation_counter(empty_db)
|
||||||
queried_df = pd.read_sql_table(
|
queried_df = pd.read_sql_table(
|
||||||
entities.MissingCompany.__tablename__, empty_db.bind # type: ignore
|
entities.MissingCompany.__tablename__, empty_db.connection()
|
||||||
).set_index("name")
|
).set_index("name")
|
||||||
pd.testing.assert_frame_equal(
|
pd.testing.assert_frame_equal(
|
||||||
queried_df,
|
queried_df,
|
||||||
|
@ -45,15 +45,11 @@ def destination_db() -> Generator[Session, None, None]:
|
|||||||
def test_transfer_db(full_db: Session, destination_db: Session) -> None:
|
def test_transfer_db(full_db: Session, destination_db: Session) -> None:
|
||||||
"""Tests if the data transfer between two sql tables works."""
|
"""Tests if the data transfer between two sql tables works."""
|
||||||
transfer_db_function(source=full_db, destination=destination_db)
|
transfer_db_function(source=full_db, destination=destination_db)
|
||||||
sbind = full_db.bind
|
|
||||||
dbind = destination_db.bind
|
|
||||||
assert isinstance(sbind, Engine)
|
|
||||||
assert isinstance(dbind, Engine)
|
|
||||||
assert Base.metadata.sorted_tables
|
assert Base.metadata.sorted_tables
|
||||||
for table in Base.metadata.sorted_tables + ["company"]:
|
for table in Base.metadata.sorted_tables + ["company"]:
|
||||||
pd.testing.assert_frame_equal(
|
pd.testing.assert_frame_equal(
|
||||||
pd.read_sql_table(str(table), dbind),
|
pd.read_sql_table(str(table), destination_db.connection()),
|
||||||
pd.read_sql_table(str(table), sbind),
|
pd.read_sql_table(str(table), full_db.connection()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,10 +13,10 @@ def test_reset_sql_all(full_db: Session) -> None:
|
|||||||
"""Tests if all sql tables are reset."""
|
"""Tests if all sql tables are reset."""
|
||||||
reset_sql.reset_tables(all_tables=True, db=full_db)
|
reset_sql.reset_tables(all_tables=True, db=full_db)
|
||||||
assert pd.read_sql_table(
|
assert pd.read_sql_table(
|
||||||
entities.MissingCompany.__tablename__, con=full_db.bind # type:ignore
|
entities.MissingCompany.__tablename__, con=full_db.connection()
|
||||||
).empty
|
).empty
|
||||||
assert pd.read_sql_table(
|
assert pd.read_sql_table(
|
||||||
entities.Company.__tablename__, con=full_db.bind # type:ignore
|
entities.Company.__tablename__, con=full_db.connection()
|
||||||
).empty
|
).empty
|
||||||
|
|
||||||
|
|
||||||
@ -24,10 +24,10 @@ def test_reset_sql(full_db: Session) -> None:
|
|||||||
"""Tests if only most sql tables are reset."""
|
"""Tests if only most sql tables are reset."""
|
||||||
reset_sql.reset_tables(all_tables=False, db=full_db)
|
reset_sql.reset_tables(all_tables=False, db=full_db)
|
||||||
assert pd.read_sql_table(
|
assert pd.read_sql_table(
|
||||||
entities.Company.__tablename__, con=full_db.bind # type:ignore
|
entities.Company.__tablename__, con=full_db.connection()
|
||||||
).empty
|
).empty
|
||||||
assert not pd.read_sql_table(
|
assert not pd.read_sql_table(
|
||||||
entities.MissingCompany.__tablename__, con=full_db.bind # type:ignore
|
entities.MissingCompany.__tablename__, con=full_db.connection()
|
||||||
).empty
|
).empty
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ def test_transfer_news_to_sql(full_db: Session, monkeypatch: MonkeyPatch) -> Non
|
|||||||
lambda _: NEWS_TEXTS,
|
lambda _: NEWS_TEXTS,
|
||||||
)
|
)
|
||||||
transfer_news._transfer_news_to_sql(None, full_db) # type: ignore
|
transfer_news._transfer_news_to_sql(None, full_db) # type: ignore
|
||||||
articles = pd.read_sql_table(entities.News.__tablename__, full_db.bind) # type: ignore
|
articles = pd.read_sql_table(entities.News.__tablename__, full_db.connection()) # type: ignore
|
||||||
assert "text" in articles.columns
|
assert "text" in articles.columns
|
||||||
del articles["text"]
|
del articles["text"]
|
||||||
assert articles.to_dict(orient="records") == [
|
assert articles.to_dict(orient="records") == [
|
||||||
@ -158,7 +158,7 @@ def test_transfer_news_to_sql(full_db: Session, monkeypatch: MonkeyPatch) -> Non
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
pd.testing.assert_frame_equal(
|
pd.testing.assert_frame_equal(
|
||||||
pd.read_sql_table(entities.Sentiment.__tablename__, full_db.bind), # type: ignore
|
pd.read_sql_table(entities.Sentiment.__tablename__, full_db.connection()),
|
||||||
pd.DataFrame(
|
pd.DataFrame(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
Reference in New Issue
Block a user