Removed the subdir.
This commit is contained in:
120
pyrate/tests/plan/graph/test_graph.py
Normal file
120
pyrate/tests/plan/graph/test_graph.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""Asserts correct behaviour of the base classes for graph navigation.
|
||||
|
||||
See Also:
|
||||
tests/common/raster_datasets/test_transformers_concrete.py
|
||||
"""
|
||||
|
||||
# Standard library
|
||||
from copy import deepcopy
|
||||
import os.path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest import TestCase
|
||||
|
||||
# Scientific
|
||||
from numpy import array
|
||||
from numpy import empty
|
||||
from numpy import full
|
||||
from numpy.testing import assert_array_equal
|
||||
from pandas import DataFrame
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
# Module under test
|
||||
from pyrate.plan.graph import NavigationGraph
|
||||
|
||||
|
||||
# Some examples:
|
||||
_NODES = DataFrame(data={"property_1": [1, 2, 3], "property_2": [10, 20, 30]})
|
||||
_EDGES = array([[0, 1], [1, 2]])
|
||||
_NEIGHBORS = array([[1, -1], [0, 2], [1, -1]])
|
||||
|
||||
|
||||
class TestNavigationGraph(TestCase):
|
||||
"""Tests the very basic functionality like initialization, (de)serialization and finding neighbors."""
|
||||
|
||||
def test_empty(self) -> None:
|
||||
"""Tests that a new instance can be created with and without neighbors."""
|
||||
graph = NavigationGraph(DataFrame(), empty((0, 2)))
|
||||
self.assertEqual(len(graph), 0)
|
||||
self.assertEqual(graph.num_edges, 0)
|
||||
|
||||
# check that the correct neighbor table is returned
|
||||
self.assertEqual(graph.neighbors.shape, (0, 0))
|
||||
|
||||
def test_create(self) -> None:
|
||||
"""Tests that a new instance can be created with and without neighbors."""
|
||||
|
||||
for given_neighbors in [_NEIGHBORS, None]:
|
||||
with self.subTest(f"neighbors given = {given_neighbors is not None}"):
|
||||
graph = NavigationGraph(_NODES, _EDGES, given_neighbors)
|
||||
assert_array_equal(graph.neighbors, _NEIGHBORS)
|
||||
|
||||
# repeated queries should return the same neighbors
|
||||
assert_array_equal(graph.neighbors, graph.neighbors)
|
||||
|
||||
def test_read_write(self) -> None:
|
||||
"""Tests that a navigation graph can be serialized and deserialized again."""
|
||||
|
||||
# `graph.neighbors` is cached, so we want to try it with and without the cached neighbors being set
|
||||
for set_neighbors in [True, False]:
|
||||
with self.subTest(f"neighbors set = {set_neighbors}"):
|
||||
graph = NavigationGraph(_NODES, _EDGES, max_neighbors=42)
|
||||
if set_neighbors:
|
||||
_ = graph.neighbors
|
||||
|
||||
with TemporaryDirectory() as directory:
|
||||
path = os.path.join(directory, "some_file.hdf5")
|
||||
graph.to_disk(path)
|
||||
new_graph = NavigationGraph.from_disk(path)
|
||||
|
||||
self.assertEqual(graph, new_graph)
|
||||
assert_array_equal(new_graph.neighbors, _NEIGHBORS)
|
||||
|
||||
def test_max_neighbors_constructor(self) -> None:
|
||||
"""Tests that only invalid inputs to max_neighbors raise exceptions."""
|
||||
NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=0)
|
||||
NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=10)
|
||||
|
||||
with self.assertRaises(Exception): # noqa: H202
|
||||
NavigationGraph(DataFrame(), empty((0, 2)), max_neighbors=-2)
|
||||
|
||||
|
||||
class TestNavigationGraphPruningArtificial(TestCase):
|
||||
"""Tests that simple toy navigation graphs can be pruned."""
|
||||
|
||||
def test_pruning_no_nodes(self) -> None:
|
||||
"""Tests that pruning no nodes works."""
|
||||
old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS)
|
||||
|
||||
pruned_graph = deepcopy(old_graph)
|
||||
retain_all = full((len(_NODES),), True)
|
||||
pruned_graph.prune_nodes(retain_all)
|
||||
|
||||
self.assertEqual(old_graph, pruned_graph)
|
||||
|
||||
def test_pruning_all(self) -> None:
|
||||
"""Tests that pruning all nodes works."""
|
||||
old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS)
|
||||
|
||||
pruned_graph = deepcopy(old_graph)
|
||||
retain_all = full((len(_NODES),), False)
|
||||
pruned_graph.prune_nodes(retain_all)
|
||||
|
||||
self.assertNotEqual(old_graph, pruned_graph)
|
||||
self.assertEqual(len(pruned_graph.nodes), 0)
|
||||
self.assertEqual(len(pruned_graph.nodes.columns), 2, "the properties must be retained")
|
||||
self.assertEqual(pruned_graph.edges.shape, (0, 2))
|
||||
self.assertEqual(pruned_graph.neighbors.shape, (0, 0))
|
||||
|
||||
def test_pruning_very_simple(self) -> None:
|
||||
"""Tests that pruning some nodes works as expected."""
|
||||
|
||||
old_graph = NavigationGraph(_NODES, _EDGES, _NEIGHBORS)
|
||||
|
||||
pruned_graph = deepcopy(old_graph)
|
||||
keep_condition = array([True, True, False]) # only prune the last node
|
||||
pruned_graph.prune_nodes(keep_condition)
|
||||
|
||||
self.assertNotEqual(old_graph, pruned_graph)
|
||||
assert_frame_equal(pruned_graph.nodes, _NODES[:2])
|
||||
assert_array_equal(pruned_graph.edges, _EDGES[:1])
|
||||
assert_array_equal(pruned_graph.neighbors, _NEIGHBORS[:2, :1])
|
Reference in New Issue
Block a user