167 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			167 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Asserts correct behaviour of the geo-referenced graph navigation.
 | 
						|
 | 
						|
See Also:
 | 
						|
    tests/common/raster_datasets/test_transformers_concrete.py
 | 
						|
"""
 | 
						|
 | 
						|
# Standard library
 | 
						|
from copy import deepcopy
 | 
						|
import os.path
 | 
						|
from tempfile import TemporaryDirectory
 | 
						|
from unittest import TestCase
 | 
						|
 | 
						|
# Scientific
 | 
						|
import numpy
 | 
						|
from numpy import arange
 | 
						|
from numpy import array
 | 
						|
from numpy import empty
 | 
						|
from numpy.testing import assert_array_equal
 | 
						|
from pandas import DataFrame
 | 
						|
 | 
						|
# Graph generation / Module under test
 | 
						|
from pyrate.common.raster_datasets import transformers_concrete
 | 
						|
from pyrate.plan.graph import create_earth_graph
 | 
						|
from pyrate.plan.graph import GeoNavigationGraph
 | 
						|
from pyrate.plan.graph import min_required_frequency
 | 
						|
 | 
						|
# CI/Testing helpers
 | 
						|
from ... import _open_test_geo_dataset
 | 
						|
 | 
						|
 | 
						|
from .generate.test_graph_generation import EXAMPLE_DISTANCES_KILOMETERS
 | 
						|
 | 
						|
 | 
						|
class TestGeoNavigationGraph(TestCase):
 | 
						|
    """Tests properties specific to :class:`pyrate.plan.graph.GeoNavigationGraph`."""
 | 
						|
 | 
						|
    def test_create_invalid_duplicate_argument_nodes(self) -> None:
 | 
						|
        """Tests supplying nodes to from_coordinates_radians/from_coordinates_degrees raises an Exception."""
 | 
						|
        for function in [
 | 
						|
            GeoNavigationGraph.from_coordinates_degrees,
 | 
						|
            GeoNavigationGraph.from_coordinates_radians,
 | 
						|
        ]:
 | 
						|
            with self.subTest(msg=f"function {str(function)}"):
 | 
						|
                with self.assertRaises(Exception):  # noqa: H202
 | 
						|
                    function(  # type: ignore
 | 
						|
                        latitudes=empty((0,)), longitudes=empty((0,)), edges=empty((0, 2)), nodes=DataFrame()
 | 
						|
                    )
 | 
						|
 | 
						|
    def test_node_radius_constructor(self) -> None:
 | 
						|
        """Tests that only invalid inputs to node_radius raise exceptions."""
 | 
						|
        GeoNavigationGraph.from_coordinates_degrees(
 | 
						|
            latitudes=empty((0,)), longitudes=empty((0,)), edges=empty((0, 2)), node_radius=0
 | 
						|
        )
 | 
						|
        GeoNavigationGraph.from_coordinates_degrees(
 | 
						|
            latitudes=empty((0,)), longitudes=empty((0,)), edges=empty((0, 2)), node_radius=100_000
 | 
						|
        )
 | 
						|
 | 
						|
        with self.assertRaises(Exception):  # noqa: H202
 | 
						|
            GeoNavigationGraph.from_coordinates_degrees(
 | 
						|
                latitudes=empty((0,)), longitudes=empty((0,)), edges=empty((0, 2)), node_radius=-1e-9
 | 
						|
            )
 | 
						|
 | 
						|
    def test_set_node_properties(self) -> None:
 | 
						|
        """Tests that passing ``node_properties`` works."""
 | 
						|
        graph = GeoNavigationGraph.from_coordinates_radians(
 | 
						|
            latitudes=array([42]),
 | 
						|
            longitudes=array([21]),
 | 
						|
            edges=empty((0, 2)),
 | 
						|
            node_radius=100,
 | 
						|
            node_properties=DataFrame(data={"col1": [99], "col2": ["text"]}),
 | 
						|
        )
 | 
						|
        self.assertEqual(graph.node_radius, 100)
 | 
						|
        assert_array_equal(graph.node_properties["col1"], [99])
 | 
						|
        assert_array_equal(graph.node_properties["col2"], ["text"])
 | 
						|
 | 
						|
    def test_read_write(self) -> None:
 | 
						|
        """Tests that a *geo* navigation graph can be serialized and deserialized again."""
 | 
						|
        latitudes = array([49.8725144])
 | 
						|
        longitudes = array([8.6528707])
 | 
						|
        edges = empty((0, 2))
 | 
						|
 | 
						|
        # `graph.neighbors` is cached, so we want to try it with and without the cached neighbors being set
 | 
						|
        for set_neighbors in [True, False]:
 | 
						|
            with self.subTest(f"neighbors set = {set_neighbors}"):
 | 
						|
                graph = GeoNavigationGraph.from_coordinates_degrees(
 | 
						|
                    latitudes, longitudes, edges=edges, max_neighbors=42, node_radius=1000
 | 
						|
                )
 | 
						|
                if set_neighbors:
 | 
						|
                    _ = graph.neighbors
 | 
						|
 | 
						|
                with TemporaryDirectory() as directory:
 | 
						|
                    path = os.path.join(directory, "some_file.hdf5")
 | 
						|
                    graph.to_disk(path)
 | 
						|
                    new_graph = GeoNavigationGraph.from_disk(path)
 | 
						|
 | 
						|
                self.assertEqual(graph, new_graph)
 | 
						|
                assert_array_equal(new_graph.neighbors, graph.neighbors)
 | 
						|
 | 
						|
 | 
						|
class TestNavigationGraphPruningGeo(TestCase):
 | 
						|
    """Tests that navigation graphs can be pruned by testing it with earth graphs."""
 | 
						|
 | 
						|
    def test_pruning_artificial(self) -> None:
 | 
						|
        """Tests that pruning half of the points works as expected."""
 | 
						|
 | 
						|
        for distance_km in EXAMPLE_DISTANCES_KILOMETERS:
 | 
						|
            with self.subTest(f"Test with distance {distance_km} km"):
 | 
						|
                # create a grid
 | 
						|
                graph = create_earth_graph(min_required_frequency(distance_km * 1000, in_meters=True))
 | 
						|
 | 
						|
                # keep all nodes at even latitudes
 | 
						|
                keep_condition = arange(0, len(graph)) % 2 == 0
 | 
						|
                pruned_graph = deepcopy(graph)
 | 
						|
                pruned_graph.prune_nodes(keep_condition)
 | 
						|
 | 
						|
                self.assertGreater(len(pruned_graph), 0, "some node must remain")
 | 
						|
 | 
						|
                # test the reduction ratio
 | 
						|
                delta_nodes = len(pruned_graph) / len(graph)
 | 
						|
                delta_edges = pruned_graph.num_edges / graph.num_edges
 | 
						|
                self.assertAlmostEqual(delta_nodes, 0.5, msg="suspicious node count reduction")
 | 
						|
                # about a fifth of all edges should be removed since each of the removed nodes removed five
 | 
						|
                # edges
 | 
						|
                self.assertAlmostEqual(delta_edges, 1 / 5, delta=0.15, msg="suspicious edge count reduction")
 | 
						|
 | 
						|
                # test the values in the edges, since they were rewritten as they point to new indices
 | 
						|
                self.assertTrue(numpy.all(pruned_graph.edges[:, :] >= 0), "indices must be non-negative")
 | 
						|
                self.assertTrue(
 | 
						|
                    numpy.all(pruned_graph.edges[:, :] < len(pruned_graph)),
 | 
						|
                    "some filtered edges reference (now) non-existent points",
 | 
						|
                )
 | 
						|
 | 
						|
    def test_pruning_depth(self) -> None:
 | 
						|
        """Supplements :meth`~test_pruning_artificial` by a real-world application.
 | 
						|
 | 
						|
        Only checks application-specific properties and not, for example, the general shapes of the result.
 | 
						|
        """
 | 
						|
        # create a grid
 | 
						|
        distance_meters = 500_000
 | 
						|
        graph = create_earth_graph(min_required_frequency(distance_meters, in_meters=True))
 | 
						|
 | 
						|
        # fetch properties
 | 
						|
        mode = transformers_concrete.BathymetricTransformer.Modes.AVERAGE_DEPTH
 | 
						|
        graph.append_property(transformers_concrete.BathymetricTransformer(_open_test_geo_dataset(), [mode]))
 | 
						|
 | 
						|
        # keep all nodes that are below sea level
 | 
						|
        keep_condition = (graph.node_properties[mode.column_name] < 0.0).to_numpy()
 | 
						|
 | 
						|
        # Remove the now useless property
 | 
						|
        graph.clear_node_properties()
 | 
						|
 | 
						|
        # perform pruning
 | 
						|
        pruned_graph = deepcopy(graph)
 | 
						|
        pruned_graph.prune_nodes(keep_condition)
 | 
						|
 | 
						|
        # test the reduction ratio
 | 
						|
        delta_nodes = len(pruned_graph) / len(graph)
 | 
						|
        delta_edges = pruned_graph.num_edges / graph.num_edges
 | 
						|
        earth_fraction_water = 0.708  # see https://en.wikipedia.org/wiki/World_Ocean
 | 
						|
        # although we go by topography and not water coverage, this should still be fairly correct
 | 
						|
        self.assertAlmostEqual(
 | 
						|
            delta_nodes, earth_fraction_water, delta=0.1, msg="suspicious node count reduction"
 | 
						|
        )
 | 
						|
        self.assertAlmostEqual(
 | 
						|
            delta_edges, earth_fraction_water, delta=0.1, msg="suspicious edge count reduction"
 | 
						|
        )
 |