"""Global configurations and definitions for pytest.""" import datetime import os from collections.abc import Generator from inspect import getmembers, isfunction from typing import Any import pytest from sqlalchemy.engine import Engine from sqlalchemy.orm import Session from aki_prj23_transparenzregister.utils import data_transfer from aki_prj23_transparenzregister.utils.sql import entities from aki_prj23_transparenzregister.utils.sql.connector import get_session, init_db @pytest.fixture(autouse=True) def _clear_caches() -> Generator[None, None, None]: """A function that clears all caches after each test. All the modules containing the cached functions need to be listed in the modules tuple. """ yield # https://stackoverflow.com/a/139198/11003343 modules = (data_transfer,) functions = [ function for module in modules for name, function in getmembers(module, isfunction) if function.__dict__.get("cache") is not None ] # https://cachetools.readthedocs.io/en/stable/?highlight=clear#memoizing-decorators for function in functions: function.cache.clear() # type: ignore @pytest.fixture() def empty_db() -> Generator[Session, None, None]: """Generates a db Session to a sql_lite db.""" if os.path.exists("test-db.db"): os.remove("test-db.db") db = get_session("sqlite:///test-db.db") init_db(db) yield db db.close() bind = db.bind assert isinstance(bind, Engine) bind.dispose() os.remove("test-db.db") @pytest.fixture() def finance_statements() -> list[dict[str, Any]]: """Creates a list of finance statements.""" return [ { "id": 1, "company_id": 1, "date": datetime.date.fromisoformat("2023-01-01"), "total_volume": 1000.0, "ebit": 1000.0, "ebitda": 1000.0, "ebit_margin": 1000.0, "total_balance": 1000.0, "equity": 1000.0, "debt": 1000.0, "return_on_equity": 1000.0, "capital_turnover_rate": 1000.0, "current_liabilities": 1000.0, "dividends": float("NaN"), "net_income": float("NaN"), "assets": 1000.0, "long_term_debt": 1000.0, "short_term_debt": 1000.0, "revenue": 1000.0, "cash_flow": 1000.0, "current_assets": 1000.0, }, { "id": 2, "company_id": 1, "date": datetime.date.fromisoformat("2022-01-01"), "total_volume": 1100.0, "ebit": 1100.0, "ebitda": 1100.0, "ebit_margin": 1100.0, "total_balance": 1100.0, "equity": 1100.0, "debt": 1100.0, "return_on_equity": 1100.0, "capital_turnover_rate": 1100.0, "current_liabilities": 1100.0, "dividends": float("NaN"), "net_income": float("NaN"), "assets": 1100.0, "long_term_debt": 1100.0, "short_term_debt": 1100.0, "revenue": 1100.0, "cash_flow": 1100.0, "current_assets": 1100.0, }, ] @pytest.fixture() def full_db(empty_db: Session, finance_statements: list[dict[str, Any]]) -> Session: """Fills a db with some test data.""" empty_db.add_all( [ entities.DistrictCourt(name="Amtsgericht Bochum", city="Bochum"), entities.DistrictCourt(name="Amtsgericht Dortmund", city="Dortmund"), entities.Person( name="Max", surname="Mustermann", date_of_birth=datetime.date(2023, 1, 1), ), entities.Person( name="Sabine", surname="Mustermann", date_of_birth=datetime.date(2023, 1, 1), ), entities.Person( name="Some Firstname", surname="Some Surname", date_of_birth=datetime.date(2023, 1, 1), ), entities.Person( name="Some Firstname", surname="Some Surname", date_of_birth=datetime.date(2023, 1, 2), ), entities.Person( name="Other Firstname", surname="Other Surname", date_of_birth=datetime.date(2023, 1, 2), ), ] ) empty_db.commit() empty_db.add_all( [ entities.Company( hr="HRB 123", court_id=2, name="Some Company GmbH", street="Sesamstr.", zip_code="58644", city="TV City", last_update=datetime.date.fromisoformat("2023-01-01"), latitude=51.3246, longitude=7.6968, pos_accuracy=4.0, ), entities.Company( hr="HRB 123", court_id=1, name="Other Company GmbH", street="Sesamstr.", zip_code="58636", city="TV City", last_update=datetime.date.fromisoformat("2023-01-01"), latitude=51.38, longitude=7.7032, pos_accuracy=4.0, ), entities.Company( hr="HRB 12", court_id=2, name="Third Company GmbH", last_update=datetime.date.fromisoformat("2023-01-01"), ), ] ) empty_db.commit() empty_db.add_all( [ entities.AnnualFinanceStatement(**finance_statement) for finance_statement in finance_statements ] ) empty_db.commit() # print(pd.read_sql_table("company", empty_db.bind).to_string()) return empty_db