Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-58]Layer Normalization in C++ #10029

Merged
merged 12 commits into from
Mar 10, 2018
Merged

[MXNET-58]Layer Normalization in C++ #10029

merged 12 commits into from
Mar 10, 2018

Conversation

sxjscience
Copy link
Member

Description

  1. Directly implement layer normalization in C++. The speed and memory cost are both better than the way of stacking the broadcast/reduce OPs. Solves [OP] LayerNorm in MXNet #9950
  2. Add LayerNorm in Gluon
  3. Fix the doc of InstanceNorm. In InstanceNorm, the real axis to normalize the input tensor is all axes excluding the 0th axis and the given axis.
  4. Fix the doc of BatchNorm, the inverse std instead of the var is set as the output. Should fix Loss of Precision in BatchNorm and output_var may be wrong #9216

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • LayerNorm in C++/Gluon, tests
  • Fix Doc of InstanceNorm
  • Fix Doc of BatchNorm

Comments

We can improve the speed further by fusing the operators. This is left as future work.

@sxjscience
Copy link
Member Author

@fhieber @tdomhan You could try this after it gets merged.

@sxjscience sxjscience changed the title Layer Norm [MXNET-58]Layer Norm Mar 8, 2018
@sxjscience sxjscience changed the title [MXNET-58]Layer Norm [MXNET-58]Layer Normalization in C++ Mar 8, 2018
@szha szha self-assigned this Mar 8, 2018
@sxjscience
Copy link
Member Author

sxjscience commented Mar 8, 2018

@fhieber
Copy link
Contributor

fhieber commented Mar 8, 2018

@sxjscience fantastic, thank you! We will definitely try this as soon as its available!

@sxjscience
Copy link
Member Author

Does anyone has time to review it? The doc page of the latest build is in http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-10029/7/index.html

@zhanghang1989
Copy link
Contributor

The docs look good to me 👍

using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const TShape &dshape = in_shape->at(layernorm::kData);
int axis = param.axis;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: const int


def test_layer_norm():
for dtype in [np.float16, np.float32, np.float64]:
check_layer_normalization((10, 12, 5), -1, 1E-3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is any axis allowed?
Can you check all possiblities (even if they theoretically overlap))? -2, -1, 0, 1, 2 (for 3D)
How about 1D and 2D? Are those relevant for this operator?

@@ -2413,6 +2413,47 @@ def test_l2_normalization():
check_l2_normalization((nbatch, nchannel, height, width), mode)


def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a nested function in check_layer_normalization?

exe.arg_dict['beta'][:] = beta
out_nd = exe.forward()[0]
out = npy_layer_norm(data, gamma, beta, axis, eps)
assert_allclose(out, out_nd.asnumpy(), 1E-4, 1E-4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the correctness test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it compares it with a numpy version.

check_layer_normalization((10, 12, 5), -1, 1E-3)
check_layer_normalization((10, 12, 5), 0, 1E-3)
check_layer_normalization((10, 12, 5), 1, 1E-3)
for in_shape in [(10, 6, 5), (5, 5), (2, 3, 3, 3)]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

beta_initializer='zeros', gamma_initializer='ones',
in_channels=0, prefix=None, params=None):
super(LayerNorm, self).__init__(prefix=prefix, params=params)
self._kwargs = {'eps': epsilon, 'axis': axis}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

center, scale

DMLC_DECLARE_FIELD(axis).set_default(-1)
.describe("The axis to perform layer normalization. "
"Usually, this should be be axis of the channel dimension. "
"Negative values means indexing from right to left. ");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra space at the end

DMLC_DECLARE_FIELD(eps).set_default(1e-5f)
.describe("An `epsilon` parameter to prevent division by 0.");
DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
.describe("Output the mean and std calculated along the given axis");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

period

@marcoabreu
Copy link
Contributor

Do you have any benchmarks regarding statement 1?

@szha szha merged commit 279ccb1 into apache:master Mar 10, 2018
@sxjscience
Copy link
Member Author

@marcoabreu Yes, here the benchmark result. My reference implementation is the following LayerNorm that is implemented by stacking broadcasting/reducing operators:

class LayerNormStackSmallOp(HybridBlock):
    """Applies layer normalization to the n-dimensional input array.
    Stack bcast/reduce
    """
    def __init__(self, axis=1, epsilon=1e-5, center=True, scale=True,
                 beta_initializer='zeros', gamma_initializer='ones',
                 in_channels=0, prefix=None, params=None):
        super(LayerNormStackSmallOp, self).__init__(prefix=prefix, params=params)
        self._kwargs = {'eps': epsilon, 'axis': axis}
        self._axis = axis
        self._epsilon = epsilon
        self._center = center
        self._scale = scale
        assert in_channels != 0, "in_channels == 0 is currently not supported"
        if self._center:
            self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
                                         shape=(in_channels,), init=gamma_initializer,
                                         allow_deferred_init=True)
        if self._scale:
            self.beta = self.params.get('beta', grad_req='write' if center else 'null',
                                        shape=(in_channels,), init=beta_initializer,
                                        allow_deferred_init=True)

    def moments(self, F, data):
        mean = F.mean(data=data, axis=self._axis, keepdims=True)
        var = F.mean(F.square(F.broadcast_minus(data, mean)),
                     axis=self._axis, keepdims=True)
        return mean, var

    def hybrid_forward(self, F, data, gamma, beta):
        if not self._center and not self._scale:
            return data
        mean, var = self.moments(F, data)
        norm_data = F.broadcast_minus(data, mean)
        norm_data = F.broadcast_mul(norm_data, mx.sym.rsqrt(var + self._epsilon))
        norm_data = F.broadcast_mul(norm_data, gamma)
        norm_data = F.broadcast_add(norm_data, beta)
        return norm_data

I run the layer normalization on data with shape=(128, 1024, 100), axis=-1

Forward-only Time Peak GPU Memory
Layer Norm (Stack Small) 6.859ms 105MB
Layer Norm (Implemented) 4.784ms 53MB
Forward + Backward Time Peak GPU Memory
Layer Norm (Stack Small) 7.741 + 17.682 = 25.423ms 367MB
Layer Norm (Implemented) 5.137 + 10.943 = 16.08ms 53MB

@marcoabreu
Copy link
Contributor

marcoabreu commented Mar 10, 2018 via email

jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request Mar 30, 2018
* add layer_norm + fix batch_norm doc

* add test

* add layer normaliation in Gluon

* update

* fix __repr__ + lint

* fix doc

* fix threshold

* fix doc

* fix bug

* enable inplace + fix test

* try to fix test

* fix doc
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* add layer_norm + fix batch_norm doc

* add test

* add layer normaliation in Gluon

* update

* fix __repr__ + lint

* fix doc

* fix threshold

* fix doc

* fix bug

* enable inplace + fix test

* try to fix test

* fix doc
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* add layer_norm + fix batch_norm doc

* add test

* add layer normaliation in Gluon

* update

* fix __repr__ + lint

* fix doc

* fix threshold

* fix doc

* fix bug

* enable inplace + fix test

* try to fix test

* fix doc
@marvis
Copy link

marvis commented Sep 5, 2018

Is there a way to infer the in_channels? I am implementing Scale layer, which has the same problem.

assert in_channels != 0, "in_channels == 0 is currently not supported"

@sxjscience
Copy link
Member Author

sxjscience commented Sep 5, 2018 via email

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Loss of Precision in BatchNorm and output_var may be wrong
7 participants