1
0

Added pyrate as a direct dependency.

This commit is contained in:
2022-07-11 23:07:33 +02:00
parent 8c4532dad4
commit c99d517f6f
230 changed files with 21114 additions and 0 deletions

View File

@ -0,0 +1 @@
"""This file is required to mark the directory as a package for the documentation to be able to include it."""

View File

@ -0,0 +1,149 @@
#!/usr/bin/env python3
"""This scripts benchmarks both the database reading accesses as well as the
projections between polar and cartesian representations.
The function :func:`_query_database` queries the database for some location and radius which were deemed a
realistic load scenario.
The function :func:`_project_to_cartesian_and_back` takes the result of such a query and projects it to the
cartesian representation and back.
Keep in mind, that these operations might only be performed very seldom on an actual Atlantic crossing
(maybe every couple of hours).
The results and details are now included in :ref:`benchmarking-db-and-local-projections`
(in the documentation of the :mod:`pyrate.plan` module).
This script was initially developed as part of
`issue #40 <https://gitlab.sailingteam.hg.tu-darmstadt.de/informatik/pyrate/-/issues/40>`__,
in order to evaluate whether "custom" local projections are a feasible option on a Raspberry Pi 3B/4B.
Since then, the implementation has changed.
In particular, the database creation as been moved into a separate script
(see :ref:`script-s57_charts_to_db`).
"""
# Standard library
from argparse import ArgumentDefaultsHelpFormatter
from argparse import ArgumentParser
from time import perf_counter
# Typing
from typing import Any
from typing import Callable
from typing import List
# Data modeling
import numpy
# Geospatial
from pyrate.plan.geometry import PolarGeometry
from pyrate.plan.geometry import PolarLocation
# Database
from pyrate.common.charts import SpatialiteDatabase
#: Around Miami, Florida, US.
#: The point was chosen simply because charts are available nearby.
QUERY_AROUND = PolarLocation(longitude=-80.10955810546875, latitude=25.851808634972723)
def _query_database(path_to_db: str, around: PolarLocation, radius: float) -> List[PolarGeometry]:
"""Queries some polygons from the database.
Args:
path_to_db: The path to the database
around: The location around which to query for chart objects
radius: The radius within which to query for chart objects in meters
Returns:
The resulting polygons
"""
with SpatialiteDatabase(path_to_db) as database:
return list(database.read_geometries_around(around=around, radius=radius))
def _project_to_cartesian_and_back(data: List[PolarGeometry]) -> None:
"""Projects some PolarPolygons to their cartesian representation and back to test the performance.
Args:
data: Some polygons to project
"""
assert data # non-emptiness
first = data[0]
center = first if isinstance(first, PolarLocation) else first.locations[0]
for polygon in data:
polygon.to_cartesian(center).to_polar()
def _measure_func(func: Callable[..., Any], name: str, iterations: int, *params, **kw_params) -> None:
"""Measures and prints the running time of a given method.
Args:
func: The callable to execute
name: The name to use for printing
iterations: The number of iterations to average over
*params: Positional arguments to be passed to the callable
**kw_params: Keyword arguments to be passed to the callable
"""
results = numpy.empty((iterations,))
for i in range(iterations):
start = perf_counter()
func(*params, **kw_params)
end = perf_counter()
results[i] = end - start
print(f'Executed "{name}" {iterations} times:')
print(f"\taverage:\t {numpy.mean(results):.6f} seconds")
print(f"\tstd dev:\t {numpy.std(results):.6f} seconds")
print(f"\tvariance:\t {numpy.var(results):.6f} seconds")
def benchmark(path_to_db: str, iterations: int, around: PolarLocation, radius: float) -> None:
"""Performs the benchmark and prints the results to the console.
Args:
path_to_db: The path to the database
iterations: The number of iterations to average over
around: The location around which to query for chart objects
radius: The radius within which to query for chart objects in meters
"""
print("Information on the setting:")
with SpatialiteDatabase(path_to_db) as database:
print(f"\tnumber of rows/polygons in database:\t\t\t {len(database)}")
print(f"\tsum of vertices of all rows/polygons of in database:\t {database.count_vertices()}")
data = _query_database(path_to_db, around, radius)
print(f"\textracted number of polygons:\t\t\t\t {len(data)}")
vertex_count = sum(1 if isinstance(poly, PolarLocation) else len(poly.locations) for poly in data)
print(f"\textracted total number of vertices:\t\t\t {vertex_count}")
print() # newline
_measure_func(_query_database, "query_database", iterations, path_to_db, around, radius)
print() # newline
_measure_func(_project_to_cartesian_and_back, "project_to_cartesian_and_back", iterations, data)
def _main() -> None:
"""The main function."""
parser = ArgumentParser(
description="Benchmark DB queries and projections for a fixed location (see docs/scripts).",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument("path_to_db", type=str)
parser.add_argument("--iterations", type=int, default=10)
parser.add_argument("--radius", type=float, default=100, help="The query radius in kilometers")
args = parser.parse_args()
benchmark(
path_to_db=args.path_to_db, iterations=args.iterations, around=QUERY_AROUND, radius=args.radius * 1000
)
if __name__ == "__main__":
_main()

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""Benchmark the neighbor search in graphs.
Initially written as part of
`Issue #90 <https://gitlab.sailingteam.hg.tu-darmstadt.de/informatik/pyrate/-/issues/90>`__ to determine
whether a faster implementation is needed.
That issue also contains a draft which *might* make it faster if that is required in the future.
Examples:
These are the benchmark results when run on a `Lenovo ThinkPad T560 laptop <https://thinkwiki.de/T560>`__
with an `Intel(R) Core(TM) i5-6300U CPU @ 2.40GHz
<https://ark.intel.com/content/www/de/de/ark/products/88190/intel-core-i5-6300u-processor-3m-cache-up-to-3-00-ghz.html>`__
and 16GB RAM (at commit ``9a8177326dc0d82d0aea4559e6c85071ceebf56f``):
.. code-block:: bash
./scripts/benchmark_graph_neighbor_search.py --iterations 100
frequency = 2 for distance 5000 km
generated graph in 0.018192768096923828 seconds
number of nodes = 42, number of edges = 120
non-empty entries in neighbor table = 240
computation time = 0.0003081770000221695 (avg. over 100 samples)
frequency = 8 for distance 1000 km
generated graph in 0.0327601432800293 seconds
number of nodes = 642, number of edges = 1920
non-empty entries in neighbor table = 3840
computation time = 0.0033796210000218707 (avg. over 100 samples)
frequency = 71 for distance 100 km
generated graph in 1.8711962699890137 seconds
number of nodes = 50412, number of edges = 151230
non-empty entries in neighbor table = 302460
computation time = 0.30925760600001695 (avg. over 100 samples)
frequency = 142 for distance 50 km
generated graph in 7.630561828613281 seconds
number of nodes = 201642, number of edges = 604920
non-empty entries in neighbor table = 1209840
computation time = 1.1302456550000102 (avg. over 100 samples)
frequency = 706 for distance 10 km
generated graph in 260.7689461708069 seconds
number of nodes = 4984362, number of edges = 14953080
non-empty entries in neighbor table = 29906160
computation time = 27.382845137000004 (avg. over 100 samples)
"""
# Standard library
from argparse import ArgumentDefaultsHelpFormatter
from argparse import ArgumentParser
from time import time
from timeit import timeit
# Scientific
import numpy
# Graph
from pyrate.plan.graph.generate import create_earth_graph
from pyrate.plan.graph.generate import min_required_frequency
from pyrate.plan.graph import NavigationGraph
def _main() -> None:
"""The main function."""
parser = ArgumentParser(
description="Benchmark the neighbor search in graphs.", formatter_class=ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--iterations", type=int, default=100, help="the number of timing samples to collect per graph size"
)
args = parser.parse_args()
for distance_km in [5000, 1000, 100, 50, 10]:
frequency = min_required_frequency(distance_km * 1000, in_meters=True)
print(f"frequency = {frequency} for distance {distance_km} km")
time_before_generation = time()
graph = create_earth_graph(frequency, print_status=False)
print(f"generated graph in {time() - time_before_generation} seconds")
print(f"number of nodes = {len(graph)}, number of edges = {graph.num_edges}")
def setup(local_graph: NavigationGraph = graph) -> None:
local_graph._neighbors = None # pylint: disable=protected-access
def statement(local_graph: NavigationGraph = graph) -> None:
_ = local_graph.neighbors
avg_time = timeit(setup=setup, stmt=statement, number=args.iterations)
print(f"non-empty entries in neighbor table = {numpy.count_nonzero(graph.neighbors != -1)}")
print(f"computation time = {avg_time} (avg. over {args.iterations} samples)")
print() # empty line between distances
if __name__ == "__main__":
_main()

View File

@ -0,0 +1,64 @@
#!/usr/bin/env python3
"""Create a database from a given GeoJSON input file. Intended to quickly create test databases.
It assumes the same structure as the one generated by `geojson.io <https://geojson.io/>`__ and only supports
polygons.
"""
# Standard library
from argparse import ArgumentDefaultsHelpFormatter
from argparse import ArgumentParser
import json
# Database and charts
from typing import Generator
# Math
from numpy import array
# Pyrate
from pyrate.common.charts import SpatialiteDatabase
from pyrate.plan.geometry import LocationType
from pyrate.plan.geometry import PolarPolygon
def read_geojson(
path: str, location_type: LocationType = LocationType.LAND
) -> Generator[PolarPolygon, None, None]:
"""Reads a GeoJSON file (only supports specific constructs, see module documentation).
Args:
path: the input file
location_type: the location type of all chart objects
"""
with open(path, "r", encoding="utf-8") as input_file:
json_data = json.load(input_file)
assert json_data["type"] == "FeatureCollection"
for feature in json_data["features"]:
assert feature["type"] == "Feature"
geometry = feature["geometry"]
assert geometry["type"] == "Polygon"
coordinates = geometry["coordinates"]
assert len(coordinates) == 1, "the polygon may have exactly one exterior and zero interior rings"
exterior = coordinates[0]
yield PolarPolygon.from_numpy(array(exterior), location_type=location_type)
def _main() -> None:
"""The main function."""
parser = ArgumentParser(description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument(
"path_to_geojson", type=str, help='The input file, usually ends with ".json", UTF-8 encoding'
)
parser.add_argument("path_to_db", type=str, help='The output file, usually ends with ".sqlite"')
args = parser.parse_args()
with SpatialiteDatabase(args.path_to_db) as database:
database.write_geometries(read_geojson(args.path_to_geojson), update=True)
if __name__ == "__main__":
_main()

View File

@ -0,0 +1,145 @@
#!/usr/bin/env python3
"""
Generated the spherical graph that can then be used to navigate.
The graph gets serialized to disk at the end of the calculation.
Examples:
Generate a graph and visualize it. Because it makes using it in search algorithms faster, we also include
the neighbor table. the also prunes by default using the *Earth2014* dataset (variant *TBI*, 1 arc-min
resolution).
.. code-block:: bash
./scripts/create_earth_graph.py 500000 earth_graph_500_km.hdf5
"""
# Standard library
from argparse import ArgumentDefaultsHelpFormatter
from argparse import ArgumentParser
import os.path
# Data set access
from pyrate.common.raster_datasets import DataSetAccess
from pyrate.common.raster_datasets import transformers_concrete
# Graph generation
from pyrate.plan.graph import create_earth_graph
from pyrate.plan.graph import min_required_frequency
# Script visualize_earth_graph
try:
from visualize_earth_graph import dump_2d_plots
except ImportError:
# add scripts folder
import sys
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
del sys
from visualize_earth_graph import dump_2d_plots
def calculate_and_save( # pylint: disable=too-many-arguments
requested_distance: float,
out_graph_file: str,
prune_land_areas: bool,
bathymetric_dataset: str,
generate_neighbors: bool,
dump_plots: bool,
out_visualization_directory: str,
) -> None:
"""Calculates and saves the graph with the given node distance while performing some logging.
Args:
requested_distance: the maximum distance between two neighboring nodes, in meters
out_graph_file: the target file where to save the graph to; usually end in ``.hdf5``
prune_land_areas: whether to prune by land areas
bathymetric_dataset: the path to the bathymetric dataset if parameter ``prune_land_areas`` is set to
``True``; e.g. the Earth2014 dataset with depth in meters
generate_neighbors: whether to generate and serialize all neighbors too
dump_plots: whether to dump all plots with the default config using
:mod:`scripts.visualize_earth_graph` after completing the graph generation
out_visualization_directory: the target directory (may not yet exist) where to save the visualizations
to if parameter ``dump_plots`` is set to ``True``
"""
print("Starting generation of earth graph")
graph = create_earth_graph(min_required_frequency(requested_distance, in_meters=True))
if prune_land_areas:
print("Pruning graph")
# generate the "keep condition" and then remove the property afterwards
data_set = DataSetAccess(bathymetric_dataset)
mode = transformers_concrete.BathymetricTransformer.Modes.FRACTION_NAVIGABLE
graph.append_properties(
[transformers_concrete.BathymetricTransformer(data_set, [mode])], show_progress=True
)
# keep all nodes that have more than 0% (i.e. that have any) navigable locations
keep_condition = graph.node_properties[mode.column_name] >= 0.0
graph.clear_node_properties()
graph.prune_nodes(keep_condition.to_numpy())
if generate_neighbors:
print("Generating neighbor table")
_ = graph.neighbors
print("Completed generation of earth graph")
print(f'Serializing to disk: "{out_graph_file}"')
os.makedirs(os.path.dirname(out_graph_file) or ".", exist_ok=True)
graph.to_disk(out_graph_file)
if dump_plots:
print(f"Dumping visualizations: {out_visualization_directory}")
dump_2d_plots(graph, out_visualization_directory)
def _main() -> None:
"""The main function."""
parser = ArgumentParser(
description="Create and serialize a graph of the earth. "
"Optionally perform pruning of land area, neighbor discovery and "
"dumping of visualizations.",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument("requested_distance", type=float, help="The max. distance between nodes in meters")
parser.add_argument("out_graph_file", type=str)
parser.add_argument("--prune_land_areas", type=bool, default=True)
parser.add_argument(
"--bathymetric_dataset",
type=str,
default="../data/topography/earth2014/Earth2014.TBI2014.1min.geod.geo.tif",
)
parser.add_argument("--generate_neighbors", type=bool, default=True)
parser.add_argument("--dump_plots", type=bool, default=True)
parser.add_argument(
"--visualization_directory",
type=str,
default=None,
help='default: "dirname(out_graph_file)/visualization/"',
)
args = parser.parse_args()
out_visualization_directory = (
args.visualization_directory
if args.visualization_directory is not None
else os.path.join(os.path.dirname(args.out_graph_file), "visualization")
)
calculate_and_save(
requested_distance=args.requested_distance,
out_graph_file=args.out_graph_file,
prune_land_areas=args.prune_land_areas,
bathymetric_dataset=args.bathymetric_dataset,
generate_neighbors=args.generate_neighbors,
dump_plots=args.dump_plots,
out_visualization_directory=out_visualization_directory,
)
if __name__ == "__main__":
_main()

View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
"""
Compute a statistics table for earth-generation with multiple frequencies.
.. _script-earth_graph_frequency_statistics-example:
Examples:
.. code-block:: bash
./scripts/earth_graph_frequency_statistics.py 300 --step 10
Frequency Great Circle Distance (km) # nodes # edges Computation time (sec)
10 705.365422 1002 3000 0.229929
20 352.682711 4002 12000 0.531473
30 235.121807 9002 27000 0.582267
40 176.341356 16002 48000 0.998187
50 141.073084 25002 75000 1.559254
60 117.560904 36002 108000 2.264499
70 100.766489 49002 147000 3.095882
80 88.170678 64002 192000 6.268704
90 78.373936 81002 243000 5.790993
100 70.536542 100002 300000 6.933162
110 64.124129 121002 363000 9.229682
120 58.780452 144002 432000 10.617164
130 54.258879 169002 507000 16.519117
140 50.383244 196002 588000 15.838989
150 47.024361 225002 675000 21.517596
160 44.085339 256002 768000 24.864843
170 41.492084 289002 867000 30.145898
180 39.186968 324002 972000 28.684602
190 37.124496 361002 1083000 27.561629
200 35.268271 400002 1200000 33.667006
210 33.58883 441002 1323000 34.471024
220 32.062065 484002 1452000 37.554208
230 30.668062 529002 1587000 43.1782
240 29.390226 576002 1728000 46.112764
250 28.214617 625002 1875000 44.472765
260 27.129439 676002 2028000 53.084822
270 26.124645 729002 2187000 60.2798
280 25.191622 784002 2352000 63.117184
290 24.322946 841002 2523000 64.914421
300 23.512181 900002 2700000 68.890113
"""
# Standard library
from argparse import ArgumentDefaultsHelpFormatter
from argparse import ArgumentParser
from time import perf_counter
# Typing
from typing import List
from typing import Tuple
# Scientific
import numpy
import pandas
# Earth graph calculation
from pyrate.plan.graph import create_earth_graph
from pyrate.plan.graph import great_circle_distance_distance_for
def _main() -> None:
"""The main function."""
parser = ArgumentParser(
description="Compute a statistics table for earth-generation with multiple frequencies.",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument("max", type=int, help="the maximum frequency to test")
parser.add_argument(
"--step", type=int, default=10, help="how large the steps while increasing the frequencies should be"
)
args = parser.parse_args()
pandas.set_option("display.max_columns", None)
pandas.set_option("display.max_rows", None)
columns = [
("Frequency", numpy.uint),
("Great Circle Distance (km)", numpy.float64),
("# nodes", numpy.uint),
("# edges", numpy.uint),
("Computation time (sec)", numpy.float64),
]
records: List[Tuple] = []
for frequency in range(0, args.max + 1, args.step):
if frequency == 0:
continue # better steps when starting at zero
start = perf_counter()
graph = create_earth_graph(frequency)
end = perf_counter()
records.append(
(
frequency,
great_circle_distance_distance_for(frequency) / 1000,
len(graph),
graph.num_edges,
end - start,
)
)
# re-creating this is inefficient, but it does not matter for small sizes
data_frame = pandas.DataFrame.from_records(numpy.array(records, dtype=columns))
string = data_frame.iloc[[-1]].to_string(index=False, justify="right")
if len(data_frame) == 1: # only the first time
print(string)
else:
print(string.splitlines()[-1])
if __name__ == "__main__":
_main()

View File

@ -0,0 +1,72 @@
#!/usr/bin/env python3
"""Create a database containing all S-57 charts in the given directory.
Optionally simplifies the geometries before saving them.
See "S57ChartHandler" for supported nautical chart features.
"""
# Standard library
from argparse import ArgumentDefaultsHelpFormatter
from argparse import ArgumentParser
# Typing
from typing import Optional
# Progress bar
from tqdm import tqdm
# Database and charts
from pyrate.common.charts import S57ChartHandler
from pyrate.common.charts import SpatialiteDatabase
def create_db(path_to_raw_charts: str, path_to_db: str, simplify_tolerance: Optional[float] = None) -> None:
"""Creates a database from all charts in the given directory.
Args:
path_to_raw_charts: the path where to look for the source chart files
path_to_db: the file of the target database
simplify_tolerance: the tolerance within all new points shall lie wrt. to the old ones, in meters,
non-negative
"""
files = list(S57ChartHandler.find_chart_files(path_to_raw_charts))
print("Scanned for relevant files")
with SpatialiteDatabase(path_to_db) as database:
with database.disable_synchronization():
if len(database) != 0:
raise RuntimeError("writing to an already existing database, which might be an error")
print("Created database")
handler = S57ChartHandler()
for file in tqdm(files, unit=" files"):
database.write_geometries(handler.read_chart_file(file), update=False, raise_on_failure=False)
if simplify_tolerance is not None:
vertices_before = database.count_vertices()
database.simplify_contents(simplify_tolerance)
vertices_after = database.count_vertices()
change = (vertices_before - vertices_after) / vertices_before * 100
print(f"Reduced the vertex count from {vertices_before} to {vertices_after} (-{change:.3f}%)")
def _main() -> None:
"""The main function."""
parser = ArgumentParser(description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("path_to_raw_charts", type=str, help="will be searched recursively")
parser.add_argument("path_to_db", type=str, help='usually ends with ".sqlite"')
parser.add_argument(
"--simplify_tolerance",
type=float,
default=25.0,
help="the simplification tolerance in meters, set to zero to disable",
)
args = parser.parse_args()
simplify_tolerance = None if args.simplify_tolerance == 0.0 else args.simplify_tolerance
create_db(args.path_to_raw_charts, args.path_to_db, simplify_tolerance)
if __name__ == "__main__":
_main()

View File

@ -0,0 +1,252 @@
#!/usr/bin/env python3
"""
Visualizes a graph generated by :func:`pyrate.plan.graph.generate.create_earth_graph`, like the one created by
the script :ref:`script-create_earth_graph`.
Examples:
Simply plot the node positions as well as graphs of all properties into a local folder.
This will overwrite existing plots and create the target directory if it does not already exist:
.. code-block:: bash
./scripts/visualize_earth_graph.py my_graph.hdf5 visualization/
"""
# Standard library
from argparse import ArgumentParser
import os.path
# Typing
from typing import Any
from typing import Dict
from typing import Optional
from typing import Sequence
# Scientific
import matplotlib.pyplot as plt
import numpy as np
import scipy.interpolate
# Progress bars
from tqdm import tqdm
# Geography
from cartopy.crs import PlateCarree
from cartopy.crs import Robinson
# Own Pyrate code
from pyrate.plan.graph import GeoNavigationGraph
# pylint: disable=too-many-arguments,too-many-locals
def _prepare_2d_plot(
central_longitude: float = 0.0,
show_gridlines: bool = True,
show_gridline_labels: bool = False,
show_coastlines: bool = False,
) -> plt.Axes:
"""Prepares a 2D plot for visualizing the graph positions or property data.
Args:
central_longitude: the central longitude of the projection, in degrees in ``[-180, +180)``
show_gridlines: whether to overlay a grid
show_gridline_labels: whether to print labels to a grid (if drawn at all)
show_coastlines: whether to outline coastlines
Returns:
The correctly configured axes object
"""
# create frame plot
if show_gridlines and show_gridline_labels:
# Robinson would be nicer but is not supported by matplotlib.Axis.gridlines()
coordinate_reference = PlateCarree(central_longitude=central_longitude)
else:
coordinate_reference = Robinson(central_longitude=central_longitude)
axes = plt.axes(projection=coordinate_reference)
# create background
if show_coastlines:
# the resolution may allow only a few specific values:
# https://scitools.org.uk/cartopy/docs/latest/matplotlib/geoaxes.html#cartopy.mpl.geoaxes.GeoAxes.coastlines
axes.coastlines(resolution="110m", color="black")
if show_gridlines:
axes.gridlines(crs=coordinate_reference, draw_labels=show_gridline_labels)
axes.set_xlim(-180.0, +180.0)
axes.set_ylim(-90.0, +90.0)
return axes
def plot_node_positions_2d(
graph: GeoNavigationGraph,
central_longitude: float = 0.0,
show_gridlines: bool = True,
show_gridline_labels: bool = False,
show_coastlines: bool = False,
) -> plt.Axes:
"""Visualizes the positions of the nodes on the globe.
Args:
graph: the graph to be visualized
central_longitude: the central longitude of the projection, in degrees in ``[-180, +180)``
show_gridlines: whether to overlay a grid
show_gridline_labels: whether to print labels to a grid (if drawn at all)
show_coastlines: whether to outline coastlines
Returns:
The axes object containing the visualization
"""
axes = _prepare_2d_plot(
central_longitude=central_longitude,
show_gridlines=show_gridlines,
show_gridline_labels=show_gridline_labels,
show_coastlines=show_coastlines,
)
point_size = 1000 / len(graph) * plt.rcParams["lines.markersize"] ** 2
# idea: plt.plot() is faster according to the docs of axes.scatter, so it could be used instead
axes.scatter(
graph.longitudes_degrees.to_numpy(),
graph.latitudes_degrees.to_numpy(),
s=point_size,
linewidths=0,
alpha=0.9,
)
return axes
def plot_properties_2d(
graph: GeoNavigationGraph,
property_column: str,
central_longitude: float = 0.0,
resolution: int = 10,
show_gridlines: bool = True,
show_gridline_labels: bool = False,
show_coastlines: bool = False,
show_legend: bool = True,
interpolation_method: str = "nearest",
shading_method: str = "nearest",
) -> plt.Axes:
"""Creates a 2D plot of the graph and associated data.
Args:
graph: the graph to be visualized
property_column: the name of the property/node column data frame to plot
central_longitude: the central longitude of the projection, in degrees in ``[-180, +180)``
resolution: the number of points/pixels per degree latitude/longitude
show_gridlines: whether to overlay a grid
show_gridline_labels: whether to print labels to a grid (if drawn at all)
show_coastlines: whether to outline coastlines
show_legend: whether to show a legend for the color values of the property
interpolation_method: passed to :func:`scipy.interpolate.griddata`; ``"nearest"`` best reflects the
nature of the discretized nodes
shading_method: passed to :func:`matplotlib.pyplot.pcolormesh`; ``"nearest"`` best reflects the nature
of the discretized nodes
Returns:
The axes object containing the visualization
"""
axes = _prepare_2d_plot(
central_longitude=central_longitude,
show_gridlines=show_gridlines,
show_gridline_labels=show_gridline_labels,
show_coastlines=show_coastlines,
)
# re-interpolate data
lat = np.linspace(-90, +90, 180 * resolution)
lon = np.linspace(-180, +180, 360 * resolution)
lat, lon = np.meshgrid(lat, lon)
node_coordinates = np.column_stack((graph.latitudes_degrees, graph.longitudes_degrees))
grid_data = scipy.interpolate.griddata(
node_coordinates, graph.nodes[property_column], (lat, lon), method=interpolation_method
)
# print data
axes.pcolormesh(lon, lat, grid_data, alpha=0.9, cmap="seismic", shading=shading_method)
if show_legend:
plt.colorbar(ax=axes)
return axes
def dump_2d_plots(
graph: GeoNavigationGraph,
path: str,
formats: Sequence[str] = ("png",),
dpi: int = 500,
show_progress: bool = False,
kwargs_node_positions: Optional[Dict[str, Any]] = None,
kwargs_properties: Optional[Dict[str, Any]] = None,
) -> None:
"""Dump 2D plots of the graph positions and all property data into the given directory.
Args:
graph: the graph to be visualized
path: the directory where to dump the plots into; is created if not yet existing; overwrites
existing plots
formats: the file formats to save in, can be for example be *png*, *svg* or *pdf* (as it must be
supported by matplotlib). Keep in mind However, that usually only raster images work
reasonably fast
dpi: the dots per inch resolution of the resulting (raster) visualizations
show_progress: whether to print a simple progress bar
kwargs_node_positions: passed directly to :meth:`~plot_node_positions_2d`
kwargs_properties: passed directly to :meth:`~plot_properties_2d`
"""
# create the target directory if it does not already exist
assert not os.path.isfile(
path
), "the visualization target path must be (not yet existing) directory and not a regular file"
os.makedirs(path, exist_ok=True)
file_pattern = os.path.join(path, "plot_{name}.{suffix}")
node_properties = graph.node_properties
number_of_properties = node_properties.size
with tqdm(
total=(number_of_properties + 1) * len(formats), unit=" plots", disable=not show_progress
) as progress_bar:
figure: plt.Figure = plt.figure() # Reuse it in save_plot
def save_plot(prepared_axes: plt.Axes, file_path: str) -> None:
figure.add_axes(prepared_axes)
figure.savefig(file_path, transparent=True, dpi=dpi)
figure.clf()
progress_bar.update() # increment by one
# plot the node positions
for viz_format in formats:
axes = plot_node_positions_2d(graph, **(kwargs_node_positions or {}))
final_path = file_pattern.format(name="node_positions", suffix=viz_format)
save_plot(axes, final_path)
# plot the properties of the node
for property_name in node_properties.columns:
# this operation might be expensive, so only do it once per property
axes = plot_properties_2d(property_name, **(kwargs_properties or {}))
for viz_format in formats:
final_path = file_pattern.format(name=f"property_{property_name}", suffix=viz_format)
save_plot(axes, final_path)
def _main() -> None:
"""The main function."""
parser = ArgumentParser(description="Visualize a geo-referenced graph.")
parser.add_argument("path_to_graph", type=str)
parser.add_argument("visualization_output_directory", type=str)
args = parser.parse_args()
graph = GeoNavigationGraph.from_disk(args.path_to_graph)
dump_2d_plots(graph, args.visualization_output_directory)
if __name__ == "__main__":
_main()