"""Asserts correct behaviour of the base classes for graph navigation. See Also: tests/common/raster_datasets/test_transformers_concrete.py """ # Standard library from copy import deepcopy import os.path from tempfile import TemporaryDirectory from unittest import TestCase # Scientific from numpy import array from numpy import empty from numpy import full from numpy.testing import assert_array_equal from pandas import DataFrame from pandas.testing import assert_frame_equal # Module under test from pyrate.plan.graph import NavigationGraph # Some examples: _NODES = DataFrame(data={"property_1": [1, 2, 3], "property_2": [10, 20, 30]}) _EDGES = array([[0, 1], [1, 2]]) _NEIGHBORS = array([[1, -1], [0, 2], [1, -1]]) class TestNavigationGraph(TestCase): """Tests the very basic functionality like initialization, (de)serialization and finding neighbors.""" def test_empty(self) -> None: """Tests that a new instance can be created with and without neighbors.""" graph = NavigationGraph(DataFrame(), empty((0, 2))) self.assertEqual(len(graph), 0) self.assertEqual(graph.num_edges, 0) # check that the correct neighbor table is returned self.assertEqual(graph.neighbors.shape, (0, 0)) def test_create(self) -> None: """Tests that a new instance can be created with and without neighbors.""" for given_neighbors in [_NEIGHBORS, None]: with self.subTest(f"neighbors given = {given_neighbors is not None}"): graph = NavigationGraph(_NODES, _EDGES, given_neighbors) assert_array_equal(graph.neighbors, _NEIGHBORS) # repeated queries should return the same neighbors assert_array_equal(graph.neighbors, graph.neighbors) def test_read_write(self) -> None: """Tests that a navigation graph can be serialized and deserialized again.""" # `graph.neighbors` is cached, so we want to try it with and without the cached neighbors being set for set_neighbors in [True, False]: with self.subTest(f"neighbors set = {set_neighbors}"): graph = NavigationGraph(_NODES, _EDGES, max_neighbors=42) if set_neighbors: _ = graph.neighbors with TemporaryDirectory() as directory: path = os.path.join(directory, "some_file.hdf5") graph.to_disk(path) new_graph = NavigationGraph.from_disk(path) self.assertEqual(graph, new_graph) assert_array_equal(new_graph.neighbors, _NEIGHBORS) def test_max_neighbors_constructor(self) -> None: """Tests that only invalid inputs to max_neighbors raise exceptions.""" NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=0) NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=10) with self.assertRaises(Exception): # noqa: H202 NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=-2) class TestNavigationGraphPruningArtificial(TestCase): """Tests that simple toy navigation graphs can be pruned.""" def test_pruning_no_nodes(self) -> None: """Tests that pruning no nodes works.""" old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS) pruned_graph = deepcopy(old_graph) retain_all = full((len(_NODES),), True) pruned_graph.prune_nodes(retain_all) self.assertEqual(old_graph, pruned_graph) def test_pruning_all(self) -> None: """Tests that pruning all nodes works.""" old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS) pruned_graph = deepcopy(old_graph) retain_all = full((len(_NODES),), False) pruned_graph.prune_nodes(retain_all) self.assertNotEqual(old_graph, pruned_graph) self.assertEqual(len(pruned_graph.nodes), 0) self.assertEqual(len(pruned_graph.nodes.columns), 2, "the properties must be retained") self.assertEqual(pruned_graph.edges.shape, (0, 2)) self.assertEqual(pruned_graph.neighbors.shape, (0, 0)) def test_pruning_very_simple(self) -> None: """Tests that pruning some nodes works as expected.""" old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS) pruned_graph = deepcopy(old_graph) keep_condition = array([True, True, False]) # only prune the last node pruned_graph.prune_nodes(keep_condition) self.assertNotEqual(old_graph, pruned_graph) assert_frame_equal(pruned_graph.nodes, _NODES[:2]) assert_array_equal(pruned_graph.edges, _EDGES[:1]) assert_array_equal(pruned_graph.neighbors, _NEIGHBORS[:2, :1])