121 lines
4.6 KiB
Python
121 lines
4.6 KiB
Python
"""Asserts correct behaviour of the base classes for graph navigation.
|
|
|
|
See Also:
|
|
tests/common/raster_datasets/test_transformers_concrete.py
|
|
"""
|
|
|
|
# Standard library
|
|
from copy import deepcopy
|
|
import os.path
|
|
from tempfile import TemporaryDirectory
|
|
from unittest import TestCase
|
|
|
|
# Scientific
|
|
from numpy import array
|
|
from numpy import empty
|
|
from numpy import full
|
|
from numpy.testing import assert_array_equal
|
|
from pandas import DataFrame
|
|
from pandas.testing import assert_frame_equal
|
|
|
|
# Module under test
|
|
from pyrate.plan.graph import NavigationGraph
|
|
|
|
|
|
# Some examples:
|
|
_NODES = DataFrame(data={"property_1": [1, 2, 3], "property_2": [10, 20, 30]})
|
|
_EDGES = array([[0, 1], [1, 2]])
|
|
_NEIGHBORS = array([[1, -1], [0, 2], [1, -1]])
|
|
|
|
|
|
class TestNavigationGraph(TestCase):
|
|
"""Tests the very basic functionality like initialization, (de)serialization and finding neighbors."""
|
|
|
|
def test_empty(self) -> None:
|
|
"""Tests that a new instance can be created with and without neighbors."""
|
|
graph = NavigationGraph(DataFrame(), empty((0, 2)))
|
|
self.assertEqual(len(graph), 0)
|
|
self.assertEqual(graph.num_edges, 0)
|
|
|
|
# check that the correct neighbor table is returned
|
|
self.assertEqual(graph.neighbors.shape, (0, 0))
|
|
|
|
def test_create(self) -> None:
|
|
"""Tests that a new instance can be created with and without neighbors."""
|
|
|
|
for given_neighbors in [_NEIGHBORS, None]:
|
|
with self.subTest(f"neighbors given = {given_neighbors is not None}"):
|
|
graph = NavigationGraph(_NODES, _EDGES, given_neighbors)
|
|
assert_array_equal(graph.neighbors, _NEIGHBORS)
|
|
|
|
# repeated queries should return the same neighbors
|
|
assert_array_equal(graph.neighbors, graph.neighbors)
|
|
|
|
def test_read_write(self) -> None:
|
|
"""Tests that a navigation graph can be serialized and deserialized again."""
|
|
|
|
# `graph.neighbors` is cached, so we want to try it with and without the cached neighbors being set
|
|
for set_neighbors in [True, False]:
|
|
with self.subTest(f"neighbors set = {set_neighbors}"):
|
|
graph = NavigationGraph(_NODES, _EDGES, max_neighbors=42)
|
|
if set_neighbors:
|
|
_ = graph.neighbors
|
|
|
|
with TemporaryDirectory() as directory:
|
|
path = os.path.join(directory, "some_file.hdf5")
|
|
graph.to_disk(path)
|
|
new_graph = NavigationGraph.from_disk(path)
|
|
|
|
self.assertEqual(graph, new_graph)
|
|
assert_array_equal(new_graph.neighbors, _NEIGHBORS)
|
|
|
|
def test_max_neighbors_constructor(self) -> None:
|
|
"""Tests that only invalid inputs to max_neighbors raise exceptions."""
|
|
NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=0)
|
|
NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=10)
|
|
|
|
with self.assertRaises(Exception): # noqa: H202
|
|
NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=-2)
|
|
|
|
|
|
class TestNavigationGraphPruningArtificial(TestCase):
|
|
"""Tests that simple toy navigation graphs can be pruned."""
|
|
|
|
def test_pruning_no_nodes(self) -> None:
|
|
"""Tests that pruning no nodes works."""
|
|
old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS)
|
|
|
|
pruned_graph = deepcopy(old_graph)
|
|
retain_all = full((len(_NODES),), True)
|
|
pruned_graph.prune_nodes(retain_all)
|
|
|
|
self.assertEqual(old_graph, pruned_graph)
|
|
|
|
def test_pruning_all(self) -> None:
|
|
"""Tests that pruning all nodes works."""
|
|
old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS)
|
|
|
|
pruned_graph = deepcopy(old_graph)
|
|
retain_all = full((len(_NODES),), False)
|
|
pruned_graph.prune_nodes(retain_all)
|
|
|
|
self.assertNotEqual(old_graph, pruned_graph)
|
|
self.assertEqual(len(pruned_graph.nodes), 0)
|
|
self.assertEqual(len(pruned_graph.nodes.columns), 2, "the properties must be retained")
|
|
self.assertEqual(pruned_graph.edges.shape, (0, 2))
|
|
self.assertEqual(pruned_graph.neighbors.shape, (0, 0))
|
|
|
|
def test_pruning_very_simple(self) -> None:
|
|
"""Tests that pruning some nodes works as expected."""
|
|
|
|
old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS)
|
|
|
|
pruned_graph = deepcopy(old_graph)
|
|
keep_condition = array([True, True, False]) # only prune the last node
|
|
pruned_graph.prune_nodes(keep_condition)
|
|
|
|
self.assertNotEqual(old_graph, pruned_graph)
|
|
assert_frame_equal(pruned_graph.nodes, _NODES[:2])
|
|
assert_array_equal(pruned_graph.edges, _EDGES[:1])
|
|
assert_array_equal(pruned_graph.neighbors, _NEIGHBORS[:2, :1])
|