diff --git a/src/aki_prj23_transparenzregister/utils/data_transfer.py b/src/aki_prj23_transparenzregister/utils/data_transfer.py index c4eeda9..bcc6dde 100644 --- a/src/aki_prj23_transparenzregister/utils/data_transfer.py +++ b/src/aki_prj23_transparenzregister/utils/data_transfer.py @@ -169,9 +169,13 @@ def get_person_id( return person.id # type: ignore -@cached(cache=LRUCache(maxsize=5000), key=lambda name, zip_code, city, db: hash((name, zip_code, city))) # type: ignore +@cached(cache=LRUCache(maxsize=5000), key=lambda name, zip_code, city, db, **_: hash((name, zip_code, city))) # type: ignore def get_company_id( - name: str, zip_code: str | None, city: str | None, db: Session + name: str, + zip_code: str | None, + city: str | None, + db: Session, + **_, ) -> int: """Queries the id of a company. @@ -325,6 +329,37 @@ def add_companies(companies: list[dict[str, Any]], db: Session) -> None: logger.info("When adding companies no problems occurred.") +def reset_relation_counter(db: Session) -> None: + """Resets the counters about missing company relations. + + Args: + db: A session to connect to an SQL db via SQLAlchemy. + """ + db.query(entities.MissingCompany).update({"number_of_links": 0}) + db.commit() + + +def company_relation_missing( + name: str, zip_code: str | None, city: str | None, db: Session, **_: Any +) -> None: + """Adds a relation to a search que. + + Args: + name: The name of the company missing. + zip_code: The zip code of the address where the company is placed. + city: The city where the company is placed. + db: A session to connect to an SQL db via SQLAlchemy. + """ + if missing_relation := db.query(entities.MissingCompany).get(name): + missing_relation.number_of_links += 1 + if not missing_relation.city: + missing_relation.city = city + if not missing_relation.zip_code: + missing_relation.zip_code = zip_code + else: + db.add(entities.MissingCompany(name=name, city=city, zip_code=zip_code)) + + @logger.catch(level="WARNING", reraise=True) def add_relationship( relationship: dict[str, Any], company_id: int, db: Session @@ -356,11 +391,13 @@ def add_relationship( try: relation_to: int = get_company_id( relationship["description"], - relationship["location"]["zip_code"], - relationship["location"]["city"], + **relationship["location"], db=db, ) except KeyError as err: + company_relation_missing( + relationship["description"], **relationship["location"], db=db + ) logger.warning(err) return if company_id == relation_to: @@ -395,8 +432,7 @@ def add_relationships(companies: list[dict[str, dict]], db: Session) -> None: try: company_id: int = get_company_id( company["name"], # type: ignore - company["location"]["zip_code"], - company["location"]["city"], + **company["location"], db=db, ) except Exception: @@ -463,8 +499,7 @@ def add_annual_financial_reports(companies: list[dict], db: Session) -> None: try: company_id: int = get_company_id( company["name"], - company["location"]["zip_code"], - company["location"]["city"], + **company["location"], db=db, ) except Exception: @@ -503,6 +538,7 @@ def transfer_data(config_provider: ConfigProvider) -> None: reset_all_tables(db) add_companies(companies, db) + reset_relation_counter(db) add_relationships(companies, db) add_annual_financial_reports(companies, db) db.close() diff --git a/src/aki_prj23_transparenzregister/utils/sql/entities.py b/src/aki_prj23_transparenzregister/utils/sql/entities.py index d9af4c0..61cebea 100644 --- a/src/aki_prj23_transparenzregister/utils/sql/entities.py +++ b/src/aki_prj23_transparenzregister/utils/sql/entities.py @@ -68,6 +68,17 @@ class Company(Base): sector = sa.Column(sa.String(100), nullable=True) +class MissingCompany(Base): + """Collects missing links that should be searched for.""" + + __tablename__ = "missing_company" + name = sa.Column(sa.String(150), nullable=False, primary_key=True) + zip_code = sa.Column(sa.String(5), nullable=True) + city = sa.Column(sa.String(100), nullable=True) + number_of_links = sa.Column(sa.Integer, nullable=False, default=1) + searched_for = sa.Column(sa.Boolean, nullable=False, default=False) + + class Person(Base): """Person.""" diff --git a/tests/utils/data_transfer_test.py b/tests/utils/data_transfer_test.py index e919d0e..fc0bf4b 100644 --- a/tests/utils/data_transfer_test.py +++ b/tests/utils/data_transfer_test.py @@ -1050,6 +1050,92 @@ def test_add_annual_report_financial_key_error(full_db: Session) -> None: ) +def test_company_relation_missing(empty_db: Session) -> None: + """Check if adding missing company to a query list works.""" + data_transfer.company_relation_missing("Some_company", None, None, empty_db) + empty_db.commit() + data_transfer.company_relation_missing("Other_company", None, "some city", empty_db) + empty_db.commit() + data_transfer.company_relation_missing( + "Some_company", + **{"city": "some city", "zip_code": "12345", "street": "some-street"}, + db=empty_db, + ) + empty_db.commit() + + pd.testing.assert_frame_equal( + pd.read_sql_table( + entities.MissingCompany.__tablename__, empty_db.bind # type: ignore + ).set_index("name"), + pd.DataFrame( + [ + { + "name": "Some_company", + "zip_code": "12345", + "city": "some city", + "number_of_links": 2, + "searched_for": False, + }, + { + "name": "Other_company", + "zip_code": None, + "city": "some city", + "number_of_links": 1, + "searched_for": False, + }, + ] + ).set_index("name"), + ) + + +def test_company_relation_missing_reset(empty_db: Session) -> None: + """Tests the reset of missing company relation counts.""" + empty_db.add_all( + [ + entities.MissingCompany( + name="Some Company", + city="city", + zip_code="12345", + number_of_links=5, + searched_for=True, + ), + entities.MissingCompany( + name="Other Company", + city="city2", + zip_code="98765", + number_of_links=1, + searched_for=False, + ), + ] + ) + empty_db.commit() + data_transfer.reset_relation_counter(empty_db) + queried_df = pd.read_sql_table( + entities.MissingCompany.__tablename__, empty_db.bind # type: ignore + ).set_index("name") + pd.testing.assert_frame_equal( + queried_df, + pd.DataFrame( + [ + { + "name": "Some Company", + "zip_code": "12345", + "city": "city", + "number_of_links": 0, + "searched_for": True, + }, + { + "name": "Other Company", + "zip_code": "98765", + "city": "city2", + "number_of_links": 0, + "searched_for": False, + }, + ] + ).set_index("name"), + ) + + @pytest.mark.parametrize("capital_type", [_.value for _ in CapitalTypeEnum]) @pytest.mark.parametrize("currency", ["€", "EUR"]) def test_norm_capital_eur(currency: str, capital_type: str) -> None: