Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add operators to normalize Distributions #203

Merged
merged 1 commit into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 96 additions & 5 deletions mystic/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,42 @@ def __init__(self, generator=None, *args, **kwds):
note::
generator may be a method object or a string of 'module.object';
similarly, rng may be a random_state object or a string of 'module'

note::
Distributions d1,d2 may be combined by adding data (i.e. d1(n) + d2(n)),
or by adding probabilitiies as Distribution(d1,d2); the former uses
the addition operator and produces a new unnormalized Distribution,
while the latter produces a new Distribution which randomly chooses from
the Distributions provided

note::
a normalization factor can be incorporated through the multiplication
or division operator, and is stored in the Distribution as 'norm'
""" #XXX: generate Distribution from list of Distributions?
self.norm = kwds.pop('norm', 1) + 0
if isinstance(generator, Distribution):
if kwds:
msg = 'keyword arguments are invalid with {0} instance'.format(self.__class__.__name__)
raise TypeError(msg)
if not args:
self._type = generator._type
self.rvs = generator.rvs
self.repr = generator.repr
self.norm *= generator.norm
return
# args can only support additional distribution instances
for arg in args:
if not isinstance(arg, Distribution): # raise TypeError
generator += arg
# use choice from multiple distributions
import numpy as np
generator = (generator,) + args
rep = lambda di: "{0}".format(di).split("(",1)[-1][:-1] if di._type == 'join' else "{0}".format(di)
sig = ', '.join(rep(i) for i in generator)
self.repr = lambda cls,fac: ("{0}({1}".format(cls, sig) + (')' if fac == 1 else ', norm={0})'.format(fac)))
self.rvs = lambda size=None: np.choose(np.random.choice(range(len(generator)), size=size), tuple(d(size) for d in generator))
self._type = 'join'
return
from mystic.tools import random_state
rng = kwds.pop('rng', random_state(module='numpy.random'))
if isinstance(rng, str): rng = random_state(module=rng)
Expand Down Expand Up @@ -97,16 +132,72 @@ def __init__(self, generator=None, *args, **kwds):
name = generator.__name__
mod = rng.__name__
name = "'{0}.{1}'".format(mod, name) if name else ""
sig = ', '.join(str(i) for i in args) + ', '.join("{0}={1}".format(i,j) for i,j in kwds.items())
sig = ', '.join(str(i) for i in args)
kwd = ', '.join("{0}={1}".format(i,j) for i,j in kwds.items())
#nrm = '' if self.norm == 1 else 'norm={0}'.format(self.norm)
#kwd = '{0}, {1}'.format(kwd, nrm) if (kwd and nrm) else (kwd or nrm)
sig = '{0}, {1}'.format(sig, kwd) if (sig and kwd) else (sig or kwd)
if name and sig: name += ", "
sig = (sig + ")") if sig else ")"
#sig = ", rng='{0}')".format(rng.__name__)
self.repr = lambda cls: ("{0}({1}".format(cls, name) + sig)
self.repr = lambda cls,fac: ("{0}({1}".format(cls, name) + sig + ('' if fac == 1 else ((', ' if (name or sig) else '') + 'norm={0}'.format(fac))) + ')')
self._type = 'base'
return
def __call__(self, size=None):
"""generate a sample of given size (tuple) from the distribution"""
return self.rvs(size)
return self.norm * self.rvs(size)
def __repr__(self):
return self.repr(self.__class__.__name__)
return self.repr(self.__class__.__name__, self.norm)
def __add__(self, dist):
if not isinstance(dist, Distribution):
msg = "unsupported operand type(s) for +: '{0}' and '{1}'".format(self.__class__.__name__, type(dist))
raise TypeError(msg)
# add data from multiple distributions
new = Distribution()
first = "{0}".format(self)
second = "{0}".format(dist)
if self._type == 'add': first = first.split("(",1)[-1][:-1]
if dist._type == 'add': second = second.split("(",1)[-1][:-1]
new.repr = lambda cls,fac: ("{0}({1} + {2}".format(cls, first, second) + (')' if fac == 1 else ', norm={0})'.format(fac)))
new.rvs = lambda size=None: (self(size) + dist(size))
new._type = 'add'
new.norm = 1
return new
def __mul__(self, norm):
new = Distribution()
new.repr = self.repr
new.rvs = self.rvs
new._type = 'base'
new.norm = self.norm * norm
return new
__rmul__ = __mul__
def __truediv__(self, denom):
new = Distribution()
new.repr = self.repr
new.rvs = self.rvs
new._type = 'base'
new.norm = self.norm / denom
return new
def __floordiv__(self, denom):
new = Distribution()
new.repr = self.repr
new.rvs = self.rvs
new._type = 'base'
new.norm = self.norm // denom
return new
"""
def __mul__(self, dist):
if not isinstance(dist, Distribution):
msg = "unsupported operand type(s) for *: '{0}' and '{1}'".format(self.__class__.__name__, type(dist))
raise TypeError(msg)
# use conflation of multiple distributions
new = Distribution()
norm = lambda x: x/sum(x) #FIXME: what is the formula...?
#func = lambda x,y: (x*y)/(x+y)
#new.rvs = lambda size=None: func(self(size),dist(size))
new.rvs = lambda size=None: norm(self(size) * dist(size))
new.repr = lambda cls: "{0}({1} * {2})".format(cls, self, dist)
return new
"""


# end of file
19 changes: 19 additions & 0 deletions mystic/tests/test_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from mystic.math import Distribution, almostEqual

N = 100000
a = Distribution('normal', 0, 1)
b = Distribution('normal', 5, 3)
c = Distribution('normal', 6, 6)

apb = a+b
a2pb2 = a/2+b/2
apb2= (a+b)/2
anb = Distribution(a,b)
bnc = Distribution(b,c)
bpc2 = (b+c)/2

assert almostEqual(a2pb2(N).mean(), apb2(N).mean(), tol=.1)
assert almostEqual(anb(N).mean(), .5*apb(N).mean(), tol=.1)
assert almostEqual((a(N).mean() + b(N).mean())/2, apb2(N).mean(), tol=.1)
assert almostEqual((b(N).mean() + c(N).mean())/2, bpc2(N).mean(), tol=.1)
assert almostEqual((a+b+c)(N).mean(), (c+b+a)(N).mean(), tol=.1)