133 lines
3.9 KiB
Python
133 lines
3.9 KiB
Python
"""This module includes an abstraction of gaussian distributions."""
|
|
|
|
# Typing
|
|
from typing import cast
|
|
|
|
# Mathematics
|
|
from numpy import ndarray
|
|
from scipy.stats import multivariate_normal
|
|
|
|
|
|
class Gaussian:
|
|
|
|
"""A weighted multivariate gaussian distribution.
|
|
|
|
Examples:
|
|
A Gaussian can be simply created from a mean and covarinace vector (and an optional weight):
|
|
|
|
>>> from numpy import array
|
|
>>> from numpy import vstack
|
|
>>> mean = vstack([0.0, 0.0])
|
|
>>> covariance = array([[1.0, 0.0], [0.0, 1.0]])
|
|
>>> N = Gaussian(mean, covariance, weight=1.0)
|
|
>>> N(vstack([0.0, 0.0])) # doctest: +ELLIPSIS
|
|
0.159...
|
|
|
|
Two Gaussians are equal if and only if all attributes are equal:
|
|
|
|
>>> N == N
|
|
True
|
|
>>> other_covariance = array([[99.0, 0.0], [0.0, 99.0]])
|
|
>>> other_N = Gaussian(mean, other_covariance, weight=1.0)
|
|
>>> other_N(vstack([10.0, 10.0])) # doctest: +ELLIPSIS
|
|
0.000585...
|
|
>>> N == other_N
|
|
False
|
|
|
|
Args:
|
|
mean: The mean of the distribution as column vector, of dimension ``(n, 1)``
|
|
covariance: The covariance matrix of the distribution, of dimension ``(n, n)``
|
|
weight: The weight of the distribution, e.g. within a mixture model
|
|
|
|
References:
|
|
- https://en.wikipedia.org/wiki/Multivariate_normal_distribution
|
|
"""
|
|
|
|
def __init__(self, mean: ndarray, covariance: ndarray, weight: float = 1.0) -> None:
|
|
# Sanity checks on given parameters
|
|
assert len(mean.shape) == 2 and mean.shape[1] == 1, "Mean needs to be a column vector!"
|
|
assert len(covariance.shape) == 2, "Covariance needs to be a 2D matrix!"
|
|
assert covariance.shape[0] == covariance.shape[1], "Covariance needs to be a square matrix!"
|
|
assert covariance.shape[0] == mean.shape[0], "Dimensions of mean and covariance don't fit!"
|
|
|
|
# Assign values
|
|
self.mean = mean
|
|
self.covariance = covariance
|
|
self.weight = weight
|
|
|
|
# ######################################
|
|
# Properties following a common filter notation
|
|
# pylint: disable=invalid-name
|
|
@property
|
|
def x(self) -> ndarray:
|
|
"""A shorthand for the distribution's mean.
|
|
|
|
Returns:
|
|
The mean, of dimension ``(1, n)``
|
|
"""
|
|
|
|
return self.mean
|
|
|
|
@x.setter
|
|
def x(self, value: ndarray) -> None:
|
|
self.mean = value
|
|
|
|
@property
|
|
def P(self) -> ndarray:
|
|
"""A shorthand for the distribution's covariance matrix.
|
|
|
|
Returns:
|
|
The covariance, of dimension ``(n, n)``
|
|
"""
|
|
|
|
return self.covariance
|
|
|
|
@P.setter
|
|
def P(self, value: ndarray) -> None:
|
|
self.covariance = value
|
|
|
|
@property
|
|
def w(self) -> float:
|
|
"""A shorthand for the distribution's weight.
|
|
|
|
Returns:
|
|
The weight of this distribution
|
|
"""
|
|
|
|
return self.weight
|
|
|
|
@w.setter
|
|
def w(self, value: float):
|
|
self.weight = value
|
|
|
|
def __call__(self, value: ndarray) -> float:
|
|
"""Evaluate the gaussian at the given location.
|
|
|
|
Args:
|
|
value: Where to evaluate the gaussian, of dimension ``(n, 1)``
|
|
|
|
Returns:
|
|
The probability density at the given location
|
|
"""
|
|
|
|
# Compute weighted probability density function
|
|
distribution = multivariate_normal(mean=self.mean.T[0], cov=self.covariance)
|
|
|
|
return self.weight * cast(float, distribution.pdf(value.T[0]))
|
|
|
|
def __eq__(self, other) -> bool:
|
|
"""Checks if two multivariate normal distributions are equal.
|
|
|
|
Args:
|
|
other: The distribution to compare within
|
|
|
|
Returns:
|
|
Whether the two distributions are the same
|
|
"""
|
|
|
|
return (
|
|
cast(bool, (self.mean == other.mean).all())
|
|
and cast(bool, (self.covariance == other.covariance).all())
|
|
and self.weight == other.weight
|
|
)
|