From 0672773551f0df2ffadf1f1574007771be6a44d4 Mon Sep 17 00:00:00 2001 From: Philipp Horstenkamp Date: Sat, 25 Nov 2023 16:53:10 +0100 Subject: [PATCH] Ner simplifications (#419) --- .../ai/ner_pipeline.py | 5 +-- .../ai/sentiment_pipeline.py | 1 + .../ui/pages/home.py | 33 ++++----------- .../utils/networkx/network_2d.py | 42 ++++++++----------- .../utils/networkx/network_3d.py | 9 ++-- 5 files changed, 31 insertions(+), 59 deletions(-) diff --git a/src/aki_prj23_transparenzregister/ai/ner_pipeline.py b/src/aki_prj23_transparenzregister/ai/ner_pipeline.py index 67ca5cf..c90f7c5 100644 --- a/src/aki_prj23_transparenzregister/ai/ner_pipeline.py +++ b/src/aki_prj23_transparenzregister/ai/ner_pipeline.py @@ -1,7 +1,6 @@ """Pipeline to get Entities from Staging DB.""" import os -import sys from typing import Literal, get_args from loguru import logger @@ -18,8 +17,6 @@ from aki_prj23_transparenzregister.config.config_providers import ( ner_methods = Literal["spacy", "company_list", "transformer"] doc_attribs = Literal["text", "title"] -logger.add(sys.stdout, colorize=True) - class EntityPipeline: """Class to initialize NER Pipeline.""" @@ -45,7 +42,7 @@ class EntityPipeline: {"companies": {"$exists": False}} ) documents = list(cursor_unprocessed) - logger.info("Dokumente: ", str(cursor_unprocessed)) + logger.info(f"Documents to be processed: {cursor_unprocessed}") # Determine NER service based on config # spaCy diff --git a/src/aki_prj23_transparenzregister/ai/sentiment_pipeline.py b/src/aki_prj23_transparenzregister/ai/sentiment_pipeline.py index 6b3deba..88c9a91 100644 --- a/src/aki_prj23_transparenzregister/ai/sentiment_pipeline.py +++ b/src/aki_prj23_transparenzregister/ai/sentiment_pipeline.py @@ -43,6 +43,7 @@ class SentimentPipeline: {"sentiment": {"$exists": False}} ) documents = list(cursor_unprocessed) + logger.info(f"Documents to be processed: {cursor_unprocessed}") if len(documents) > 0: for document in tqdm(documents): diff --git a/src/aki_prj23_transparenzregister/ui/pages/home.py b/src/aki_prj23_transparenzregister/ui/pages/home.py index 5ced1f6..d39e3d3 100644 --- a/src/aki_prj23_transparenzregister/ui/pages/home.py +++ b/src/aki_prj23_transparenzregister/ui/pages/home.py @@ -16,7 +16,6 @@ from aki_prj23_transparenzregister.utils.networkx.network_2d import ( create_2d_graph, ) from aki_prj23_transparenzregister.utils.networkx.network_3d import ( - # initialize_network, create_3d_graph, ) from aki_prj23_transparenzregister.utils.networkx.network_base import ( @@ -90,7 +89,7 @@ def _update_figure( # noqa: PLR0913 Returns: - Network Graph(Plotly Figure): Plotly Figure in 3 or 2D + Network Graph: Plotly Figure in 3 or 2D """ _ = c_relation_filter_value, p_relation_filter_value dims = 3 if layout.endswith("(3d)") else 2 @@ -103,26 +102,14 @@ def _update_figure( # noqa: PLR0913 table_dict, table_columns = update_table(metric_dropdown_value, metrics) - if dims == 2: # noqa: PLR2004 - return ( - table_dict, - table_columns, - create_2d_graph( - graph, - nodes, - edges, - metrics, - selected_metric, - layout, - switch_edge_annotation_value, - slider_value, # type: ignore - ), - ) - + plot_graph_dim_function = ( + create_2d_graph if dims == 2 else create_3d_graph # noqa: PLR2004 + ) + # noinspection PyTypeChecker return ( table_dict, table_columns, - create_3d_graph( + plot_graph_dim_function( graph, nodes, edges, @@ -139,12 +126,8 @@ def layout() -> list[html]: """Generates the Layout of the Homepage.""" person_relation_types = person_relation_type_filter() company_relation_types = company_relation_type_filter() - selected_company_relation_types: frozenset[str] = ( - frozenset( - {company_relation_types[1]} if company_relation_types else frozenset({}) - ) - if company_relation_types - else frozenset({}) + selected_company_relation_types: frozenset[str] = frozenset( + {company_relation_types[1]} if company_relation_types else {} ) selected_person_relation_types: frozenset[str] = frozenset( {} diff --git a/src/aki_prj23_transparenzregister/utils/networkx/network_2d.py b/src/aki_prj23_transparenzregister/utils/networkx/network_2d.py index e1edbef..3534fde 100644 --- a/src/aki_prj23_transparenzregister/utils/networkx/network_2d.py +++ b/src/aki_prj23_transparenzregister/utils/networkx/network_2d.py @@ -49,29 +49,29 @@ def create_2d_graph( # noqa PLR0913 raise ValueError(f'Unknown 2d layout "{layout}" requested.') # Initialize Variables to set the Position of the Edges. - edge_x = [] - edge_y = [] + edge_x, edge_y = [], [] + # Initialize Node Position Variables. - node_x = [] - node_y = [] + node_x, node_y = [], [] # Initialize Position Variables for the Description Text of the edges. - edge_weight_x = [] - edge_weight_y = [] + edge_weight_x, edge_weight_y = [], [] # Getting the Positions from NetworkX and assign them to the variables. for edge in graph.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] - edge_x.append(x0) - edge_x.append(x1) - edge_x.append(float("NaN")) - - edge_y.append(y0) - edge_y.append(y1) - edge_y.append(float("NaN")) + edge_x.extend([x0, x1, None]) + edge_y.extend([y0, y1, None]) edge_weight_x.append((x1 + x0) / 2) edge_weight_y.append((y1 + y0) / 2) + + # Getting the Positions from NetworkX and assign it to the variables. + for node in graph.nodes(): + x, y = pos[node] + node_x.append(x) + node_y.append(y) + # Add the Edges to the scatter plot according to their Positions. edge_trace = go.Scatter( x=edge_x, @@ -91,11 +91,6 @@ def create_2d_graph( # noqa PLR0913 hovertemplate="Relation: %{text}", ) - # Getting the Positions from NetworkX and assign it to the variables. - for node in graph.nodes(): - x, y = pos[node] - node_x.append(x) - node_y.append(y) # Add the Nodes to the scatter plot according to their Positions. node_trace = go.Scatter( x=node_x, @@ -112,7 +107,6 @@ def create_2d_graph( # noqa PLR0913 colors = list(nx.get_node_attributes(graph, "color").values()) node_names = list(nx.get_node_attributes(graph, "name").values()) # ids = list(nx.get_node_attributes(graph, "id").values()) - # print(ids) # # Get the Node Text # node_names = [] @@ -133,18 +127,16 @@ def create_2d_graph( # noqa PLR0913 # Add Relation_Type as a Description for the edges. if edge_annotation: - edge_type_list = [] - for row in edges: - edge_type_list.append(row["type"]) - + edge_type_list = [ + row["type"] for row in edges + ] # this code be moved and used as hover data edge_weights_trace.text = edge_type_list # Return the Plotly Figure return go.Figure( data=[edge_trace, edge_weights_trace, node_trace], layout=go.Layout( - title="
Network graph made with Python", - titlefont_size=16, + title="Network graph", showlegend=False, hovermode="closest", margin={"b": 20, "l": 5, "r": 5, "t": 20}, diff --git a/src/aki_prj23_transparenzregister/utils/networkx/network_3d.py b/src/aki_prj23_transparenzregister/utils/networkx/network_3d.py index db87c50..3e77da8 100644 --- a/src/aki_prj23_transparenzregister/utils/networkx/network_3d.py +++ b/src/aki_prj23_transparenzregister/utils/networkx/network_3d.py @@ -87,7 +87,7 @@ def create_3d_graph( # noqa : PLR0913 hoverinfo="none", ) - # Add the Edgedescriptiontext to the scatter plot according to its Position. + # Add the edge descriptions to the scatter plot according to its Position. edge_weights_trace = go.Scatter3d( x=edge_weight_x, y=edge_weight_y, @@ -186,10 +186,9 @@ def create_3d_graph( # noqa : PLR0913 # Add Relation_Type as a Description for the edges. if edge_annotation: - edge_type_list = [] - for row in edges: - edge_type_list.append(row["type"]) - + edge_type_list = [ + row["type"] for row in edges + ] # this code be moved and used as hover data edge_weights_trace.text = edge_type_list # Set Color by using the nodes DataFrame with its Color Attribute. The sequence matters!