refactor: Implement linter feedback

This commit is contained in:
TrisNol 2023-07-11 14:20:16 +02:00
parent 1c621f46a7
commit ed681d7c47

View File

@ -1,18 +1,27 @@
"""Unternehmensregister Scraping."""
import glob import glob
import logging
import multiprocessing import multiprocessing
import os import os
from pathlib import Path from pathlib import Path
from selenium import webdriver from selenium import webdriver
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as ec
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger()
def scrape(query: str, download_dir: list[str]): def scrape(query: str, download_dir: list[str]):
"""Fetch results from Unternehmensregister for given query.
Args:
query (str): Search Query (RegEx supported)
download_dir (list[str]): Directory to place output files in
"""
download_path = os.path.join(str(Path.cwd()), *download_dir) download_path = os.path.join(str(Path.cwd()), *download_dir)
print(download_path)
options = webdriver.ChromeOptions() options = webdriver.ChromeOptions()
preferences = { preferences = {
"profile.default_content_settings.popups": 0, "profile.default_content_settings.popups": 0,
@ -52,7 +61,7 @@ def scrape(query: str, download_dir: list[str]):
processed_companies = [] processed_companies = []
for page_index in tqdm(range(num_pages)): for _ in tqdm(range(num_pages)):
# Find all "Registerinformationen" # Find all "Registerinformationen"
companies_tab = driver.find_elements( companies_tab = driver.find_elements(
By.LINK_TEXT, "Registerinformationen des Registergerichts" By.LINK_TEXT, "Registerinformationen des Registergerichts"
@ -75,7 +84,7 @@ def scrape(query: str, download_dir: list[str]):
driver.find_element(By.LINK_TEXT, "SI").click() driver.find_element(By.LINK_TEXT, "SI").click()
# Show shopping cart # Show shopping cart
wait.until( wait.until(
EC.visibility_of_element_located( ec.visibility_of_element_located(
(By.LINK_TEXT, "Dokumentenkorb ansehen") (By.LINK_TEXT, "Dokumentenkorb ansehen")
) )
) )
@ -85,12 +94,12 @@ def scrape(query: str, download_dir: list[str]):
elems[-2].click() elems[-2].click()
wait.until( wait.until(
EC.visibility_of_element_located((By.ID, "paymentFormOverview:btnNext")) ec.visibility_of_element_located((By.ID, "paymentFormOverview:btnNext"))
) )
driver.find_element(By.ID, "paymentFormOverview:btnNext").click() driver.find_element(By.ID, "paymentFormOverview:btnNext").click()
wait.until( wait.until(
EC.visibility_of_element_located((By.LINK_TEXT, "Zum Dokumentenkorb")) ec.visibility_of_element_located((By.LINK_TEXT, "Zum Dokumentenkorb"))
) )
driver.find_element(By.LINK_TEXT, "Zum Dokumentenkorb").click() driver.find_element(By.LINK_TEXT, "Zum Dokumentenkorb").click()
@ -98,9 +107,7 @@ def scrape(query: str, download_dir: list[str]):
driver.find_element(By.CLASS_NAME, "download-wrapper").click() driver.find_element(By.CLASS_NAME, "download-wrapper").click()
try: try:
wait.until( wait.until(wait_for_download_condition(download_path, num_files))
lambda x: wait_for_download_condition(download_path, num_files)
)
file_name = "".join(e for e in company_name if e.isalnum()) + ".xml" file_name = "".join(e for e in company_name if e.isalnum()) + ".xml"
rename_latest_file( rename_latest_file(
download_path, download_path,
@ -108,9 +115,9 @@ def scrape(query: str, download_dir: list[str]):
) )
processed_companies.append(company_name) processed_companies.append(company_name)
except Exception: except Exception:
pass logger.warning("Exception caught in Scraping")
finally: finally:
for click_counter in range(6): for _ in range(6):
driver.back() driver.back()
driver.find_element(By.XPATH, '//*[@class="fas fa-angle-right"]').click() driver.find_element(By.XPATH, '//*[@class="fas fa-angle-right"]').click()
driver.close() driver.close()
@ -119,34 +126,61 @@ def scrape(query: str, download_dir: list[str]):
def wait_for_download_condition( def wait_for_download_condition(
path: str, num_files: int, pattern: str = "*.xml" path: str, num_files: int, pattern: str = "*.xml"
) -> bool: ) -> bool:
"""Selenium wait condition monitoring number of files in a dir.
Args:
path (str): Directory path
num_files (int): Current number of file
pattern (str, optional): File pattern. Defaults to "*.xml".
Returns:
bool: Current num file exceeded
"""
return len(glob.glob1(path, pattern)) > num_files return len(glob.glob1(path, pattern)) > num_files
def get_num_files(path: str, pattern: str = "*.xml") -> int: def get_num_files(path: str, pattern: str = "*.xml") -> int:
"""Get number of files in directory.
Args:
path (str): Directory to scan
pattern (str, optional): File pattern. Defaults to "*.xml".
Returns:
int: Number of files matching pattern
"""
return len(glob.glob1(path, pattern)) return len(glob.glob1(path, pattern))
def rename_latest_file(path: str, filename: str, pattern: str = "*.xml"): def rename_latest_file(path: str, filename: str, pattern: str = "*.xml"):
"""Rename file in dir with latest change date.
Args:
path (str): Dir to check
filename (str): Name of file
pattern (str, optional): File pattern. Defaults to "*.xml".
"""
list_of_files = [os.path.join(path, file) for file in glob.glob1(path, pattern)] list_of_files = [os.path.join(path, file) for file in glob.glob1(path, pattern)]
latest_download = max(list_of_files, key=os.path.getctime) latest_download = max(list_of_files, key=os.path.getctime)
os.rename(latest_download, os.path.join(path, filename)) os.rename(latest_download, os.path.join(path, filename))
if __name__ == "__main__": if __name__ == "__main__":
"""Main procedure"""
import pandas as pd import pandas as pd
df = pd.read_excel( df_relevant_companies = pd.read_excel(
"./data/study_id42887_top-100-unternehmen-deutschland.xlsx", "./data/study_id42887_top-100-unternehmen-deutschland.xlsx",
sheet_name="Toplist", sheet_name="Toplist",
skiprows=1, skiprows=1,
) )
df = df[df["Name"].notna()] df_relevant_companies = df_relevant_companies[df_relevant_companies["Name"].notna()]
batch_size = 5 batch_size = 5
pool = multiprocessing.Pool(processes=batch_size) pool = multiprocessing.Pool(processes=batch_size)
params = [ params = [
(query, ["data", "Unternehmensregister", "scraping", query.strip()]) (query, ["data", "Unternehmensregister", "scraping", query.strip()])
for query in df.Name for query in df_relevant_companies.Name
] ]
# Map the process_handler function to the parameter list using the Pool # Map the process_handler function to the parameter list using the Pool
pool.starmap(scrape, params) pool.starmap(scrape, params)