133 lines
3.9 KiB
Python

"""This module includes an abstraction of gaussian distributions."""
# Typing
from typing import cast
# Mathematics
from numpy import ndarray
from scipy.stats import multivariate_normal
class Gaussian:
"""A weighted multivariate gaussian distribution.
Examples:
A Gaussian can be simply created from a mean and covarinace vector (and an optional weight):
>>> from numpy import array
>>> from numpy import vstack
>>> mean = vstack([0.0, 0.0])
>>> covariance = array([[1.0, 0.0], [0.0, 1.0]])
>>> N = Gaussian(mean, covariance, weight=1.0)
>>> N(vstack([0.0, 0.0])) # doctest: +ELLIPSIS
0.159...
Two Gaussians are equal if and only if all attributes are equal:
>>> N == N
True
>>> other_covariance = array([[99.0, 0.0], [0.0, 99.0]])
>>> other_N = Gaussian(mean, other_covariance, weight=1.0)
>>> other_N(vstack([10.0, 10.0])) # doctest: +ELLIPSIS
0.000585...
>>> N == other_N
False
Args:
mean: The mean of the distribution as column vector, of dimension ``(n, 1)``
covariance: The covariance matrix of the distribution, of dimension ``(n, n)``
weight: The weight of the distribution, e.g. within a mixture model
References:
- https://en.wikipedia.org/wiki/Multivariate_normal_distribution
"""
def __init__(self, mean: ndarray, covariance: ndarray, weight: float = 1.0) -> None:
# Sanity checks on given parameters
assert len(mean.shape) == 2 and mean.shape[1] == 1, "Mean needs to be a column vector!"
assert len(covariance.shape) == 2, "Covariance needs to be a 2D matrix!"
assert covariance.shape[0] == covariance.shape[1], "Covariance needs to be a square matrix!"
assert covariance.shape[0] == mean.shape[0], "Dimensions of mean and covariance don't fit!"
# Assign values
self.mean = mean
self.covariance = covariance
self.weight = weight
# ######################################
# Properties following a common filter notation
# pylint: disable=invalid-name
@property
def x(self) -> ndarray:
"""A shorthand for the distribution's mean.
Returns:
The mean, of dimension ``(1, n)``
"""
return self.mean
@x.setter
def x(self, value: ndarray) -> None:
self.mean = value
@property
def P(self) -> ndarray:
"""A shorthand for the distribution's covariance matrix.
Returns:
The covariance, of dimension ``(n, n)``
"""
return self.covariance
@P.setter
def P(self, value: ndarray) -> None:
self.covariance = value
@property
def w(self) -> float:
"""A shorthand for the distribution's weight.
Returns:
The weight of this distribution
"""
return self.weight
@w.setter
def w(self, value: float):
self.weight = value
def __call__(self, value: ndarray) -> float:
"""Evaluate the gaussian at the given location.
Args:
value: Where to evaluate the gaussian, of dimension ``(n, 1)``
Returns:
The probability density at the given location
"""
# Compute weighted probability density function
distribution = multivariate_normal(mean=self.mean.T[0], cov=self.covariance)
return self.weight * cast(float, distribution.pdf(value.T[0]))
def __eq__(self, other) -> bool:
"""Checks if two multivariate normal distributions are equal.
Args:
other: The distribution to compare within
Returns:
Whether the two distributions are the same
"""
return (
cast(bool, (self.mean == other.mean).all())
and cast(bool, (self.covariance == other.covariance).all())
and self.weight == other.weight
)