-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathAdaMod.py
116 lines (99 loc) · 4.88 KB
/
AdaMod.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend_config
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
class AdaMod(optimizer_v2.OptimizerV2):
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
beta_3=0.9995,
epsilon=1e-8,
name='AdaMod',
**kwargs):
super(AdaMod, self).__init__(name, **kwargs)
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
self._set_hyper('decay', self._initial_decay)
self._set_hyper('beta_1', beta_1)
self._set_hyper('beta_2', beta_2)
self._set_hyper('beta_3', beta_3)
self.epsilon = epsilon or backend_config.epsilon()
def _create_slots(self, var_list):
# Create slots for the first and second moments.
# Separate for-loops to respect the ordering of slot variables from v1.
for var in var_list:
self.add_slot(var, 'm')
for var in var_list:
self.add_slot(var, 'v')
for var in var_list:
self.add_slot(var, 'exp_avg_lr')
def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdaMod, self)._prepare_local(var_device, var_dtype, apply_state)
local_step = math_ops.cast(self.iterations + 1, var_dtype)
beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
beta_3_t = array_ops.identity(self._get_hyper('beta_3', var_dtype))
beta_1_power = math_ops.pow(beta_1_t, local_step)
beta_2_power = math_ops.pow(beta_2_t, local_step)
lr = (apply_state[(var_device, var_dtype)]['lr_t'] *
(math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)))
apply_state[(var_device, var_dtype)].update(dict(
lr=lr,
epsilon=ops.convert_to_tensor(self.epsilon, var_dtype),
beta_1_t=beta_1_t,
beta_1_power=beta_1_power,
one_minus_beta_1_t=1 - beta_1_t,
beta_2_t=beta_2_t,
beta_2_power=beta_2_power,
one_minus_beta_2_t=1 - beta_2_t,
beta_3_t=beta_3_t,
))
def set_weights(self, weights):
params = self.weights
num_vars = int((len(params) - 1) / 2)
if len(weights) == 3 * num_vars + 1:
weights = weights[:len(params)]
super(AdaMod, self).set_weights(weights)
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, 'm')
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
m_t = state_ops.assign(m, m * coefficients['beta_1_t'] + m_scaled_g_values,
use_locking=self._use_locking)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, 'v')
v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
v_t = state_ops.assign(v, v * coefficients['beta_2_t'] + v_scaled_g_values,
use_locking=self._use_locking)
denom = math_ops.sqrt(v_t) + coefficients['epsilon']
step_size = coefficients['lr'] / denom
# exp_avg_lr.mul_(group['beta3']).add_(1 - group['beta3'], step_size)
exp_avg_lr = self.get_slot(var, 'exp_avg_lr')
exp_avg_lr = state_ops.assign(
exp_avg_lr,
exp_avg_lr * coefficients['beta_3_t'] + (1.0 - coefficients['beta_3_t']) * step_size,
use_locking=self._use_locking)
step_size = math_ops.minimum(step_size, exp_avg_lr)
var_update = state_ops.assign_sub(
var, m_t * step_size,
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t, exp_avg_lr])
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
raise RuntimeError('This optimizer does not support sparse gradients.')
def get_config(self):
config = super(AdaMod, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'decay': self._serialize_hyperparameter('decay'),
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'beta_3': self._serialize_hyperparameter('beta_3'),
'epsilon': self.epsilon,
})
return config