forked from deepspeedai/Megatron-DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy path__init__.py
420 lines (401 loc) · 16.2 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from deepspeed.accelerator import get_accelerator
import torch
from typing import Callable, Any, Iterable, Union
from megatron import get_args
from .distrib_optimizer import DistributedOptimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
import ezpz as ez
RANK = ez.get_rank()
def get_param_groups(
modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
no_weight_decay_cond: Callable[[str, torch.Tensor], bool],
scale_lr_cond: Callable[[str, torch.Tensor], bool],
lr_mult: Any,
use_galore: bool = False,
):
"""
Creates param groups (regularized vs non) based on:
- weight decay condition.
- learning rate scale condition (args.lr vs lr_mult * args.lr)
- scale_lr_cond is used during finetuning, where head of the network
requires a scaled version of the base learning rate.
# if 'galore' in args.optimizer.lower():
# # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
# galore_params = []
# target_modules_list = ["attn", "mlp"]
# # for module_name, module in param_groups:
# for group_id, group in enumerate(param_groups):
# for param, p in enumerate(group['params']):
# if not isinstance(module, torch.nn.Linear):
# continue
# if not any(target_key in module_name for target_key in target_modules_list):
# continue
# print('enable GaLore for weights in module: ', module_name)
# galore_params.append(module.weight)
# id_galore_params = [id(p) for p in galore_params]
# # make parameters without "rank" to another group
# regular_params = [p for p in param_groups if id(p) not in id_galore_params]
# # then call galore_adamw
# param_groups = [
# {
# 'params': regular_params
# },
# {
# 'params': galore_params,
# 'rank': RANK,
# 'update_proj_gap': args.update_proj_gap,
# 'scale': args.galore_scale,
# 'proj_type': args.proj_type
# }
# ]
"""
wd_no_scale_lr = []
wd_scale_lr = []
no_wd_no_scale_lr = []
no_wd_scale_lr = []
galore_params = []
target_modules_list = ["attn", "mlp"]
for module in modules:
for name, param in module.named_parameters():
if not param.requires_grad:
continue
if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else:
# do not regularize biases nor Norm parameters
no_wd = name.endswith(".bias") or len(param.shape) == 1
if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
wd_scale_lr.append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
no_wd_scale_lr.append(param)
param_groups = []
if len(wd_no_scale_lr):
param_groups.append({'name': 'wd_no_scale_lr', 'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
if len(wd_scale_lr):
param_groups.append({'name': 'wd_scale_lr', 'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append({'name': 'no_wd_no_scale_lr', 'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
if len(no_wd_scale_lr):
param_groups.append({'name': 'no_wd_scale_lr', 'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
return param_groups
def get_megatron_optimizer(
model,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0
):
args = get_args()
assert args is not None
# Base optimizer.
param_groups = get_param_groups(
model,
no_weight_decay_cond,
scale_lr_cond,
lr_mult
)
if args.create_moe_param_group:
from deepspeed.moe.utils import (
split_params_into_different_moe_groups_for_optimizer
)
param_groups = split_params_into_different_moe_groups_for_optimizer(
param_groups
)
optimizer = None
# ---- CPU Optimizer --------------------------------------
if args.cpu_optimizer:
assert args.optimizer == 'adam', 'CPU offloading is for Adam'
if args.cpu_torch_adam:
cpu_adam_optimizer = torch.optim.AdamW
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam
cpu_adam_optimizer = DeepSpeedCPUAdam
optimizer = cpu_adam_optimizer(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
)
# ---- Adam --------------------------------------
elif args.optimizer == 'adam':
if args.ds_fused_adam:
# global Adam
from deepspeed.ops.adam import FusedAdam
Adam = FusedAdam
else:
Adam = torch.optim.Adam
optimizer = Adam(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
)
# ---- apex.Adam --------------------------------------------
elif str(args.optimizer).lower() == 'apex.adam':
assert get_accelerator().device_name() == 'cuda'
from apex.optimizers import FusedAdam as Adam
optimizer = Adam(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
)
# ---- Adam8Bit --------------------------------------
elif args.optimizer.lower() == "adam8bit":
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay)
# ---- AdamW --------------------------------------
elif str(args.optimizer).lower() == 'adamw':
optimizer = torch.optim.AdamW(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps
)
# ---- AdamW: ScheduleFree -------------------------------------
elif str(args.optimizer).lower() == 'adamwschedulefree':
import schedulefree
optimizer = schedulefree.AdamWScheduleFree(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
warmup_steps=args.lr_warmup_iters,
foreach=args.schedulefree_for_each,
)
# ---- AdamW: Galore ------------------------------------------
elif args.optimizer.lower() == "galore_adamw":
from galore_torch import GaLoreAdamW
# redefine way to call galore_adamw
optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
# elif args.optimizer.lower() == "galore_adamw":
# from galore_torch import GaLoreAdamW
# # redefine way to call galore_adamw
# optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
# ---- AdamW: GaloRe 8Bit --------------------------------------
elif args.optimizer.lower() == "galore_adamw8bit":
from galore_torch import GaLoreAdamW8bit
optimizer = GaLoreAdamW8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay)
# ---- AdamW8bitPerLayer: GaloRE ----------------------------
elif args.optimizer.lower() == 'galore_adamw8bit_per_layer':
from galore_torch import GaLoreAdamW8bit
# TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap
optimizer_dict = {}
for p in model.parameters():
if p.requires_grad:
if id(p) in id_galore_params:
optimizer_dict[p] = GaLoreAdamW8bit([{'params': [p], 'rank': args.rank, 'update_proj_gap': args.update_proj_gap * 2, 'scale': args.galore_scale, 'proj_type': args.proj_type}], lr=args.lr, weight_decay=args.weight_decay)
else:
optimizer_dict[p] = bnb.optim.Adam8bit([p], lr=args.lr, weight_decay=args.weight_decay)
# get scheduler dict
scheduler_dict = {}
from galore_torch.peft_pretraining import training_utils
for p in model.parameters():
if p.requires_grad:
scheduler_dict[p] = training_utils.get_scheculer(
optimizer=optimizer_dict[p],
scheduler_type=args.scheduler,
num_training_steps=args.num_training_steps * 2,
warmup_steps=args.warmup_steps * 2,
min_lr_ratio=args.min_lr_ratio,
)
def optimizer_hook(p):
if p.grad is None:
return
optimizer_dict[p].step()
optimizer_dict[p].zero_grad()
scheduler_dict[p].step()
# Register the hook onto every parameter
for p in model.parameters():
if p.requires_grad:
p.register_post_accumulate_grad_hook(optimizer_hook)
layer_wise_flag = True
# ---- AdaFactor --------------------------------------
elif args.optimizer.lower() == "adafactor":
import transformers
args.beta1 = None if args.beta1 == 0.0 else args.beta1
optimizer = transformers.optimization.Adafactor(
param_groups,
lr=args.lr,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=args.beta1,
weight_decay=args.weight_decay,
relative_step=False,
scale_parameter=False,
warmup_init=False,
)
# ---- GaLore: Adafactor adafactor ------------------------------------
elif args.optimizer.lower() == "galore_adafactor":
from galore_torch import GaLoreAdafactor
args.beta1 = None if args.beta1 == 0.0 else args.beta1
optimizer = GaLoreAdafactor(
param_groups,
lr=args.lr,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=args.beta1,
weight_decay=args.weight_decay,
relative_step=False,
scale_parameter=False,
warmup_init=False,
)
# ---- Apex: sgd ---------------------------------------------
elif str(args.optimizer).lower() == 'apex.sgd':
from apex.optimizers import FusedSGD as SGD
optimizer = SGD(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum
)
# ---- ScheduleFree: SGD -------------------------------
elif str(args.optimizer).lower() == 'sgdschedulefree':
import schedulefree
optimizer = schedulefree.SGDScheduleFree(
param_groups,
lr=args.lr,
momentum=args.sgd_momentum,
weight_decay=args.weight_decay,
warmup_steps=args.lr_warmup_iters,
foreach=args.schedulefree_for_each,
)
# ---- Lamb: Ipex --------------------------------------------
elif str(args.optimizer) == 'ipex.lamb':
from intel_extension_for_pytorch.optim._lamb import Lamb
optimizer = Lamb(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
)
# ---- Lamb(Fused): Ipex ----------------------------------------
elif str(args.optimizer) == 'ipex.fusedlamb':
from intel_extension_for_pytorch.optim._lamb import Lamb
optimizer = Lamb(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
fused=True,
)
# ---- Lamb(Fused): DeepSpeed ------------------------------------------
elif str(args.optimizer).lower() == 'ds.fusedlamb':
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
)
# ---- Shampoo ----------------------------------------
elif args.optimizer == 'shampoo':
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.shampoo_types import AdamGraftingConfig
optimizer = DistributedShampoo(
model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
epsilon=1e-12,
weight_decay=1e-05,
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
beta2=0.999,
epsilon=1e-08,
),
)
elif args.optimizer == 'sgd':
optimizer = torch.optim.SGD(
param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum
)
elif str(args.optimizer).lower() == 'sophiag':
from .sophia import SophiaG
optimizer = SophiaG(
param_groups,
lr=args.lr,
betas=(args.sophiag_beta1, args.sophiag_beta2),
rho = args.sophiag_rho,
weight_decay=args.weight_decay
)
else:
raise TypeError(f'{args.optimizer} optimizer is not supported.')
assert optimizer is not None
if args.deepspeed:
return optimizer
# Determine whether the params have main-grad field.
params_have_main_grad = False
if args.use_contiguous_buffers_in_local_ddp:
params_have_main_grad = True
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if args.fp16 or args.bf16 or args.use_distributed_optimizer:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale.
else:
if args.fp16:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
# Megatron optimizer.
opt_ty = (
DistributedOptimizer if args.use_distributed_optimizer
else Float16OptimizerWithFloat16Params
)
return opt_ty(optimizer,
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.fp16,
args.bf16,
args.params_dtype,
grad_scaler,
model)
# FP32.
return FP32Optimizer(
optimizer,
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
model
)