test: Adapt existing unit tests to refactored imports

This commit is contained in:
TrisNol 2023-11-03 23:26:08 +01:00
parent 042a019628
commit d6b07431e7
6 changed files with 123 additions and 94 deletions

View File

@ -1,32 +1,20 @@
"""Retrieve missing companies from unternehmensregister."""
import argparse
import dataclasses
import glob
import json
import multiprocessing
import os
import sys
import json
import glob
import argparse
import tempfile
import dataclasses
import multiprocessing
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from tqdm import tqdm
from aki_prj23_transparenzregister.config.config_providers import (
HELP_TEXT_CONFIG,
ConfigProvider,
get_config_provider,
)
from aki_prj23_transparenzregister.utils.logger_config import (
add_logger_options_to_argparse,
configer_logger,
)
from aki_prj23_transparenzregister.utils.sql import connector
from aki_prj23_transparenzregister.utils.sql import entities
from aki_prj23_transparenzregister.utils.mongo.connector import MongoConnector
from aki_prj23_transparenzregister.utils.mongo.company_mongo_service import (
CompanyMongoService,
)
from aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister import (
extract,
load,
@ -34,13 +22,29 @@ from aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister im
from aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform import (
main as transform,
)
from aki_prj23_transparenzregister.utils.logger_config import (
add_logger_options_to_argparse,
configer_logger,
)
from aki_prj23_transparenzregister.utils.mongo.company_mongo_service import (
CompanyMongoService,
)
from aki_prj23_transparenzregister.utils.mongo.connector import MongoConnector
from aki_prj23_transparenzregister.utils.sql import connector, entities
def work(company: entities.Company, configProvider) -> None:
def work(company: entities.Company, config_provider: ConfigProvider) -> None:
"""Main method.
Args:
company (entities.Company): Company to be searched for
config_provider (ConfigProvider): ConfigProvider
"""
with tempfile.TemporaryDirectory() as tmp_dir:
xml_dir = os.path.join(*[tmp_dir, "xml"])
os.makedirs(xml_dir, exist_ok=True)
try:
extract.scrape(company.name, xml_dir, True)
extract.scrape(company.name, xml_dir, True, True) # type: ignore
except Exception as e:
logger.error(e)
return
@ -57,37 +61,41 @@ def work(company: entities.Company, configProvider) -> None:
try:
path = os.path.join(json_dir, file)
with open(path, encoding="utf-8") as file_object:
company_mapped = transform.map_unternehmensregister_json(
json.loads(file_object.read())
company_mapped = transform.map_unternehmensregister_json(
json.loads(file_object.read())
)
name = "".join(e for e in company_mapped.name if e.isalnum())[:50]
with open(
os.path.join(output_path, f"{name}.json"),
"w+",
encoding="utf-8",
) as export_file:
json.dump(
dataclasses.asdict(company_mapped),
export_file,
ensure_ascii=False,
)
name = "".join(e for e in company_mapped.name if e.isalnum())[:50]
with open(
os.path.join(output_path, f"{name}.json"),
"w+",
encoding="utf-8",
) as export_file:
json.dump(
dataclasses.asdict(company_mapped), export_file, ensure_ascii=False
)
except Exception as e:
logger.error(e)
return
mongoConnector = MongoConnector(configProvider.get_mongo_connection_string())
companyMongoService = CompanyMongoService(
mongoConnector
)
num_processed = load.load_directory_to_mongo(output_path, companyMongoService)
mongoConnector.client.close()
mongo_connector = MongoConnector(config_provider.get_mongo_connection_string())
company_mongo_service = CompanyMongoService(mongo_connector)
num_processed = load.load_directory_to_mongo(output_path, company_mongo_service)
mongo_connector.client.close()
try:
if num_processed > 0:
with connector.get_session(configProvider) as session:
company = session.query(entities.MissingCompany).where(entities.MissingCompany.name == company.name).first()
company.searched_for = True
with connector.get_session(config_provider) as session:
company = (
session.query(entities.MissingCompany) # type: ignore
.where(entities.MissingCompany.name == company.name)
.first()
)
company.searched_for = True # type: ignore
session.commit()
print(f"Processed {company.name}")
logger.info(f"Processed {company.name}")
except Exception as e:
logger.error(e)
return
@ -109,22 +117,23 @@ if __name__ == "__main__":
parsed = parser.parse_args(sys.argv[1:])
configer_logger(namespace=parsed)
config = parsed.config
configProvider = get_config_provider(config)
session = connector.get_session(configProvider)
config_provider = get_config_provider(config)
session = connector.get_session(config_provider)
companyMongoService = CompanyMongoService(
MongoConnector(configProvider.get_mongo_connection_string())
company_mongo_service = CompanyMongoService(
MongoConnector(config_provider.get_mongo_connection_string())
)
missing_companies = session.query(entities.MissingCompany).where(entities.MissingCompany.searched_for == False).all()
missing_companies = (
session.query(entities.MissingCompany)
.where(entities.MissingCompany.searched_for is False)
.all()
)
batch_size = 5
pool = multiprocessing.Pool(processes=batch_size)
# Scrape data from unternehmensregister
params = [
(company, configProvider)
for company in missing_companies
]
params = [(company, config_provider) for company in missing_companies]
# Map the process_handler function to the parameter list using the Pool
pool.starmap(work, params)
@ -134,4 +143,3 @@ if __name__ == "__main__":
# Wait for all the processes to complete
pool.join()
# for company in tqdm(missing_companies):

View File

@ -3,7 +3,6 @@
import glob
import multiprocessing
import os
from pathlib import Path
from loguru import logger
from selenium import webdriver
@ -13,12 +12,19 @@ from selenium.webdriver.support.ui import WebDriverWait
from tqdm import tqdm
def scrape(query: str, download_dir: str, full_match: bool = False) -> None:
def scrape(
query: str,
download_dir: str,
full_match: bool = False,
early_stopping: bool = False,
) -> None:
"""Fetch results from Unternehmensregister for given query.
Args:
query (str): Search Query (RegEx supported)
download_dir (list[str]): Directory to place output files in
full_match (bool, optional): Only scrape first result. Defaults to False.
early_stopping (bool, optional): Stop scraping after first page. Defaults to False.
"""
# download_path = os.path.join(str(Path.cwd()), *download_dir)
download_path = download_dir
@ -75,7 +81,9 @@ def scrape(query: str, download_dir: str, full_match: bool = False) -> None:
]
for index, company_link in enumerate(companies_tab):
company_name = company_names[index]
if company_name in processed_companies or (full_match == True and company_name != query):
if company_name in processed_companies or (
full_match is True and company_name != query
):
continue
# Go to intermediary page
company_link.click()
@ -122,8 +130,10 @@ def scrape(query: str, download_dir: str, full_match: bool = False) -> None:
finally:
for _ in range(6):
driver.back()
if company_name == query and full_match == True:
break
if company_name == query and full_match is True:
break # noqa: B012
if early_stopping is True:
break
driver.find_element(By.XPATH, '//*[@class="fas fa-angle-right"]').click()
driver.close()

View File

@ -0,0 +1,6 @@
"""Testing find_missing_companies.py."""
from aki_prj23_transparenzregister.apps import find_missing_companies
def test_import_find_missing_companies() -> None:
assert find_missing_companies

View File

@ -86,4 +86,4 @@ def test_wait_for_download_condition() -> None:
def test_scrape() -> None:
with TemporaryDirectory(dir="./") as temp_dir:
extract.scrape("GEA Farm Technologies GmbH", [temp_dir])
extract.scrape("GEA Farm Technologies GmbH", temp_dir)

View File

@ -0,0 +1,24 @@
"""Testing main.py."""
import json
import os
from tempfile import TemporaryDirectory
from aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform import (
main,
)
def test_transform_xml_to_json() -> None:
with TemporaryDirectory(dir="./") as temp_source_dir:
with open(os.path.join(temp_source_dir, "test.xml"), "w") as file:
xml_input = """<?xml version="1.0" encoding="UTF-8"?>
<test>
<message>Hello World!</message>
</test>
"""
file.write(xml_input)
with TemporaryDirectory(dir="./") as temp_target_dir:
main.transform_xml_to_json(temp_source_dir, temp_target_dir)
with open(os.path.join(temp_target_dir, "test.json")) as file:
json_output = json.load(file)
assert json_output == {"test": {"message": "Hello World!"}}

View File

@ -1,7 +1,4 @@
"""Testing utils/data_extraction/unternehmensregister/transform.py."""
import json
import os
from tempfile import TemporaryDirectory
from unittest.mock import Mock, patch
import pytest
@ -21,27 +18,11 @@ from aki_prj23_transparenzregister.models.company import (
PersonToCompanyRelationship,
RelationshipRoleEnum,
)
from aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister import (
transform,
from aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1 import (
v1 as transform,
)
def test_transform_xml_to_json() -> None:
with TemporaryDirectory(dir="./") as temp_source_dir:
with open(os.path.join(temp_source_dir, "test.xml"), "w") as file:
xml_input = """<?xml version="1.0" encoding="UTF-8"?>
<test>
<message>Hello World!</message>
</test>
"""
file.write(xml_input)
with TemporaryDirectory(dir="./") as temp_target_dir:
transform.transform_xml_to_json(temp_source_dir, temp_target_dir)
with open(os.path.join(temp_target_dir, "test.json")) as file:
json_output = json.load(file)
assert json_output == {"test": {"message": "Hello World!"}}
def test_parse_stakeholder_org_hidden_in_person() -> None:
data = {
"Beteiligter": {
@ -787,34 +768,34 @@ def test_map_co_relation(value: dict, expected_result: dict) -> None:
@patch(
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.map_co_relation"
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1.v1.map_co_relation"
)
@patch(
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.map_company_id"
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1.v1.map_company_id"
)
@patch(
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.name_from_beteiligung"
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1.v1.name_from_beteiligung"
)
@patch(
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.loc_from_beteiligung"
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1.v1.loc_from_beteiligung"
)
@patch(
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.map_last_update"
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1.v1.map_last_update"
)
@patch(
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.map_rechtsform"
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1.v1.map_rechtsform"
)
@patch(
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.map_capital"
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1.v1.map_capital"
)
@patch(
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.map_business_purpose"
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1.v1.map_business_purpose"
)
@patch(
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.map_founding_date"
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1.v1.map_founding_date"
)
@patch(
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.parse_stakeholder"
"aki_prj23_transparenzregister.utils.data_extraction.unternehmensregister.transform.v1.v1.parse_stakeholder"
)
def test_map_unternehmensregister_json( # noqa: PLR0913
mock_map_parse_stakeholder: Mock,