121 lines
4.6 KiB
Python

"""Asserts correct behaviour of the base classes for graph navigation.
See Also:
tests/common/raster_datasets/test_transformers_concrete.py
"""
# Standard library
from copy import deepcopy
import os.path
from tempfile import TemporaryDirectory
from unittest import TestCase
# Scientific
from numpy import array
from numpy import empty
from numpy import full
from numpy.testing import assert_array_equal
from pandas import DataFrame
from pandas.testing import assert_frame_equal
# Module under test
from pyrate.plan.graph import NavigationGraph
# Some examples:
_NODES = DataFrame(data={"property_1": [1, 2, 3], "property_2": [10, 20, 30]})
_EDGES = array([[0, 1], [1, 2]])
_NEIGHBORS = array([[1, -1], [0, 2], [1, -1]])
class TestNavigationGraph(TestCase):
"""Tests the very basic functionality like initialization, (de)serialization and finding neighbors."""
def test_empty(self) -> None:
"""Tests that a new instance can be created with and without neighbors."""
graph = NavigationGraph(DataFrame(), empty((0, 2)))
self.assertEqual(len(graph), 0)
self.assertEqual(graph.num_edges, 0)
# check that the correct neighbor table is returned
self.assertEqual(graph.neighbors.shape, (0, 0))
def test_create(self) -> None:
"""Tests that a new instance can be created with and without neighbors."""
for given_neighbors in [_NEIGHBORS, None]:
with self.subTest(f"neighbors given = {given_neighbors is not None}"):
graph = NavigationGraph(_NODES, _EDGES, given_neighbors)
assert_array_equal(graph.neighbors, _NEIGHBORS)
# repeated queries should return the same neighbors
assert_array_equal(graph.neighbors, graph.neighbors)
def test_read_write(self) -> None:
"""Tests that a navigation graph can be serialized and deserialized again."""
# `graph.neighbors` is cached, so we want to try it with and without the cached neighbors being set
for set_neighbors in [True, False]:
with self.subTest(f"neighbors set = {set_neighbors}"):
graph = NavigationGraph(_NODES, _EDGES, max_neighbors=42)
if set_neighbors:
_ = graph.neighbors
with TemporaryDirectory() as directory:
path = os.path.join(directory, "some_file.hdf5")
graph.to_disk(path)
new_graph = NavigationGraph.from_disk(path)
self.assertEqual(graph, new_graph)
assert_array_equal(new_graph.neighbors, _NEIGHBORS)
def test_max_neighbors_constructor(self) -> None:
"""Tests that only invalid inputs to max_neighbors raise exceptions."""
NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=0)
NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=10)
with self.assertRaises(Exception): # noqa: H202
NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=-2)
class TestNavigationGraphPruningArtificial(TestCase):
"""Tests that simple toy navigation graphs can be pruned."""
def test_pruning_no_nodes(self) -> None:
"""Tests that pruning no nodes works."""
old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS)
pruned_graph = deepcopy(old_graph)
retain_all = full((len(_NODES),), True)
pruned_graph.prune_nodes(retain_all)
self.assertEqual(old_graph, pruned_graph)
def test_pruning_all(self) -> None:
"""Tests that pruning all nodes works."""
old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS)
pruned_graph = deepcopy(old_graph)
retain_all = full((len(_NODES),), False)
pruned_graph.prune_nodes(retain_all)
self.assertNotEqual(old_graph, pruned_graph)
self.assertEqual(len(pruned_graph.nodes), 0)
self.assertEqual(len(pruned_graph.nodes.columns), 2, "the properties must be retained")
self.assertEqual(pruned_graph.edges.shape, (0, 2))
self.assertEqual(pruned_graph.neighbors.shape, (0, 0))
def test_pruning_very_simple(self) -> None:
"""Tests that pruning some nodes works as expected."""
old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS)
pruned_graph = deepcopy(old_graph)
keep_condition = array([True, True, False]) # only prune the last node
pruned_graph.prune_nodes(keep_condition)
self.assertNotEqual(old_graph, pruned_graph)
assert_frame_equal(pruned_graph.nodes, _NODES[:2])
assert_array_equal(pruned_graph.edges, _EDGES[:1])
assert_array_equal(pruned_graph.neighbors, _NEIGHBORS[:2, :1])