19
19
# pylint: disable= arguments-differ
20
20
"""Basic neural network layers."""
21
21
__all__ = ['Sequential' , 'HybridSequential' , 'Dense' , 'Dropout' , 'Embedding' ,
22
- 'BatchNorm' , 'InstanceNorm' , 'Flatten' , 'Lambda' , 'HybridLambda' ]
22
+ 'BatchNorm' , 'InstanceNorm' , 'LayerNorm' , ' Flatten' , 'Lambda' , 'HybridLambda' ]
23
23
import warnings
24
24
import numpy as np
25
25
@@ -419,14 +419,18 @@ class InstanceNorm(HybridBlock):
419
419
420
420
.. math::
421
421
422
- out = \frac{x - mean[data]}{ \sqrt{Var[data]} + \epsilon} * gamma + beta
422
+ \bar{C} = \{i \mid i \neq 0, i \neq axis\}
423
+
424
+ out = \frac{x - mean[data, \bar{C}]}{ \sqrt{Var[data, \bar{C}]} + \epsilon}
425
+ * gamma + beta
423
426
424
427
Parameters
425
428
----------
426
429
axis : int, default 1
427
- The axis that should be normalized . This is typically the channels
430
+ The axis that will be excluded in the normalization process . This is typically the channels
428
431
(C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`,
429
- set `axis=1` in `InstanceNorm`. If `layout='NHWC'`, then set `axis=3`.
432
+ set `axis=1` in `InstanceNorm`. If `layout='NHWC'`, then set `axis=3`. Data will be
433
+ normalized along axes excluding the first axis and the axis given.
430
434
epsilon: float, default 1e-5
431
435
Small float added to variance to avoid dividing by zero.
432
436
center: bool, default True
@@ -475,7 +479,7 @@ def __init__(self, axis=1, epsilon=1e-5, center=True, scale=False,
475
479
beta_initializer = 'zeros' , gamma_initializer = 'ones' ,
476
480
in_channels = 0 , ** kwargs ):
477
481
super (InstanceNorm , self ).__init__ (** kwargs )
478
- self ._kwargs = {'eps' : epsilon , 'axis' : axis }
482
+ self ._kwargs = {'eps' : epsilon , 'axis' : axis , 'center' : center , 'scale' : scale }
479
483
self ._axis = axis
480
484
self ._epsilon = epsilon
481
485
self .gamma = self .params .get ('gamma' , grad_req = 'write' if scale else 'null' ,
@@ -502,6 +506,91 @@ def __repr__(self):
502
506
content = ', ' .join (['=' .join ([k , v .__repr__ ()])
503
507
for k , v in self ._kwargs .items ()]))
504
508
509
+
510
+ class LayerNorm (HybridBlock ):
511
+ r"""
512
+ Applies layer normalization to the n-dimensional input array.
513
+ This operator takes an n-dimensional input array and normalizes
514
+ the input using the given axis:
515
+
516
+ .. math::
517
+
518
+ out = \frac{x - mean[data, axis]}{ \sqrt{Var[data, axis]} + \epsilon} * gamma + beta
519
+
520
+ Parameters
521
+ ----------
522
+ axis : int, default -1
523
+ The axis that should be normalized. This is typically the axis of the channels.
524
+ epsilon: float, default 1e-5
525
+ Small float added to variance to avoid dividing by zero.
526
+ center: bool, default True
527
+ If True, add offset of `beta` to normalized tensor.
528
+ If False, `beta` is ignored.
529
+ scale: bool, default True
530
+ If True, multiply by `gamma`. If False, `gamma` is not used.
531
+ beta_initializer: str or `Initializer`, default 'zeros'
532
+ Initializer for the beta weight.
533
+ gamma_initializer: str or `Initializer`, default 'ones'
534
+ Initializer for the gamma weight.
535
+ in_channels : int, default 0
536
+ Number of channels (feature maps) in input data. If not specified,
537
+ initialization will be deferred to the first time `forward` is called
538
+ and `in_channels` will be inferred from the shape of input data.
539
+
540
+
541
+ Inputs:
542
+ - **data**: input tensor with arbitrary shape.
543
+
544
+ Outputs:
545
+ - **out**: output tensor with the same shape as `data`.
546
+
547
+ References
548
+ ----------
549
+ `Layer Normalization
550
+ <https://arxiv.org/pdf/1607.06450.pdf>`_
551
+
552
+ Examples
553
+ --------
554
+ >>> # Input of shape (2, 5)
555
+ >>> x = mx.nd.array([[1, 2, 3, 4, 5], [1, 1, 2, 2, 2]])
556
+ >>> # Layer normalization is calculated with the above formula
557
+ >>> layer = LayerNorm()
558
+ >>> layer.initialize(ctx=mx.cpu(0))
559
+ >>> layer(x)
560
+ [[-1.41421 -0.707105 0. 0.707105 1.41421 ]
561
+ [-1.2247195 -1.2247195 0.81647956 0.81647956 0.81647956]]
562
+ <NDArray 2x5 @cpu(0)>
563
+ """
564
+ def __init__ (self , axis = - 1 , epsilon = 1e-5 , center = True , scale = True ,
565
+ beta_initializer = 'zeros' , gamma_initializer = 'ones' ,
566
+ in_channels = 0 , prefix = None , params = None ):
567
+ super (LayerNorm , self ).__init__ (prefix = prefix , params = params )
568
+ self ._kwargs = {'eps' : epsilon , 'axis' : axis , 'center' : center , 'scale' : scale }
569
+ self ._axis = axis
570
+ self ._epsilon = epsilon
571
+ self ._center = center
572
+ self ._scale = scale
573
+ self .gamma = self .params .get ('gamma' , grad_req = 'write' if scale else 'null' ,
574
+ shape = (in_channels ,), init = gamma_initializer ,
575
+ allow_deferred_init = True )
576
+ self .beta = self .params .get ('beta' , grad_req = 'write' if center else 'null' ,
577
+ shape = (in_channels ,), init = beta_initializer ,
578
+ allow_deferred_init = True )
579
+
580
+ def hybrid_forward (self , F , data , gamma , beta ):
581
+ norm_data = F .LayerNorm (data , gamma = gamma , beta = beta , axis = self ._axis , eps = self ._epsilon )
582
+ return norm_data
583
+
584
+ def __repr__ (self ):
585
+ s = '{name}({content}'
586
+ in_channels = self .gamma .shape [0 ]
587
+ s += ', in_channels={0}' .format (in_channels )
588
+ s += ')'
589
+ return s .format (name = self .__class__ .__name__ ,
590
+ content = ', ' .join (['=' .join ([k , v .__repr__ ()])
591
+ for k , v in self ._kwargs .items ()]))
592
+
593
+
505
594
class Lambda (Block ):
506
595
r"""Wraps an operator or an expression as a Block object.
507
596
0 commit comments