From 9cc58ba8bec5aba2eeaadfb29bbdaa3015cee2c7 Mon Sep 17 00:00:00 2001 From: TrisNol Date: Sat, 7 Oct 2023 09:11:43 +0200 Subject: [PATCH] fix: Add script to fix malformed yearly_result entries --- .../apps/fix_company_financials.py | 12 +++---- .../utils/mongo/company_mongo_service.py | 16 ++++++++++ .../utils/mongo/company_mongo_service_test.py | 32 +++++++++++++++++++ 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/src/aki_prj23_transparenzregister/apps/fix_company_financials.py b/src/aki_prj23_transparenzregister/apps/fix_company_financials.py index 0fb2aed..d39fabc 100644 --- a/src/aki_prj23_transparenzregister/apps/fix_company_financials.py +++ b/src/aki_prj23_transparenzregister/apps/fix_company_financials.py @@ -1,4 +1,6 @@ """Fix fincancial data of particular companies identified by their ID.""" +from loguru import logger + from aki_prj23_transparenzregister.apps.enrich_company_financials import work from aki_prj23_transparenzregister.config.config_providers import JsonFileConfigProvider from aki_prj23_transparenzregister.utils.mongo.company_mongo_service import ( @@ -12,12 +14,8 @@ if __name__ == "__main__": mongo_connector = MongoConnector(config_provider.get_mongo_connection_string()) company_service = CompanyMongoService(mongo_connector) - entries = [ - "649f16a4e198338c3b442ab1", - "649f16a5e198338c3b442b0a", - "649f16a5e198338c3b442ac6", - ] + entries = company_service.get_where_malformed_yearly_results() - companies = [company_service.get_by_object_id(entry) for entry in entries] - for company in companies: + for company in entries: work(company, company_service) + logger.info(f"Processed {company['name']}") diff --git a/src/aki_prj23_transparenzregister/utils/mongo/company_mongo_service.py b/src/aki_prj23_transparenzregister/utils/mongo/company_mongo_service.py index c328c27..d84ffdb 100644 --- a/src/aki_prj23_transparenzregister/utils/mongo/company_mongo_service.py +++ b/src/aki_prj23_transparenzregister/utils/mongo/company_mongo_service.py @@ -1,4 +1,5 @@ """CompanyMongoService.""" +import re from threading import Lock from bson.objectid import ObjectId @@ -87,6 +88,21 @@ class CompanyMongoService: with self.lock: return list(self.collection.find({"yearly_results": {"$gt": {}}})) + def get_where_malformed_yearly_results(self) -> list[dict]: + """Finds all entries with malformed yearly_results (e.g., key is not a year). + + Returns: + list[dict]: List of companies + """ + preliminary_results = self.get_where_yearly_results() + malformed_entries = [] + # TODO There should be a cleaner solution using pure MongoDB queries/aggregations + for entry in preliminary_results: + for key in entry["yearly_results"]: + if not re.match(r"^[0-9]{4}$", key): + malformed_entries.append(entry) + return malformed_entries + def insert(self, company: Company) -> InsertOneResult: """Insert a new Company document. diff --git a/tests/utils/mongo/company_mongo_service_test.py b/tests/utils/mongo/company_mongo_service_test.py index 7b99a07..52b58af 100644 --- a/tests/utils/mongo/company_mongo_service_test.py +++ b/tests/utils/mongo/company_mongo_service_test.py @@ -160,3 +160,35 @@ def test_add_yearly_reslults(mock_mongo_connector: Mock, mock_collection: Mock) mock_result: list = [{"_id": "abc", "brille?": "Fielmann", "Hotel?": "Trivago"}] mock_collection.update_one.return_value = mock_result assert service.add_yearly_results("612316a1e198338c3b44299e", {}) == mock_result + + +def test_get_where_malformed_yearly_results( + mock_mongo_connector: Mock, mock_collection: Mock +) -> None: + mock_mongo_connector.database = {"companies": mock_collection} + service = CompanyMongoService(mock_mongo_connector) + mock_result: list = [ + { + "_id": "abc", + "name": "Fielmann", + "Hotel?": "Trivago", + "yearly_results": {"Vor Aeonen": 42, "2022": 4711}, + }, + { + "_id": "abc", + "name": "Fielmann", + "Hotel?": "Trivago", + "yearly_results": {"1998": 42, "2022": 4711}, + }, + { + "_id": "abc", + "name": "Fielmann", + "Hotel?": "Trivago", + "yearly_results": {"19": 42, "2022": 4711}, + }, + ] + mock_collection.find.return_value = mock_result + assert service.get_where_malformed_yearly_results() == [ + mock_result[0], + mock_result[2], + ]