Ner simplifications (#419)

This commit is contained in:
Philipp Horstenkamp 2023-11-25 16:53:10 +01:00 committed by GitHub
parent 3ed756c8e8
commit 0672773551
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 59 deletions

View File

@ -1,7 +1,6 @@
"""Pipeline to get Entities from Staging DB.""" """Pipeline to get Entities from Staging DB."""
import os import os
import sys
from typing import Literal, get_args from typing import Literal, get_args
from loguru import logger from loguru import logger
@ -18,8 +17,6 @@ from aki_prj23_transparenzregister.config.config_providers import (
ner_methods = Literal["spacy", "company_list", "transformer"] ner_methods = Literal["spacy", "company_list", "transformer"]
doc_attribs = Literal["text", "title"] doc_attribs = Literal["text", "title"]
logger.add(sys.stdout, colorize=True)
class EntityPipeline: class EntityPipeline:
"""Class to initialize NER Pipeline.""" """Class to initialize NER Pipeline."""
@ -45,7 +42,7 @@ class EntityPipeline:
{"companies": {"$exists": False}} {"companies": {"$exists": False}}
) )
documents = list(cursor_unprocessed) documents = list(cursor_unprocessed)
logger.info("Dokumente: ", str(cursor_unprocessed)) logger.info(f"Documents to be processed: {cursor_unprocessed}")
# Determine NER service based on config # Determine NER service based on config
# spaCy # spaCy

View File

@ -43,6 +43,7 @@ class SentimentPipeline:
{"sentiment": {"$exists": False}} {"sentiment": {"$exists": False}}
) )
documents = list(cursor_unprocessed) documents = list(cursor_unprocessed)
logger.info(f"Documents to be processed: {cursor_unprocessed}")
if len(documents) > 0: if len(documents) > 0:
for document in tqdm(documents): for document in tqdm(documents):

View File

@ -16,7 +16,6 @@ from aki_prj23_transparenzregister.utils.networkx.network_2d import (
create_2d_graph, create_2d_graph,
) )
from aki_prj23_transparenzregister.utils.networkx.network_3d import ( from aki_prj23_transparenzregister.utils.networkx.network_3d import (
# initialize_network,
create_3d_graph, create_3d_graph,
) )
from aki_prj23_transparenzregister.utils.networkx.network_base import ( from aki_prj23_transparenzregister.utils.networkx.network_base import (
@ -90,7 +89,7 @@ def _update_figure( # noqa: PLR0913
Returns: Returns:
Network Graph(Plotly Figure): Plotly Figure in 3 or 2D Network Graph: Plotly Figure in 3 or 2D
""" """
_ = c_relation_filter_value, p_relation_filter_value _ = c_relation_filter_value, p_relation_filter_value
dims = 3 if layout.endswith("(3d)") else 2 dims = 3 if layout.endswith("(3d)") else 2
@ -103,26 +102,14 @@ def _update_figure( # noqa: PLR0913
table_dict, table_columns = update_table(metric_dropdown_value, metrics) table_dict, table_columns = update_table(metric_dropdown_value, metrics)
if dims == 2: # noqa: PLR2004 plot_graph_dim_function = (
return ( create_2d_graph if dims == 2 else create_3d_graph # noqa: PLR2004
table_dict, )
table_columns, # noinspection PyTypeChecker
create_2d_graph(
graph,
nodes,
edges,
metrics,
selected_metric,
layout,
switch_edge_annotation_value,
slider_value, # type: ignore
),
)
return ( return (
table_dict, table_dict,
table_columns, table_columns,
create_3d_graph( plot_graph_dim_function(
graph, graph,
nodes, nodes,
edges, edges,
@ -139,12 +126,8 @@ def layout() -> list[html]:
"""Generates the Layout of the Homepage.""" """Generates the Layout of the Homepage."""
person_relation_types = person_relation_type_filter() person_relation_types = person_relation_type_filter()
company_relation_types = company_relation_type_filter() company_relation_types = company_relation_type_filter()
selected_company_relation_types: frozenset[str] = ( selected_company_relation_types: frozenset[str] = frozenset(
frozenset( {company_relation_types[1]} if company_relation_types else {}
{company_relation_types[1]} if company_relation_types else frozenset({})
)
if company_relation_types
else frozenset({})
) )
selected_person_relation_types: frozenset[str] = frozenset( selected_person_relation_types: frozenset[str] = frozenset(
{} {}

View File

@ -49,29 +49,29 @@ def create_2d_graph( # noqa PLR0913
raise ValueError(f'Unknown 2d layout "{layout}" requested.') raise ValueError(f'Unknown 2d layout "{layout}" requested.')
# Initialize Variables to set the Position of the Edges. # Initialize Variables to set the Position of the Edges.
edge_x = [] edge_x, edge_y = [], []
edge_y = []
# Initialize Node Position Variables. # Initialize Node Position Variables.
node_x = [] node_x, node_y = [], []
node_y = []
# Initialize Position Variables for the Description Text of the edges. # Initialize Position Variables for the Description Text of the edges.
edge_weight_x = [] edge_weight_x, edge_weight_y = [], []
edge_weight_y = []
# Getting the Positions from NetworkX and assign them to the variables. # Getting the Positions from NetworkX and assign them to the variables.
for edge in graph.edges(): for edge in graph.edges():
x0, y0 = pos[edge[0]] x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]] x1, y1 = pos[edge[1]]
edge_x.append(x0) edge_x.extend([x0, x1, None])
edge_x.append(x1) edge_y.extend([y0, y1, None])
edge_x.append(float("NaN"))
edge_y.append(y0)
edge_y.append(y1)
edge_y.append(float("NaN"))
edge_weight_x.append((x1 + x0) / 2) edge_weight_x.append((x1 + x0) / 2)
edge_weight_y.append((y1 + y0) / 2) edge_weight_y.append((y1 + y0) / 2)
# Getting the Positions from NetworkX and assign it to the variables.
for node in graph.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
# Add the Edges to the scatter plot according to their Positions. # Add the Edges to the scatter plot according to their Positions.
edge_trace = go.Scatter( edge_trace = go.Scatter(
x=edge_x, x=edge_x,
@ -91,11 +91,6 @@ def create_2d_graph( # noqa PLR0913
hovertemplate="Relation: %{text}<extra></extra>", hovertemplate="Relation: %{text}<extra></extra>",
) )
# Getting the Positions from NetworkX and assign it to the variables.
for node in graph.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
# Add the Nodes to the scatter plot according to their Positions. # Add the Nodes to the scatter plot according to their Positions.
node_trace = go.Scatter( node_trace = go.Scatter(
x=node_x, x=node_x,
@ -112,7 +107,6 @@ def create_2d_graph( # noqa PLR0913
colors = list(nx.get_node_attributes(graph, "color").values()) colors = list(nx.get_node_attributes(graph, "color").values())
node_names = list(nx.get_node_attributes(graph, "name").values()) node_names = list(nx.get_node_attributes(graph, "name").values())
# ids = list(nx.get_node_attributes(graph, "id").values()) # ids = list(nx.get_node_attributes(graph, "id").values())
# print(ids)
# # Get the Node Text # # Get the Node Text
# node_names = [] # node_names = []
@ -133,18 +127,16 @@ def create_2d_graph( # noqa PLR0913
# Add Relation_Type as a Description for the edges. # Add Relation_Type as a Description for the edges.
if edge_annotation: if edge_annotation:
edge_type_list = [] edge_type_list = [
for row in edges: row["type"] for row in edges
edge_type_list.append(row["type"]) ] # this code be moved and used as hover data
edge_weights_trace.text = edge_type_list edge_weights_trace.text = edge_type_list
# Return the Plotly Figure # Return the Plotly Figure
return go.Figure( return go.Figure(
data=[edge_trace, edge_weights_trace, node_trace], data=[edge_trace, edge_weights_trace, node_trace],
layout=go.Layout( layout=go.Layout(
title="<br>Network graph made with Python", title="Network graph",
titlefont_size=16,
showlegend=False, showlegend=False,
hovermode="closest", hovermode="closest",
margin={"b": 20, "l": 5, "r": 5, "t": 20}, margin={"b": 20, "l": 5, "r": 5, "t": 20},

View File

@ -87,7 +87,7 @@ def create_3d_graph( # noqa : PLR0913
hoverinfo="none", hoverinfo="none",
) )
# Add the Edgedescriptiontext to the scatter plot according to its Position. # Add the edge descriptions to the scatter plot according to its Position.
edge_weights_trace = go.Scatter3d( edge_weights_trace = go.Scatter3d(
x=edge_weight_x, x=edge_weight_x,
y=edge_weight_y, y=edge_weight_y,
@ -186,10 +186,9 @@ def create_3d_graph( # noqa : PLR0913
# Add Relation_Type as a Description for the edges. # Add Relation_Type as a Description for the edges.
if edge_annotation: if edge_annotation:
edge_type_list = [] edge_type_list = [
for row in edges: row["type"] for row in edges
edge_type_list.append(row["type"]) ] # this code be moved and used as hover data
edge_weights_trace.text = edge_type_list edge_weights_trace.text = edge_type_list
# Set Color by using the nodes DataFrame with its Color Attribute. The sequence matters! # Set Color by using the nodes DataFrame with its Color Attribute. The sequence matters!