Skip to content

Commit d6af8c9

Browse files
gaurav-gireeshlanking520
authored andcommitted
[MXNET -1004] Poisson NegativeLog Likelihood loss (apache#12697)
* PoissonNLLLoss function to compute negative log likelihood loss * Removing debugging print statements * Pylint code formatting problems addressed * Added Stirling approximation for factorial term in the denominator and test case for the same * Separated the test cases for Flag value for logits and compute_full * Added comments for package- numpy inclusion and some pylint formatting * Trigger CI * Markdown file updted. Added entry for Poissons NLLLoss * Fixing pending documentation issue * Documentation docstring changed * PR Comment to remove extra newline removed. * Symbol PI corrected * epsilon spellicng correction * More unit tests added - testing with mod.score() and mod.fit() * changed the number of epochs * PR Comments addressed added mod score tests and a newline * Empty line added * Adding hybridized test * Trigger CI * Variable names changed
1 parent dabdf2a commit d6af8c9

File tree

3 files changed

+118
-1
lines changed

3 files changed

+118
-1
lines changed

docs/api/python/gluon/loss.md

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ This package includes several commonly used loss functions in neural networks.
2525
LogisticLoss
2626
TripletLoss
2727
CTCLoss
28+
PoissonNLLLoss
2829
```
2930

3031

python/mxnet/gluon/loss.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
'SigmoidBinaryCrossEntropyLoss', 'SigmoidBCELoss',
2424
'SoftmaxCrossEntropyLoss', 'SoftmaxCELoss',
2525
'KLDivLoss', 'CTCLoss', 'HuberLoss', 'HingeLoss',
26-
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss']
26+
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss', 'PoissonNLLLoss']
2727

28+
import numpy as np
2829
from .. import ndarray
2930
from ..base import numeric_types
3031
from .block import HybridBlock
@@ -706,3 +707,63 @@ def hybrid_forward(self, F, pred, positive, negative):
706707
axis=self._batch_axis, exclude=True)
707708
loss = F.relu(loss + self._margin)
708709
return _apply_weighting(F, loss, self._weight, None)
710+
711+
712+
class PoissonNLLLoss(Loss):
713+
r"""For a target (Random Variable) in a Poisson distribution, the function calculates the Negative
714+
Log likelihood loss.
715+
PoissonNLLLoss measures the loss accrued from a poisson regression prediction made by the model.
716+
717+
.. math::
718+
L = \text{pred} - \text{target} * \log(\text{pred}) +\log(\text{target!})
719+
720+
`pred`, `target` can have arbitrary shape as long as they have the same number of elements.
721+
722+
Parameters
723+
----------
724+
from_logits : boolean, default True
725+
indicating whether log(predicted) value has already been computed. If True, the loss is computed as
726+
:math:`\exp(\text{pred}) - \text{target} * \text{pred}`, and if False, then loss is computed as
727+
:math:`\text{pred} - \text{target} * \log(\text{pred}+\text{epsilon})`.The default value
728+
weight : float or None
729+
Global scalar weight for loss.
730+
batch_axis : int, default 0
731+
The axis that represents mini-batch.
732+
compute_full: boolean, default False
733+
Indicates whether to add an approximation(Stirling factor) for the Factorial term in the formula for the loss.
734+
The Stirling factor is:
735+
:math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`
736+
epsilon: float, default 1e-08
737+
This is to avoid calculating log(0) which is not defined.
738+
739+
740+
Inputs:
741+
- **pred**: Predicted value
742+
- **target**: Random variable(count or number) which belongs to a Poisson distribution.
743+
- **sample_weight**: element-wise weighting tensor. Must be broadcastable
744+
to the same shape as pred. For example, if pred has shape (64, 10)
745+
and you want to weigh each sample in the batch separately,
746+
sample_weight should have shape (64, 1).
747+
748+
Outputs:
749+
- **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,).
750+
"""
751+
def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=False, **kwargs):
752+
super(PoissonNLLLoss, self).__init__(weight, batch_axis, **kwargs)
753+
self._from_logits = from_logits
754+
self._compute_full = compute_full
755+
756+
def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08):
757+
target = _reshape_like(F, target, pred)
758+
if self._from_logits:
759+
loss = F.exp(pred) - target * pred
760+
else:
761+
loss = pred - target * F.log(pred + epsilon)
762+
if self._compute_full:
763+
# Using numpy's pi value
764+
stirling_factor = target * F.log(target)- target + 0.5 * F.log(2 * target * np.pi)
765+
target_gt_1 = target > 1
766+
stirling_factor *= target_gt_1
767+
loss += stirling_factor
768+
loss = _apply_weighting(F, loss, self._weight, sample_weight)
769+
return F.mean(loss)

tests/python/unittest/test_loss.py

+55
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,61 @@ def test_triplet_loss():
348348
optimizer='adam')
349349
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05
350350

351+
@with_seed()
352+
def test_poisson_nllloss():
353+
pred = mx.nd.random.normal(shape=(3, 4))
354+
min_pred = mx.nd.min(pred)
355+
#This is necessary to ensure only positive random values are generated for prediction,
356+
# to avoid ivalid log calculation
357+
pred[:] = pred + mx.nd.abs(min_pred)
358+
target = mx.nd.random.normal(shape=(3, 4))
359+
min_target = mx.nd.min(target)
360+
#This is necessary to ensure only positive random values are generated for prediction,
361+
# to avoid ivalid log calculation
362+
target[:] += mx.nd.abs(min_target)
363+
364+
Loss = gluon.loss.PoissonNLLLoss(from_logits=True)
365+
Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False)
366+
#Calculating by brute formula for default value of from_logits = True
367+
368+
# 1) Testing for flag logits = True
369+
brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy())
370+
loss_withlogits = Loss(pred, target)
371+
assert_almost_equal(brute_loss, loss_withlogits.asscalar())
372+
373+
#2) Testing for flag logits = False
374+
loss_no_logits = Loss_no_logits(pred, target)
375+
np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08))
376+
if np.isnan(loss_no_logits.asscalar()):
377+
assert_almost_equal(np.isnan(np_loss_no_logits), np.isnan(loss_no_logits.asscalar()))
378+
else:
379+
assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar())
380+
381+
#3) Testing for Sterling approximation
382+
np_pred = np.random.uniform(1, 5, (2, 3))
383+
np_target = np.random.uniform(1, 5, (2, 3))
384+
np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
385+
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)))
386+
Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
387+
loss_compute_full = Loss_compute_full(mx.nd.array(np_pred), mx.nd.array(np_target))
388+
assert_almost_equal(np_compute_full, loss_compute_full.asscalar())
389+
390+
@with_seed()
391+
def test_poisson_nllloss_mod():
392+
N = 1000
393+
data = mx.random.poisson(shape=(N, 2))
394+
label = mx.random.poisson(lam=4, shape=(N, 1))
395+
data_iter = mx.io.NDArrayIter(data, label, batch_size=20, label_name='label', shuffle=True)
396+
output = mx.sym.exp(get_net(1))
397+
l = mx.symbol.Variable('label')
398+
Loss = gluon.loss.PoissonNLLLoss(from_logits=False)
399+
loss = Loss(output, l)
400+
loss = mx.sym.make_loss(loss)
401+
mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
402+
mod.fit(data_iter, num_epoch=20, optimizer_params={'learning_rate': 0.01},
403+
initializer=mx.init.Normal(sigma=0.1), eval_metric=mx.metric.Loss(),
404+
optimizer='adam')
405+
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05
351406

352407
if __name__ == '__main__':
353408
import nose

0 commit comments

Comments
 (0)