Add an list of missing relation partners to be searched (#171)

- [x] Add a new table
- [x] Add a field to the table that can register if the company was
already queried
- [x] Add a field to the table that counts how many times a relation
partner was missing
- [x] Add a function that restets the counter

Also:
- Reworked the get_company function to use the location dict as kwargs
This commit is contained in:
2023-10-05 19:57:30 +02:00
committed by GitHub
parent c6f2c7467c
commit 09c36960e3
3 changed files with 141 additions and 8 deletions

View File

@ -169,9 +169,13 @@ def get_person_id(
return person.id # type: ignore
@cached(cache=LRUCache(maxsize=5000), key=lambda name, zip_code, city, db: hash((name, zip_code, city))) # type: ignore
@cached(cache=LRUCache(maxsize=5000), key=lambda name, zip_code, city, db, **_: hash((name, zip_code, city))) # type: ignore
def get_company_id(
name: str, zip_code: str | None, city: str | None, db: Session
name: str,
zip_code: str | None,
city: str | None,
db: Session,
**_,
) -> int:
"""Queries the id of a company.
@ -325,6 +329,37 @@ def add_companies(companies: list[dict[str, Any]], db: Session) -> None:
logger.info("When adding companies no problems occurred.")
def reset_relation_counter(db: Session) -> None:
"""Resets the counters about missing company relations.
Args:
db: A session to connect to an SQL db via SQLAlchemy.
"""
db.query(entities.MissingCompany).update({"number_of_links": 0})
db.commit()
def company_relation_missing(
name: str, zip_code: str | None, city: str | None, db: Session, **_: Any
) -> None:
"""Adds a relation to a search que.
Args:
name: The name of the company missing.
zip_code: The zip code of the address where the company is placed.
city: The city where the company is placed.
db: A session to connect to an SQL db via SQLAlchemy.
"""
if missing_relation := db.query(entities.MissingCompany).get(name):
missing_relation.number_of_links += 1
if not missing_relation.city:
missing_relation.city = city
if not missing_relation.zip_code:
missing_relation.zip_code = zip_code
else:
db.add(entities.MissingCompany(name=name, city=city, zip_code=zip_code))
@logger.catch(level="WARNING", reraise=True)
def add_relationship(
relationship: dict[str, Any], company_id: int, db: Session
@ -356,11 +391,13 @@ def add_relationship(
try:
relation_to: int = get_company_id(
relationship["description"],
relationship["location"]["zip_code"],
relationship["location"]["city"],
**relationship["location"],
db=db,
)
except KeyError as err:
company_relation_missing(
relationship["description"], **relationship["location"], db=db
)
logger.warning(err)
return
if company_id == relation_to:
@ -395,8 +432,7 @@ def add_relationships(companies: list[dict[str, dict]], db: Session) -> None:
try:
company_id: int = get_company_id(
company["name"], # type: ignore
company["location"]["zip_code"],
company["location"]["city"],
**company["location"],
db=db,
)
except Exception:
@ -463,8 +499,7 @@ def add_annual_financial_reports(companies: list[dict], db: Session) -> None:
try:
company_id: int = get_company_id(
company["name"],
company["location"]["zip_code"],
company["location"]["city"],
**company["location"],
db=db,
)
except Exception:
@ -503,6 +538,7 @@ def transfer_data(config_provider: ConfigProvider) -> None:
reset_all_tables(db)
add_companies(companies, db)
reset_relation_counter(db)
add_relationships(companies, db)
add_annual_financial_reports(companies, db)
db.close()

View File

@ -68,6 +68,17 @@ class Company(Base):
sector = sa.Column(sa.String(100), nullable=True)
class MissingCompany(Base):
"""Collects missing links that should be searched for."""
__tablename__ = "missing_company"
name = sa.Column(sa.String(150), nullable=False, primary_key=True)
zip_code = sa.Column(sa.String(5), nullable=True)
city = sa.Column(sa.String(100), nullable=True)
number_of_links = sa.Column(sa.Integer, nullable=False, default=1)
searched_for = sa.Column(sa.Boolean, nullable=False, default=False)
class Person(Base):
"""Person."""

View File

@ -1050,6 +1050,92 @@ def test_add_annual_report_financial_key_error(full_db: Session) -> None:
)
def test_company_relation_missing(empty_db: Session) -> None:
"""Check if adding missing company to a query list works."""
data_transfer.company_relation_missing("Some_company", None, None, empty_db)
empty_db.commit()
data_transfer.company_relation_missing("Other_company", None, "some city", empty_db)
empty_db.commit()
data_transfer.company_relation_missing(
"Some_company",
**{"city": "some city", "zip_code": "12345", "street": "some-street"},
db=empty_db,
)
empty_db.commit()
pd.testing.assert_frame_equal(
pd.read_sql_table(
entities.MissingCompany.__tablename__, empty_db.bind # type: ignore
).set_index("name"),
pd.DataFrame(
[
{
"name": "Some_company",
"zip_code": "12345",
"city": "some city",
"number_of_links": 2,
"searched_for": False,
},
{
"name": "Other_company",
"zip_code": None,
"city": "some city",
"number_of_links": 1,
"searched_for": False,
},
]
).set_index("name"),
)
def test_company_relation_missing_reset(empty_db: Session) -> None:
"""Tests the reset of missing company relation counts."""
empty_db.add_all(
[
entities.MissingCompany(
name="Some Company",
city="city",
zip_code="12345",
number_of_links=5,
searched_for=True,
),
entities.MissingCompany(
name="Other Company",
city="city2",
zip_code="98765",
number_of_links=1,
searched_for=False,
),
]
)
empty_db.commit()
data_transfer.reset_relation_counter(empty_db)
queried_df = pd.read_sql_table(
entities.MissingCompany.__tablename__, empty_db.bind # type: ignore
).set_index("name")
pd.testing.assert_frame_equal(
queried_df,
pd.DataFrame(
[
{
"name": "Some Company",
"zip_code": "12345",
"city": "city",
"number_of_links": 0,
"searched_for": True,
},
{
"name": "Other Company",
"zip_code": "98765",
"city": "city2",
"number_of_links": 0,
"searched_for": False,
},
]
).set_index("name"),
)
@pytest.mark.parametrize("capital_type", [_.value for _ in CapitalTypeEnum])
@pytest.mark.parametrize("currency", ["", "EUR"])
def test_norm_capital_eur(currency: str, capital_type: str) -> None: