Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] DeepSpeed Zero 3 taking to much memory for FLAN-T5-XL (3B) #2797

Closed
philschmid opened this issue Feb 7, 2023 · 7 comments
Closed

[BUG] DeepSpeed Zero 3 taking to much memory for FLAN-T5-XL (3B) #2797

philschmid opened this issue Feb 7, 2023 · 7 comments
Assignees
Labels
bug Something isn't working training

Comments

@philschmid
Copy link

philschmid commented Feb 7, 2023

Describe the bug
I am tryiny to train FLAN-T5-XL using DeepSpeed zero 3 and transformers and it seems z3/ cpu offload seems to use quite a lot of gpu memory as compared to the expectations. I am running on 4x V100 16GB. And i ran the estimate_zero3_model_states_mem_needs_all_cold test which gave me the following results

SW: Model with 2849M total params, 65M largest layer params.
  per CPU  |  per GPU |   Options
   71.66GB |   0.25GB | offload_param=cpu , offload_optimizer=cpu , zero_init=1
   71.66GB |   0.25GB | offload_param=cpu , offload_optimizer=cpu , zero_init=0
   63.70GB |   5.55GB | offload_param=none, offload_optimizer=cpu , zero_init=1
   63.70GB |   5.55GB | offload_param=none, offload_optimizer=cpu , zero_init=0
    0.37GB |  48.02GB | offload_param=none, offload_optimizer=none, zero_init=1
   15.92GB |  48.02GB | offload_param=none, offload_optimizer=none, zero_init=0

I know that's only for weights+grads+optim states, but my assumption would be that with offloading it can fit on 4x V100 16GB using a BS 1

To Reproduce
Steps to reproduce the behavior:

1.Clone script and create directory
mkdir philipp
cd philipp
wget https://raw.githubusercontent.com/philschmid/deep-learning-pytorch-huggingface/deepspeed-example/training/scripts/run_seq2seq_deepspeed.py
wget https://raw.githubusercontent.com/philschmid/deep-learning-pytorch-huggingface/deepspeed-example/training/configs/ds_flan_t5_z3_config.json
2. create dataset with the following script
# experiment config
model_id = "google/flan-t5-base"
repository_id = "flan-t5-base-cnn"# Dataset
dataset_id = "cnn_dailymail"
dataset_config = "3.0.0"
save_dataset_path = "data"
text_column = "article"
summary_column = "highlights"
prompt_start = "Summarize the following news article:\n"
generation_start = "\nSummary:\n"
prompt_template = f"{prompt_start}{{input}}{generation_start}"max_source_length = 500
max_target_length = 129from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
​
​
# Load dataset from the hub
dataset = load_dataset(dataset_id, name=dataset_config)
# Load tokenizer of FLAN-t5-base
tokenizer = AutoTokenizer.from_pretrained(model_id)
​
print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")
​
prompt_lenght = len(tokenizer(prompt_template.format(input=""))["input_ids"])
max_sample_length = tokenizer.model_max_length - prompt_lenght
print(f"Prompt lenght: {prompt_lenght}")
print(f"Max input lenght: {max_sample_length}")
max_source_length = 500
max_target_length = 129import os
​
​
def preprocess_function(sample, padding="max_length"):
    # created prompted input
    inputs = [prompt_template.format(input=item) for item in sample[text_column]]
​
    # tokenize inputs
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
​
    # Tokenize targets with the `text_target` keyword argument
    labels = tokenizer(
        text_target=sample[summary_column], max_length=max_target_length, padding=padding, truncation=True
    )
​
    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length":
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]
​
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
​
​
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=list(dataset["train"].features))
print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")
​
tokenized_dataset["train"].save_to_disk(os.path.join(save_dataset_path, "train"))
tokenized_dataset["test"].save_to_disk(os.path.join(save_dataset_path, "eval"))
  1. run script and get high memory usage
deepspeed --num_gpus=1 run_seq2seq_deepspeed.py --model_id google/flan-t5-xl --dataset_path data --epochs 3 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --generation_max_length 1 --lr 1e-5 --deepspeed ds_flan_t5_z3_config.json --block_size 1 --gradient_checkpointing 1

Expected behavior
My Assumption is that training with Zero 3 and offload should work on 4x V100

System info (please complete the following information):

  • transformers version: 4.26.0
  • Platform: Linux-5.15.0-1027-aws-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.12.0
  • PyTorch version (GPU?): 1.13.1+cu116 (True)
  • deepspeed: '0.8.0'
@philschmid philschmid added bug Something isn't working compression labels Feb 7, 2023
@stas00
Copy link
Collaborator

stas00 commented Feb 7, 2023

cc: @tjruwase

I was able to reproduce this large memory usage on a single a100.

The gpu memory gets cleared before each iteration to almost 0, and then peaks to about 20GB during each iteration.

I suspect that a lot of this memory is normal allocations but it's very likely it shouldn't need 20GB if offload is to work efficiently.

If I turn both offloads off then the memory usage goes up to 70GB. So we definitely know offload works, just perhaps not as efficiently as we would like it to be.

@stas00
Copy link
Collaborator

stas00 commented Feb 9, 2023

Hi Philipp,

So Tunji and I spent some time researching this Issue - we plugged a bunch of see_memory_usage calls through t5 and it's all activations memory and it's a lot of it.

That is the offload works correctly, no bugs there.

We were discussing creating a calculator to estimate activations memory usage to resolve such situations easier.

Now, fear not, I have a solution for your situation that comes from HF transformers:

You just need to change your trainer to pass on --gradient_checkpointing 1, which is an HF Trainer feature and I tested on a single gpu that drops about 1/2 the gpu memory usage, so with a small batch size / seq_len you should be able to fit everything on 4x16 gpu v100 gpus. Please let me know if you don't.

The reason my initial --gradient_checkpointing 1 didn't work, is that your program silently ignores args it's not passing on, so it was just being ignored. If you need help with that let's discuss on slack.

@philschmid
Copy link
Author

Can confirm with gradient_checkpointing, i was able to fit the mode and dataset using offloading!

@nonstopfor
Copy link

Hi Philipp,

So Tunji and I spent some time researching this Issue - we plugged a bunch of see_memory_usage calls through t5 and it's all activations memory and it's a lot of it.

That is the offload works correctly, no bugs there.

We were discussing creating a calculator to estimate activations memory usage to resolve such situations easier.

Now, fear not, I have a solution for your situation that comes from HF transformers:

You just need to change your trainer to pass on --gradient_checkpointing 1, which is an HF Trainer feature and I tested on a single gpu that drops about 1/2 the gpu memory usage, so with a small batch size / seq_len you should be able to fit everything on 4x16 gpu v100 gpus. Please let me know if you don't.

The reason my initial --gradient_checkpointing 1 didn't work, is that your program silently ignores args it's not passing on, so it was just being ignored. If you need help with that let's discuss on slack.

Hello! I tried to use activation checkpointing provided by deepspeed but the OOM still existed, while gradient checkpointing in huggingface trainer did solve the OOM problem. Isn't activation checkpointing in deepspeed the same as gradient checkpointing in huggingface traininer? I'm a bit confused now.

@stas00
Copy link
Collaborator

stas00 commented Mar 10, 2023

Activation checkpointing and gradient checkpointing are 2 terms for the same methodology.

Except HF Transformers models don't know anything about Deepspeed's activation checkpointing.

So if you want to use HF Transformers models you do model.gradient_checkpointing_enable() or use --gradient_checkpointing 1 in HF Trainer which will do this for you.

If you write your own model and you want to use Deepspeed's activation checkpointing you use the API prescribed there.

I hope I was able to help with clarity here, @nonstopfor

And I totally agree with that it's odd that the same feature acquired two different names - it took me a while to get used to that.

@nonstopfor
Copy link

Thanks very much! It works now. I was thinking deepspeed.initialize() may activate the gradient checkpointing automatically if activation checkpointing is set in the deepspeed config file.

@stas00
Copy link
Collaborator

stas00 commented Mar 10, 2023

Glad to hear you have it working, @nonstopfor

As I explained above you have to recode the modeling code to replace torch's checkpointing with deepspeed's version of it as explained here:
https://deepspeed.readthedocs.io/en/latest/activation-checkpointing.html#using-activation-checkpointing

Their checkpointing is more flexible since you could offload to cpu instead of recalculating, so if you want to squeeze more performance you could copy HF's model and replace our checkpointing with deepspeed's and configure it to use their advanced features.

But chances are that what HF provides out of the box is good enough performance-wise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

5 participants