Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve flow.load for better error message #10138

Merged
merged 28 commits into from
Apr 23, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7aaea2e
improve flow.load
daquexian Apr 14, 2023
5bc0e06
fix tests
daquexian Apr 14, 2023
c291bed
auto format by CI
oneflow-ci-bot Apr 14, 2023
9bd8b34
fix save()
daquexian Apr 17, 2023
02993fb
Merge branch 'improve_load' of github.com:Oneflow-Inc/oneflow into im…
daquexian Apr 17, 2023
21fb958
fix save()
daquexian Apr 17, 2023
f723ead
Merge branch 'master' into improve_load
mergify[bot] Apr 17, 2023
f9fb08c
Merge branch 'master' into improve_load
mergify[bot] Apr 17, 2023
f40e594
Merge branch 'master' into improve_load
mergify[bot] Apr 17, 2023
c89e59b
fix mock torch bug
daquexian Apr 19, 2023
fb47f3c
Merge branch 'master' into improve_load
mergify[bot] Apr 19, 2023
a64d71c
auto format by CI
oneflow-ci-bot Apr 19, 2023
0185343
Merge branch 'master' into improve_load
mergify[bot] Apr 19, 2023
ed02249
Merge branch 'master' into improve_load
mergify[bot] Apr 19, 2023
25db152
Merge branch 'master' into improve_load
mergify[bot] Apr 20, 2023
c1f37e3
Merge branch 'master' into improve_load
mergify[bot] Apr 20, 2023
394892f
Merge branch 'master' into improve_load
mergify[bot] Apr 20, 2023
29cf37f
Merge branch 'master' into improve_load
mergify[bot] Apr 20, 2023
bfcf153
Merge branch 'master' into improve_load
mergify[bot] Apr 20, 2023
1315ea9
Merge branch 'master' into improve_load
mergify[bot] Apr 21, 2023
bd76d0d
Merge branch 'master' into improve_load
mergify[bot] Apr 21, 2023
5f73640
Merge branch 'master' into improve_load
mergify[bot] Apr 21, 2023
1ba171c
Merge branch 'master' into improve_load
mergify[bot] Apr 21, 2023
111fd66
Merge branch 'master' into improve_load
mergify[bot] Apr 21, 2023
7e7095e
Merge branch 'master' into improve_load
mergify[bot] Apr 23, 2023
1eaa8bf
Merge branch 'master' into improve_load
mergify[bot] Apr 23, 2023
3c815ad
Merge branch 'master' into improve_load
mergify[bot] Apr 23, 2023
9f81a95
Merge branch 'master' into improve_load
mergify[bot] Apr 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 41 additions & 29 deletions python/oneflow/framework/check_point_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
MAP_LOCATION: TypeAlias = Optional[
Union[Callable[[Tensor, str], Tensor], flow.device, str, flow.placement]
]
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes], Path]
FILE_LIKE: TypeAlias = Union[os.PathLike, BinaryIO, IO[bytes], Path]


class _opener(object):
Expand Down Expand Up @@ -107,7 +107,7 @@ def _open_file_like(path_or_buffer, mode):


def _is_path(path_or_buffer):
return isinstance(path_or_buffer, str) or isinstance(path_or_buffer, Path)
return isinstance(path_or_buffer, Path)


def _check_seekable(f) -> bool:
Expand Down Expand Up @@ -475,9 +475,8 @@ def tensor_pickling_context(


def is_oneflow_pickle_file(path: FILE_LIKE, support_pytorch_format: bool) -> bool:
if _is_path(path):
if not path.is_file():
return False
if _is_path(path) and not path.is_file():
return False
try:
with _open_file_like(path, "rb") as f:
content = pickle.load(f)
Expand Down Expand Up @@ -517,20 +516,27 @@ def load_from_oneflow_single_file(
def is_file_and_support_pytorch_format(
path: FILE_LIKE, support_pytorch_format: bool
) -> bool:
if _is_path(path):
return path.is_file() and support_pytorch_format
return support_pytorch_format
if not support_pytorch_format:
return False
if _is_path(path) and not path.is_file():
return False
try:
with flow.mock_torch.disable():
import torch

content = torch.load(path, map_location="cpu")
return True, (content,)
except:
return False


@load_if(is_file_and_support_pytorch_format)
def load_from_pytorch_file(
path: FILE_LIKE, global_src_rank, map_location: MAP_LOCATION,
path: FILE_LIKE, global_src_rank, map_location: MAP_LOCATION, torch_obj: Any = None
):
with flow.mock_torch.disable():
import torch

if global_src_rank is None or global_src_rank == flow.env.get_rank():
torch_obj = torch.load(path, map_location="cpu")
if torch_obj is not None:
with flow.mock_torch.disable():
import torch

def torch_tensor_to_flow(x):
if isinstance(x, torch.Tensor):
Expand All @@ -539,17 +545,17 @@ def torch_tensor_to_flow(x):
return x

flow_obj = ArgsTree(torch_obj).map_leaf(torch_tensor_to_flow)
else:
flow_obj = None
if global_src_rank is not None:
flow_obj = flow.utils.global_view.to_global(
flow_obj,
placement=flow.placement("cpu", [global_src_rank]),
sbp=flow.sbp.broadcast,
warn_on_non_tensor_leaf=False,
)
flow_obj = _map_location(flow_obj, map_location)
return flow_obj
else:
flow_obj = None
if global_src_rank is not None:
flow_obj = flow.utils.global_view.to_global(
flow_obj,
placement=flow.placement("cpu", [global_src_rank]),
sbp=flow.sbp.broadcast,
warn_on_non_tensor_leaf=False,
)
flow_obj = _map_location(flow_obj, map_location)
return flow_obj


def is_dir_and_has_pickle_file(path: FILE_LIKE, support_pytorch_format: bool) -> bool:
Expand Down Expand Up @@ -585,7 +591,7 @@ def load_from_oneflow_pickle_dir(


def load(
path: FILE_LIKE,
path: Union[FILE_LIKE, str],
global_src_rank: Optional[int] = None,
map_location: MAP_LOCATION = None,
*,
Expand All @@ -611,7 +617,7 @@ def load(
The loaded object
"""
if isinstance(path, str):
path: Path = Path(path)
path = Path(path)
rank = flow.env.get_rank()
if global_src_rank is None or global_src_rank == rank:
for i, (condition, load) in enumerate(load_methods):
Expand All @@ -621,7 +627,11 @@ def load(
_broadcast_py_object(i, global_src_rank)
break
else:
raise NotImplementedError("No valid load method found for {}".format(path))
if _is_path(path):
err_msg = f'Cannot load file "{path}"'
else:
err_msg = "Cannot load the data"
raise ValueError(err_msg)
else:
i = _broadcast_py_object(None, global_src_rank)
load = load_methods[i][1]
Expand Down Expand Up @@ -687,6 +697,9 @@ def save(
save_as_external_data (bool): useful only if path_or_buffer is a string or
os.PathLike object containing a file name
"""
if isinstance(path_or_buffer, str):
path_or_buffer = Path(path_or_buffer)

if isinstance(obj, graph_util.Graph):
if not _is_path(path_or_buffer):
raise ValueError(
Expand All @@ -703,7 +716,6 @@ def save(
pickled_bytes = pickle.dumps(obj)

if _is_path(path_or_buffer) and save_as_external_data:
path_or_buffer: Path = Path(path_or_buffer)
path_or_buffer.mkdir(exist_ok=True)
path_or_buffer = path_or_buffer / PICKLE_FILENAME

Expand Down
20 changes: 16 additions & 4 deletions python/oneflow/test/exceptions/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,28 @@
import torch


@flow.unittest.skip_unless_1n1d()
class TestSaveLoad(flow.unittest.TestCase):
@flow.unittest.skip_unless_1n1d()
def test_support_pytorch_with_global_src_rank(test_case):
conv_torch = torch.nn.Conv2d(3, 3, 3)
conv_flow = flow.nn.Conv2d(3, 3, 3)
with tempfile.NamedTemporaryFile() as f:
torch.save(conv_torch.state_dict(), f.name)
with test_case.assertRaises(NotImplementedError) as ctx:
conv_flow.load_state_dict(flow.load(f.name, support_pytorch=False))
test_case.assertTrue("No valid load method found" in str(ctx.exception))
with test_case.assertRaises(ValueError) as ctx:
conv_flow.load_state_dict(
flow.load(f.name, support_pytorch_format=False)
)
test_case.assertTrue("Cannot load file" in str(ctx.exception))

def test_load_invalid_file(test_case):
f = tempfile.NamedTemporaryFile()
f.write(b"invalid file")
f.flush()
with test_case.assertRaises(ValueError) as ctx:
flow.load(f.name)
test_case.assertTrue("Cannot load file" in str(ctx.exception))

f.close()


if __name__ == "__main__":
Expand Down