Skip to content

Commit 396d7e9

Browse files
sxjscienceJin Huang
authored and
Jin Huang
committed
[MXNET-58]Layer Normalization in C++ (apache#10029)
* add layer_norm + fix batch_norm doc * add test * add layer normaliation in Gluon * update * fix __repr__ + lint * fix doc * fix threshold * fix doc * fix bug * enable inplace + fix test * try to fix test * fix doc
1 parent 2445bd4 commit 396d7e9

File tree

11 files changed

+627
-7
lines changed

11 files changed

+627
-7
lines changed

docs/api/python/gluon/nn.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ This document lists the neural network blocks in Gluon:
2020
Dropout
2121
BatchNorm
2222
InstanceNorm
23+
LayerNorm
2324
Embedding
2425
Flatten
2526
```

docs/api/python/ndarray/ndarray.md

+1
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ The `ndarray` package provides several classes:
640640
Embedding
641641
LeakyReLU
642642
InstanceNorm
643+
LayerNorm
643644
L2Normalization
644645
LRN
645646
ROIPooling

docs/api/python/symbol/symbol.md

+1
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,7 @@ Composite multiple symbols into a new one by an operator.
641641
Embedding
642642
LeakyReLU
643643
InstanceNorm
644+
LayerNorm
644645
L2Normalization
645646
LRN
646647
ROIPooling

python/mxnet/gluon/nn/basic_layers.py

+94-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# pylint: disable= arguments-differ
2020
"""Basic neural network layers."""
2121
__all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
22-
'BatchNorm', 'InstanceNorm', 'Flatten', 'Lambda', 'HybridLambda']
22+
'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 'HybridLambda']
2323
import warnings
2424
import numpy as np
2525

@@ -419,14 +419,18 @@ class InstanceNorm(HybridBlock):
419419
420420
.. math::
421421
422-
out = \frac{x - mean[data]}{ \sqrt{Var[data]} + \epsilon} * gamma + beta
422+
\bar{C} = \{i \mid i \neq 0, i \neq axis\}
423+
424+
out = \frac{x - mean[data, \bar{C}]}{ \sqrt{Var[data, \bar{C}]} + \epsilon}
425+
* gamma + beta
423426
424427
Parameters
425428
----------
426429
axis : int, default 1
427-
The axis that should be normalized. This is typically the channels
430+
The axis that will be excluded in the normalization process. This is typically the channels
428431
(C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`,
429-
set `axis=1` in `InstanceNorm`. If `layout='NHWC'`, then set `axis=3`.
432+
set `axis=1` in `InstanceNorm`. If `layout='NHWC'`, then set `axis=3`. Data will be
433+
normalized along axes excluding the first axis and the axis given.
430434
epsilon: float, default 1e-5
431435
Small float added to variance to avoid dividing by zero.
432436
center: bool, default True
@@ -475,7 +479,7 @@ def __init__(self, axis=1, epsilon=1e-5, center=True, scale=False,
475479
beta_initializer='zeros', gamma_initializer='ones',
476480
in_channels=0, **kwargs):
477481
super(InstanceNorm, self).__init__(**kwargs)
478-
self._kwargs = {'eps': epsilon, 'axis': axis}
482+
self._kwargs = {'eps': epsilon, 'axis': axis, 'center': center, 'scale': scale}
479483
self._axis = axis
480484
self._epsilon = epsilon
481485
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
@@ -502,6 +506,91 @@ def __repr__(self):
502506
content=', '.join(['='.join([k, v.__repr__()])
503507
for k, v in self._kwargs.items()]))
504508

509+
510+
class LayerNorm(HybridBlock):
511+
r"""
512+
Applies layer normalization to the n-dimensional input array.
513+
This operator takes an n-dimensional input array and normalizes
514+
the input using the given axis:
515+
516+
.. math::
517+
518+
out = \frac{x - mean[data, axis]}{ \sqrt{Var[data, axis]} + \epsilon} * gamma + beta
519+
520+
Parameters
521+
----------
522+
axis : int, default -1
523+
The axis that should be normalized. This is typically the axis of the channels.
524+
epsilon: float, default 1e-5
525+
Small float added to variance to avoid dividing by zero.
526+
center: bool, default True
527+
If True, add offset of `beta` to normalized tensor.
528+
If False, `beta` is ignored.
529+
scale: bool, default True
530+
If True, multiply by `gamma`. If False, `gamma` is not used.
531+
beta_initializer: str or `Initializer`, default 'zeros'
532+
Initializer for the beta weight.
533+
gamma_initializer: str or `Initializer`, default 'ones'
534+
Initializer for the gamma weight.
535+
in_channels : int, default 0
536+
Number of channels (feature maps) in input data. If not specified,
537+
initialization will be deferred to the first time `forward` is called
538+
and `in_channels` will be inferred from the shape of input data.
539+
540+
541+
Inputs:
542+
- **data**: input tensor with arbitrary shape.
543+
544+
Outputs:
545+
- **out**: output tensor with the same shape as `data`.
546+
547+
References
548+
----------
549+
`Layer Normalization
550+
<https://arxiv.org/pdf/1607.06450.pdf>`_
551+
552+
Examples
553+
--------
554+
>>> # Input of shape (2, 5)
555+
>>> x = mx.nd.array([[1, 2, 3, 4, 5], [1, 1, 2, 2, 2]])
556+
>>> # Layer normalization is calculated with the above formula
557+
>>> layer = LayerNorm()
558+
>>> layer.initialize(ctx=mx.cpu(0))
559+
>>> layer(x)
560+
[[-1.41421 -0.707105 0. 0.707105 1.41421 ]
561+
[-1.2247195 -1.2247195 0.81647956 0.81647956 0.81647956]]
562+
<NDArray 2x5 @cpu(0)>
563+
"""
564+
def __init__(self, axis=-1, epsilon=1e-5, center=True, scale=True,
565+
beta_initializer='zeros', gamma_initializer='ones',
566+
in_channels=0, prefix=None, params=None):
567+
super(LayerNorm, self).__init__(prefix=prefix, params=params)
568+
self._kwargs = {'eps': epsilon, 'axis': axis, 'center': center, 'scale': scale}
569+
self._axis = axis
570+
self._epsilon = epsilon
571+
self._center = center
572+
self._scale = scale
573+
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
574+
shape=(in_channels,), init=gamma_initializer,
575+
allow_deferred_init=True)
576+
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
577+
shape=(in_channels,), init=beta_initializer,
578+
allow_deferred_init=True)
579+
580+
def hybrid_forward(self, F, data, gamma, beta):
581+
norm_data = F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon)
582+
return norm_data
583+
584+
def __repr__(self):
585+
s = '{name}({content}'
586+
in_channels = self.gamma.shape[0]
587+
s += ', in_channels={0}'.format(in_channels)
588+
s += ')'
589+
return s.format(name=self.__class__.__name__,
590+
content=', '.join(['='.join([k, v.__repr__()])
591+
for k, v in self._kwargs.items()]))
592+
593+
505594
class Lambda(Block):
506595
r"""Wraps an operator or an expression as a Block object.
507596

src/operator/nn/batch_norm-inl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
7979
.describe("Whether use global moving statistics instead of local batch-norm. "
8080
"This will force change batch-norm into a scale shift operator.");
8181
DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
82-
.describe("Output All,normal mean and var");
82+
.describe("Output the mean and inverse std ");
8383
DMLC_DECLARE_FIELD(axis).set_default(mxnet::op::batchnorm::DEFAULT_AXIS)
8484
.describe("Specify which shape axis the channel is specified");
8585
DMLC_DECLARE_FIELD(cudnn_off).set_default(false)

src/operator/nn/batch_norm.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,8 @@ Both *mean* and *var* returns a scalar by treating the input as a vector.
510510
511511
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
512512
have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and
513-
``data_var`` as well, which are needed for the backward pass.
513+
the inverse of ``data_var``, which are needed for the backward pass. Note that gradient of these
514+
two outputs are blocked.
514515
515516
Besides the inputs and the outputs, this operator accepts two auxiliary
516517
states, ``moving_mean`` and ``moving_var``, which are *k*-length

0 commit comments

Comments
 (0)