diff --git a/src/aki_prj23_transparenzregister/utils/mongo/ner_pipeline.py b/src/aki_prj23_transparenzregister/utils/mongo/ner_pipeline.py index 42d3ee3..2402865 100644 --- a/src/aki_prj23_transparenzregister/utils/mongo/ner_pipeline.py +++ b/src/aki_prj23_transparenzregister/utils/mongo/ner_pipeline.py @@ -22,7 +22,6 @@ class EntityPipeline: def __init__(self, conn_string: conn.MongoConnection) -> None: """Method to connect to StagingDB.""" self.connect_string = conn_string - self.connect_string.database = "transparenzregister_ner" self.connector = conn.MongoConnector(self.connect_string) self.news_obj = news.MongoNewsService(self.connector) diff --git a/src/aki_prj23_transparenzregister/utils/mongo/ner_service.py b/src/aki_prj23_transparenzregister/utils/mongo/ner_service.py index 056d2b3..82906b5 100644 --- a/src/aki_prj23_transparenzregister/utils/mongo/ner_service.py +++ b/src/aki_prj23_transparenzregister/utils/mongo/ner_service.py @@ -44,7 +44,7 @@ class NerAnalysisService: self.classifier = pipeline( "ner", model="fhswf/bert_de_ner", - grouped_entities=True, + aggregation_strategy="simple", tokenizer="dbmdz/bert-base-german-cased", ) @@ -72,15 +72,16 @@ class NerAnalysisService: # init list for entities entities = [] - text = doc[doc_attrib] + text = doc[doc_attrib].strip() + # check if text is a string and not empty + if isinstance(text, str) and text: + # get entities + doc_nlp = self.nlp(text) - # get entities - doc_nlp = self.nlp(text) - - # select company - for ent in doc_nlp.ents: - if ent.label_ == ent_type: - entities.append(ent.text) + # select company + for ent in doc_nlp.ents: + if ent.label_ == ent_type: + entities.append(ent.text) return dict(Counter(entities)) def ner_company_list( @@ -104,17 +105,19 @@ class NerAnalysisService: entities = [] # Search the text for company names - text = doc[doc_attrib] - # Convert title to lowercase - text = text.lower() + text = doc[doc_attrib].strip() + # check if text is a string and not empty + if isinstance(text, str) and text: + # Convert title to lowercase + text = text.lower() - for company_name in self.complist: - start_idx = text.find(company_name) - if start_idx != -1: # Wort gefunden - start_idx + len(company_name) - entity = company_name - if entity not in entities: - entities.append(entity) + for company_name in self.complist: + start_idx = text.find(company_name) + if start_idx != -1: # Wort gefunden + start_idx + len(company_name) + entity = company_name + if entity not in entities: + entities.append(entity) return dict(Counter(entities)) @@ -136,15 +139,18 @@ class NerAnalysisService: # init list for entities entities = [] text = doc[doc_attrib] - sentences = text.split(". ") # Split text into sentences based on '. ' + # check if text is a string and not empty + if isinstance(text, str) and text: + sentences = text.split(". ") # Split text into sentences based on '. ' - # Process each sentence separately - for sentence in sentences: - res = self.classifier( - sentence - ) # Assuming 'classifier' processes a single sentence at a time + # Process each sentence separately + for sentence in sentences: + res = self.classifier( + sentence + ) # Assuming 'classifier' processes a single sentence at a time + + for _ in res: + if _["entity_group"] == ent_type: + entities.append(_["word"]) - for i in range(len(res)): - if res[i]["entity_group"] == ent_type: - entities.append(res[i]["word"]) return dict(Counter(entities))