Skip to content

Commit

Permalink
Fix E1136 (PaddlePaddle#563)
Browse files Browse the repository at this point in the history
  • Loading branch information
icyblade authored Sep 21, 2023
1 parent 229080b commit 187c2a0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
3 changes: 2 additions & 1 deletion flash_attn/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import OrderedDict, namedtuple
from collections.abc import Sequence
from functools import partial
from typing import Dict, List

import torch
import torch.nn as nn
Expand Down Expand Up @@ -810,7 +811,7 @@ def shard_qkv_headdim(state_dict, key):
return state_dict


def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config):
def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config):
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to
the state_dict of a standard GPT model.
Expand Down
16 changes: 8 additions & 8 deletions flash_attn/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from collections import OrderedDict
from pathlib import Path
from typing import Union
from typing import Dict, List, Union

import torch
import torch.nn.functional as F
Expand All @@ -17,8 +17,8 @@


def remap_state_dict_meta_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor], config: GPT2Config
) -> Dict[str, torch.Tensor]:
"""Convert the state_dict in Meta format to standard GPT format.
This function modifies state_dict in place.
Expand Down Expand Up @@ -113,8 +113,8 @@ def key_mapping_attn(key):


def remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor], config: GPT2Config
) -> Dict[str, torch.Tensor]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place.
Expand Down Expand Up @@ -217,8 +217,8 @@ def key_mapping_attn(key):


def inv_remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor], config: GPT2Config
) -> Dict[str, torch.Tensor]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
Expand Down Expand Up @@ -382,7 +382,7 @@ def config_from_checkpoint(

def state_dicts_from_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> list[dict]:
) -> List[dict]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return [
torch.load(path, map_location="cpu")
Expand Down

0 comments on commit 187c2a0

Please sign in to comment.