diff --git a/src/aki_prj23_transparenzregister/ui/data_elements.py b/src/aki_prj23_transparenzregister/ui/data_elements.py index 679ac40..106eecc 100644 --- a/src/aki_prj23_transparenzregister/ui/data_elements.py +++ b/src/aki_prj23_transparenzregister/ui/data_elements.py @@ -4,14 +4,13 @@ import pandas as pd import sqlalchemy as sa from cachetools import TTLCache, cached from loguru import logger -from sqlalchemy.engine import Engine from sqlalchemy.orm import Session from aki_prj23_transparenzregister.utils.sql import entities def get_company_data(session: Session) -> pd.DataFrame: - """Creates a session to the database and get's all available company data. + """Creates a session to the database and collects all available company data. Args: session: A session connecting to the database. @@ -23,14 +22,14 @@ def get_company_data(session: Session) -> pd.DataFrame: entities.DistrictCourt ) engine = session.bind - if not isinstance(engine, Engine): + if not isinstance(engine, sa.engine.Engine): raise TypeError return pd.read_sql(str(query_company), engine, index_col="company_id") def get_person_data(session: Session) -> pd.DataFrame: - """Creates a session to the database and get's all available company data. + """Creates a session to the database and collects all available company data. Args: session: A session connecting to the database. @@ -40,7 +39,7 @@ def get_person_data(session: Session) -> pd.DataFrame: """ query_person = session.query(entities.Person) engine = session.bind - if not isinstance(engine, Engine): + if not isinstance(engine, sa.engine.Engine): raise TypeError return pd.read_sql(str(query_person), engine, index_col="person_id") @@ -60,7 +59,7 @@ def get_finance_data(session: Session) -> pd.DataFrame: ).join(entities.Company) engine = session.bind - if not isinstance(engine, Engine): + if not isinstance(engine, sa.engine.Engine): raise TypeError return pd.read_sql(str(query_finance), engine) @@ -85,9 +84,12 @@ def get_finance_data_of_one_company(session: Session, company_id: int) -> pd.Dat logger.warning("SQL rollback after operational Error!") session.rollback() annual_finance_data = query.all() - + except sa.exc.PendingRollbackError: + logger.warning("SQL rollback when demanded!") + session.rollback() + annual_finance_data = query.all() engine = session.bind - if not isinstance(engine, Engine): + if not isinstance(engine, sa.engine.Engine): raise TypeError data = [row.__dict__ for row in annual_finance_data]