Ner simplifications (#419)

This commit is contained in:
Philipp Horstenkamp 2023-11-25 16:53:10 +01:00 committed by GitHub
parent 3ed756c8e8
commit 0672773551
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 59 deletions

View File

@ -1,7 +1,6 @@
"""Pipeline to get Entities from Staging DB."""
import os
import sys
from typing import Literal, get_args
from loguru import logger
@ -18,8 +17,6 @@ from aki_prj23_transparenzregister.config.config_providers import (
ner_methods = Literal["spacy", "company_list", "transformer"]
doc_attribs = Literal["text", "title"]
logger.add(sys.stdout, colorize=True)
class EntityPipeline:
"""Class to initialize NER Pipeline."""
@ -45,7 +42,7 @@ class EntityPipeline:
{"companies": {"$exists": False}}
)
documents = list(cursor_unprocessed)
logger.info("Dokumente: ", str(cursor_unprocessed))
logger.info(f"Documents to be processed: {cursor_unprocessed}")
# Determine NER service based on config
# spaCy

View File

@ -43,6 +43,7 @@ class SentimentPipeline:
{"sentiment": {"$exists": False}}
)
documents = list(cursor_unprocessed)
logger.info(f"Documents to be processed: {cursor_unprocessed}")
if len(documents) > 0:
for document in tqdm(documents):

View File

@ -16,7 +16,6 @@ from aki_prj23_transparenzregister.utils.networkx.network_2d import (
create_2d_graph,
)
from aki_prj23_transparenzregister.utils.networkx.network_3d import (
# initialize_network,
create_3d_graph,
)
from aki_prj23_transparenzregister.utils.networkx.network_base import (
@ -90,7 +89,7 @@ def _update_figure( # noqa: PLR0913
Returns:
Network Graph(Plotly Figure): Plotly Figure in 3 or 2D
Network Graph: Plotly Figure in 3 or 2D
"""
_ = c_relation_filter_value, p_relation_filter_value
dims = 3 if layout.endswith("(3d)") else 2
@ -103,26 +102,14 @@ def _update_figure( # noqa: PLR0913
table_dict, table_columns = update_table(metric_dropdown_value, metrics)
if dims == 2: # noqa: PLR2004
return (
table_dict,
table_columns,
create_2d_graph(
graph,
nodes,
edges,
metrics,
selected_metric,
layout,
switch_edge_annotation_value,
slider_value, # type: ignore
),
)
plot_graph_dim_function = (
create_2d_graph if dims == 2 else create_3d_graph # noqa: PLR2004
)
# noinspection PyTypeChecker
return (
table_dict,
table_columns,
create_3d_graph(
plot_graph_dim_function(
graph,
nodes,
edges,
@ -139,12 +126,8 @@ def layout() -> list[html]:
"""Generates the Layout of the Homepage."""
person_relation_types = person_relation_type_filter()
company_relation_types = company_relation_type_filter()
selected_company_relation_types: frozenset[str] = (
frozenset(
{company_relation_types[1]} if company_relation_types else frozenset({})
)
if company_relation_types
else frozenset({})
selected_company_relation_types: frozenset[str] = frozenset(
{company_relation_types[1]} if company_relation_types else {}
)
selected_person_relation_types: frozenset[str] = frozenset(
{}

View File

@ -49,29 +49,29 @@ def create_2d_graph( # noqa PLR0913
raise ValueError(f'Unknown 2d layout "{layout}" requested.')
# Initialize Variables to set the Position of the Edges.
edge_x = []
edge_y = []
edge_x, edge_y = [], []
# Initialize Node Position Variables.
node_x = []
node_y = []
node_x, node_y = [], []
# Initialize Position Variables for the Description Text of the edges.
edge_weight_x = []
edge_weight_y = []
edge_weight_x, edge_weight_y = [], []
# Getting the Positions from NetworkX and assign them to the variables.
for edge in graph.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.append(x0)
edge_x.append(x1)
edge_x.append(float("NaN"))
edge_y.append(y0)
edge_y.append(y1)
edge_y.append(float("NaN"))
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_weight_x.append((x1 + x0) / 2)
edge_weight_y.append((y1 + y0) / 2)
# Getting the Positions from NetworkX and assign it to the variables.
for node in graph.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
# Add the Edges to the scatter plot according to their Positions.
edge_trace = go.Scatter(
x=edge_x,
@ -91,11 +91,6 @@ def create_2d_graph( # noqa PLR0913
hovertemplate="Relation: %{text}<extra></extra>",
)
# Getting the Positions from NetworkX and assign it to the variables.
for node in graph.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
# Add the Nodes to the scatter plot according to their Positions.
node_trace = go.Scatter(
x=node_x,
@ -112,7 +107,6 @@ def create_2d_graph( # noqa PLR0913
colors = list(nx.get_node_attributes(graph, "color").values())
node_names = list(nx.get_node_attributes(graph, "name").values())
# ids = list(nx.get_node_attributes(graph, "id").values())
# print(ids)
# # Get the Node Text
# node_names = []
@ -133,18 +127,16 @@ def create_2d_graph( # noqa PLR0913
# Add Relation_Type as a Description for the edges.
if edge_annotation:
edge_type_list = []
for row in edges:
edge_type_list.append(row["type"])
edge_type_list = [
row["type"] for row in edges
] # this code be moved and used as hover data
edge_weights_trace.text = edge_type_list
# Return the Plotly Figure
return go.Figure(
data=[edge_trace, edge_weights_trace, node_trace],
layout=go.Layout(
title="<br>Network graph made with Python",
titlefont_size=16,
title="Network graph",
showlegend=False,
hovermode="closest",
margin={"b": 20, "l": 5, "r": 5, "t": 20},

View File

@ -87,7 +87,7 @@ def create_3d_graph( # noqa : PLR0913
hoverinfo="none",
)
# Add the Edgedescriptiontext to the scatter plot according to its Position.
# Add the edge descriptions to the scatter plot according to its Position.
edge_weights_trace = go.Scatter3d(
x=edge_weight_x,
y=edge_weight_y,
@ -186,10 +186,9 @@ def create_3d_graph( # noqa : PLR0913
# Add Relation_Type as a Description for the edges.
if edge_annotation:
edge_type_list = []
for row in edges:
edge_type_list.append(row["type"])
edge_type_list = [
row["type"] for row in edges
] # this code be moved and used as hover data
edge_weights_trace.text = edge_type_list
# Set Color by using the nodes DataFrame with its Color Attribute. The sequence matters!