253 lines
8.9 KiB
Python
253 lines
8.9 KiB
Python
#!/usr/bin/env python3
|
|
|
|
"""
|
|
Visualizes a graph generated by :func:`pyrate.plan.graph.generate.create_earth_graph`, like the one created by
|
|
the script :ref:`script-create_earth_graph`.
|
|
|
|
Examples:
|
|
Simply plot the node positions as well as graphs of all properties into a local folder.
|
|
This will overwrite existing plots and create the target directory if it does not already exist:
|
|
|
|
.. code-block:: bash
|
|
|
|
./scripts/visualize_earth_graph.py my_graph.hdf5 visualization/
|
|
|
|
"""
|
|
|
|
# Standard library
|
|
from argparse import ArgumentParser
|
|
import os.path
|
|
|
|
# Typing
|
|
from typing import Any
|
|
from typing import Dict
|
|
from typing import Optional
|
|
from typing import Sequence
|
|
|
|
# Scientific
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import scipy.interpolate
|
|
|
|
# Progress bars
|
|
from tqdm import tqdm
|
|
|
|
# Geography
|
|
from cartopy.crs import PlateCarree
|
|
from cartopy.crs import Robinson
|
|
|
|
# Own Pyrate code
|
|
from pyrate.plan.graph import GeoNavigationGraph
|
|
|
|
# pylint: disable=too-many-arguments,too-many-locals
|
|
|
|
|
|
def _prepare_2d_plot(
|
|
central_longitude: float = 0.0,
|
|
show_gridlines: bool = True,
|
|
show_gridline_labels: bool = False,
|
|
show_coastlines: bool = False,
|
|
) -> plt.Axes:
|
|
"""Prepares a 2D plot for visualizing the graph positions or property data.
|
|
|
|
Args:
|
|
central_longitude: the central longitude of the projection, in degrees in ``[-180, +180)``
|
|
show_gridlines: whether to overlay a grid
|
|
show_gridline_labels: whether to print labels to a grid (if drawn at all)
|
|
show_coastlines: whether to outline coastlines
|
|
|
|
Returns:
|
|
The correctly configured axes object
|
|
"""
|
|
# create frame plot
|
|
if show_gridlines and show_gridline_labels:
|
|
# Robinson would be nicer but is not supported by matplotlib.Axis.gridlines()
|
|
coordinate_reference = PlateCarree(central_longitude=central_longitude)
|
|
else:
|
|
coordinate_reference = Robinson(central_longitude=central_longitude)
|
|
axes = plt.axes(projection=coordinate_reference)
|
|
|
|
# create background
|
|
if show_coastlines:
|
|
# the resolution may allow only a few specific values:
|
|
# https://scitools.org.uk/cartopy/docs/latest/matplotlib/geoaxes.html#cartopy.mpl.geoaxes.GeoAxes.coastlines
|
|
axes.coastlines(resolution="110m", color="black")
|
|
if show_gridlines:
|
|
axes.gridlines(crs=coordinate_reference, draw_labels=show_gridline_labels)
|
|
|
|
axes.set_xlim(-180.0, +180.0)
|
|
axes.set_ylim(-90.0, +90.0)
|
|
|
|
return axes
|
|
|
|
|
|
def plot_node_positions_2d(
|
|
graph: GeoNavigationGraph,
|
|
central_longitude: float = 0.0,
|
|
show_gridlines: bool = True,
|
|
show_gridline_labels: bool = False,
|
|
show_coastlines: bool = False,
|
|
) -> plt.Axes:
|
|
"""Visualizes the positions of the nodes on the globe.
|
|
|
|
Args:
|
|
graph: the graph to be visualized
|
|
central_longitude: the central longitude of the projection, in degrees in ``[-180, +180)``
|
|
show_gridlines: whether to overlay a grid
|
|
show_gridline_labels: whether to print labels to a grid (if drawn at all)
|
|
show_coastlines: whether to outline coastlines
|
|
|
|
Returns:
|
|
The axes object containing the visualization
|
|
"""
|
|
axes = _prepare_2d_plot(
|
|
central_longitude=central_longitude,
|
|
show_gridlines=show_gridlines,
|
|
show_gridline_labels=show_gridline_labels,
|
|
show_coastlines=show_coastlines,
|
|
)
|
|
|
|
point_size = 1000 / len(graph) * plt.rcParams["lines.markersize"] ** 2
|
|
# idea: plt.plot() is faster according to the docs of axes.scatter, so it could be used instead
|
|
axes.scatter(
|
|
graph.longitudes_degrees.to_numpy(),
|
|
graph.latitudes_degrees.to_numpy(),
|
|
s=point_size,
|
|
linewidths=0,
|
|
alpha=0.9,
|
|
)
|
|
|
|
return axes
|
|
|
|
|
|
def plot_properties_2d(
|
|
graph: GeoNavigationGraph,
|
|
property_column: str,
|
|
central_longitude: float = 0.0,
|
|
resolution: int = 10,
|
|
show_gridlines: bool = True,
|
|
show_gridline_labels: bool = False,
|
|
show_coastlines: bool = False,
|
|
show_legend: bool = True,
|
|
interpolation_method: str = "nearest",
|
|
shading_method: str = "nearest",
|
|
) -> plt.Axes:
|
|
"""Creates a 2D plot of the graph and associated data.
|
|
|
|
Args:
|
|
graph: the graph to be visualized
|
|
property_column: the name of the property/node column data frame to plot
|
|
central_longitude: the central longitude of the projection, in degrees in ``[-180, +180)``
|
|
resolution: the number of points/pixels per degree latitude/longitude
|
|
show_gridlines: whether to overlay a grid
|
|
show_gridline_labels: whether to print labels to a grid (if drawn at all)
|
|
show_coastlines: whether to outline coastlines
|
|
show_legend: whether to show a legend for the color values of the property
|
|
interpolation_method: passed to :func:`scipy.interpolate.griddata`; ``"nearest"`` best reflects the
|
|
nature of the discretized nodes
|
|
shading_method: passed to :func:`matplotlib.pyplot.pcolormesh`; ``"nearest"`` best reflects the nature
|
|
of the discretized nodes
|
|
|
|
Returns:
|
|
The axes object containing the visualization
|
|
"""
|
|
axes = _prepare_2d_plot(
|
|
central_longitude=central_longitude,
|
|
show_gridlines=show_gridlines,
|
|
show_gridline_labels=show_gridline_labels,
|
|
show_coastlines=show_coastlines,
|
|
)
|
|
|
|
# re-interpolate data
|
|
lat = np.linspace(-90, +90, 180 * resolution)
|
|
lon = np.linspace(-180, +180, 360 * resolution)
|
|
lat, lon = np.meshgrid(lat, lon)
|
|
node_coordinates = np.column_stack((graph.latitudes_degrees, graph.longitudes_degrees))
|
|
grid_data = scipy.interpolate.griddata(
|
|
node_coordinates, graph.nodes[property_column], (lat, lon), method=interpolation_method
|
|
)
|
|
|
|
# print data
|
|
axes.pcolormesh(lon, lat, grid_data, alpha=0.9, cmap="seismic", shading=shading_method)
|
|
|
|
if show_legend:
|
|
plt.colorbar(ax=axes)
|
|
|
|
return axes
|
|
|
|
|
|
def dump_2d_plots(
|
|
graph: GeoNavigationGraph,
|
|
path: str,
|
|
formats: Sequence[str] = ("png",),
|
|
dpi: int = 500,
|
|
show_progress: bool = False,
|
|
kwargs_node_positions: Optional[Dict[str, Any]] = None,
|
|
kwargs_properties: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""Dump 2D plots of the graph positions and all property data into the given directory.
|
|
|
|
Args:
|
|
graph: the graph to be visualized
|
|
path: the directory where to dump the plots into; is created if not yet existing; overwrites
|
|
existing plots
|
|
formats: the file formats to save in, can be for example be *png*, *svg* or *pdf* (as it must be
|
|
supported by matplotlib). Keep in mind However, that usually only raster images work
|
|
reasonably fast
|
|
dpi: the dots per inch resolution of the resulting (raster) visualizations
|
|
show_progress: whether to print a simple progress bar
|
|
kwargs_node_positions: passed directly to :meth:`~plot_node_positions_2d`
|
|
kwargs_properties: passed directly to :meth:`~plot_properties_2d`
|
|
"""
|
|
# create the target directory if it does not already exist
|
|
assert not os.path.isfile(
|
|
path
|
|
), "the visualization target path must be (not yet existing) directory and not a regular file"
|
|
os.makedirs(path, exist_ok=True)
|
|
file_pattern = os.path.join(path, "plot_{name}.{suffix}")
|
|
|
|
node_properties = graph.node_properties
|
|
number_of_properties = node_properties.size
|
|
with tqdm(
|
|
total=(number_of_properties + 1) * len(formats), unit=" plots", disable=not show_progress
|
|
) as progress_bar:
|
|
|
|
figure: plt.Figure = plt.figure() # Reuse it in save_plot
|
|
|
|
def save_plot(prepared_axes: plt.Axes, file_path: str) -> None:
|
|
figure.add_axes(prepared_axes)
|
|
figure.savefig(file_path, transparent=True, dpi=dpi)
|
|
figure.clf()
|
|
|
|
progress_bar.update() # increment by one
|
|
|
|
# plot the node positions
|
|
for viz_format in formats:
|
|
axes = plot_node_positions_2d(graph, **(kwargs_node_positions or {}))
|
|
final_path = file_pattern.format(name="node_positions", suffix=viz_format)
|
|
save_plot(axes, final_path)
|
|
|
|
# plot the properties of the node
|
|
for property_name in node_properties.columns:
|
|
# this operation might be expensive, so only do it once per property
|
|
axes = plot_properties_2d(property_name, **(kwargs_properties or {}))
|
|
|
|
for viz_format in formats:
|
|
final_path = file_pattern.format(name=f"property_{property_name}", suffix=viz_format)
|
|
save_plot(axes, final_path)
|
|
|
|
|
|
def _main() -> None:
|
|
"""The main function."""
|
|
parser = ArgumentParser(description="Visualize a geo-referenced graph.")
|
|
parser.add_argument("path_to_graph", type=str)
|
|
parser.add_argument("visualization_output_directory", type=str)
|
|
args = parser.parse_args()
|
|
|
|
graph = GeoNavigationGraph.from_disk(args.path_to_graph)
|
|
dump_2d_plots(graph, args.visualization_output_directory)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_main()
|