Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat graph load to new device #10335

Merged
merged 25 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions oneflow/core/graph/stream_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ namespace oneflow {

// StreamId encoding (bits)
// | reserved | node_index | device_type | device_index | stream_index |
// | -- 21 -- | ----- 19 ----- | ---- 5 ---- | ----- 7 ----- | |
// | -- 18 -- | ----- 19 ----- | ---- 5 ---- | ----- 7 ----- | |
// | | DeviceId | |
// | | ------------------- 31 --------------------- | ---- 12 ---- |
// | | ------------------- 31 --------------------- | ---- 15 ---- |
// | StreamId |
// | -------------------------------- 64 ---------------------------------- |

Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/graph/task_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ namespace oneflow {

// TaskId encoding (maybe extended to 128 bits in future)
// | rank | device_type | device_index | |
// | ----------- 19 ----------- | ---- 5 ---- | ----- 7 ----- | |
// | ----------- 16 ----------- | ---- 5 ---- | ----- 7 ----- | |
// | DeviceId | stream_index | |
// | ------------------------- 31 --------------------------- | ---- 12 ---- | |
// | ------------------------- 31 --------------------------- | ---- 15 ---- | |
// | StreamId | task_index |
// | -------------------------------- 43 ----------------------------------- | --- 21 --- |
// | TaskId |
Expand Down
24 changes: 23 additions & 1 deletion python/oneflow/nn/graph/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from oneflow.framework.args_tree import ArgsTree
from oneflow.framework.tensor import Tensor
import oneflow as flow
import oneflow


class LRUCache(object):
Expand Down Expand Up @@ -134,6 +134,28 @@ def runtime_state_dict(
destination[state_dict["graph_name"]] = state_dict
return destination

@staticmethod
def runtime_state_dict_to(
state_dict: Union[
Dict[str, Union[Dict[str, Tensor], str]],
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
],
device: str,
) -> Union[
Dict[str, Union[Dict[str, Tensor], str]],
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
]:
destination = OrderedDict()
destination._metadata = OrderedDict()
for (key, sub_state_dict) in state_dict.items():
dest_sub_state_dict = oneflow.nn.Graph.runtime_state_dict_to(
sub_state_dict, device
)
dest_sub_state_dict["cache_order"] = sub_state_dict["cache_order"]
dest_sub_state_dict["cache_key"] = sub_state_dict["cache_key"]
destination[key] = dest_sub_state_dict
return destination

def _init_and_get_a_graph_in_cache(self, cache_key):
self._base_graph._print(
0,
Expand Down
78 changes: 65 additions & 13 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
GraphIR,
seq_to_func_return,
sys_exc_error_msg,
_rsd_sub_destination_to,
_job_to,
_plan_to,
)
from oneflow.framework.args_tree import ArgsTree
from oneflow.nn.modules.module import Module
Expand Down Expand Up @@ -1069,34 +1072,35 @@ def _fill_sub_destination(dest_dict, name_list, tensor_tuple):
assert len(tensor_tuple) == len(name_list)
for name_idx in range(len(name_list)):
tensor_item = tensor_tuple[name_idx]
dest_dict[name_list[name_idx]] = (tensor_item, tensor_item.device.type)
device_str = ":".join(
(tensor_item.device.type, str(tensor_item.device.index))
)
dest_dict[name_list[name_idx]] = (tensor_item, device_str)

# This is original outputs is needed to build output buffer.
tuple_idx = -1

def gen_index_in_tuple(eager_out):
def gen_index_in_tuple(item):
nonlocal tuple_idx
tuple_idx += 1
return "_OFTPI" + str(tuple_idx) # oneflow tuple index
if isinstance(item, Tensor):
tuple_idx += 1
return "_OFTPI" + str(tuple_idx) # oneflow tuple index
else:
return item

inputs_sub_destination = OrderedDict()
_fill_sub_destination(
inputs_sub_destination, self._input_op_names, self._inputs_tensor_tuple
)

_eager_inputs_args, _eager_inputs_kwargs = self.__map_io(
"input",
gen_index_in_tuple,
*self.inputs_original[0],
**self.inputs_original[1],
_eager_inputs_args, _eager_inputs_kwargs = self.__map_io_lite(
gen_index_in_tuple, *self.inputs_original[0], **self.inputs_original[1],
)
destination["inputs"] = inputs_sub_destination
destination["inputs_original"] = (_eager_inputs_args, _eager_inputs_kwargs)

tuple_idx = -1
_eager_outputs, _ = self.__map_io(
"output", gen_index_in_tuple, *self._eager_outputs
)
_eager_outputs, _ = self.__map_io_lite(gen_index_in_tuple, *self._eager_outputs)
destination["outputs_original"] = _eager_outputs
assert len(self._outputs_tensor_tuple) == tuple_idx + 1
outputs_sub_destination = OrderedDict()
Expand Down Expand Up @@ -1146,7 +1150,7 @@ def load_runtime_state_dict(
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
],
*,
warmup_with_run: bool = False,
warmup_with_run: bool = True,
) -> None:
if self._run_with_cache == True:
return self._dynamic_input_graph_cache.load_runtime_state_dict(
Expand Down Expand Up @@ -1293,6 +1297,7 @@ def get_tensor_in_tuple(tensor_tuple, map_item):
self.__run(
*_eager_inputs_args, **_eager_inputs_kwargs
) # pre-run to warm up
oneflow._oneflow_internal.eager.Sync()
build_graph_end = time.perf_counter()
self.__print(
0,
Expand All @@ -1304,6 +1309,53 @@ def get_tensor_in_tuple(tensor_tuple, map_item):
+ "\n",
)

@staticmethod
def runtime_state_dict_to(
state_dict: Union[
Dict[str, Union[Dict[str, Tensor], str]],
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
],
device: str,
) -> Union[
Dict[str, Union[Dict[str, Tensor], str]],
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
]:
if "job_id" not in state_dict:
from oneflow.nn.graph.cache import GraphCache

return GraphCache.runtime_state_dict_to(state_dict, device)

dest_device = oneflow.device(device)
assert dest_device.type == "cuda", "device must be cuda."

destination = OrderedDict()
destination._metadata = OrderedDict()
destination["oneflow_version"] = state_dict["oneflow_version"]
destination["graph_name"] = state_dict["graph_name"]
destination["job_id"] = state_dict["job_id"]
destination["inputs"] = _rsd_sub_destination_to(state_dict["inputs"], device)
destination["inputs_original"] = state_dict["inputs_original"]
destination["outputs"] = _rsd_sub_destination_to(state_dict["outputs"], device)
destination["outputs_original"] = state_dict["outputs_original"]
destination["oneflow_with_eager_tensor"] = state_dict[
"oneflow_with_eager_tensor"
]
if "states" in state_dict:
destination["states"] = _rsd_sub_destination_to(
state_dict["states"], device
)
destination["exe_plan"] = _plan_to(state_dict["exe_plan"], dest_device)
if "forward_graph" in state_dict:
forward_graph = deepcopy(state_dict["forward_graph"])
_job_to(forward_graph, dest_device)
destination["forward_graph"] = forward_graph
if "compile_graph" in state_dict:
compile_graph = deepcopy(state_dict["compile_graph"])
_job_to(compile_graph, dest_device)
destination["compile_graph"] = compile_graph
destination["id_state"] = state_dict["id_state"]
return destination

def build_graph(self, *args, **kwargs):
# Build graph
try:
Expand Down
119 changes: 119 additions & 0 deletions python/oneflow/nn/graph/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
import sys
from string import Template
from typing import Callable, Dict, Union, List, Tuple, Optional
from collections import OrderedDict

from google.protobuf import text_format
from google.protobuf.message import Message

import oneflow
import oneflow.core.job.job_pb2 as job_pb
import oneflow.core.job.plan_pb2 as plan_pb
import oneflow.core.common.device_type_pb2 as device_type
import oneflow.core.operator.op_conf_pb2 as op_conf_util
from oneflow.framework.tensor import Tensor

Expand Down Expand Up @@ -308,3 +311,119 @@ def seq_to_func_return(seq, need_unpack=False):
if need_unpack:
return seq[0]
return seq


def _rsd_sub_destination_to(origin_dict, dest_device_str):
dest_dict = OrderedDict()
for k, v in origin_dict.items():
tensor_item, device_str = v
dest_dict[k] = (
tensor_item.to(device=oneflow.device(dest_device_str), copy=True),
dest_device_str,
)
return dest_dict


def _parallel_conf_to(parallel_conf, dest_device):
if parallel_conf.device_tag == "cuda":
assert len(parallel_conf.device_name) == 1
parallel_conf.device_name[0] = "@0:" + str(dest_device.index)


def _mem_case_to(mem_case, dest_device):
if mem_case.device_type == device_type.DeviceType.kCUDA:
mem_case.device_id = dest_device.index
if (
mem_case.HasField("pinned_device_type")
and mem_case.pinned_device_type == device_type.DeviceType.kCUDA
):
mem_case.pinned_device_id = dest_device.index


def _job_to(job, dest_device):
for pg in job.placement.placement_group:
_parallel_conf_to(pg.parallel_conf, dest_device)
for bpg in job.placement.blob_placement_group:
_parallel_conf_to(bpg.parallel_conf, dest_device)


def _modify_bits(original_num, k, j, new_num):
if k > j:
return original_num
mask = ((1 << (j - k + 1)) - 1) << k
cleared_num = original_num & ~mask
modified_num = cleared_num | ((new_num & ((1 << (j - k + 1)) - 1)) << k)
return modified_num


def _get_bits(original_num, k, j):
mask = ((1 << (j - k + 1)) - 1) << k
cleared_num = (original_num & mask) >> k

return cleared_num


def _task_id_to(task_id, dest_device):
if _get_bits(task_id, 43, 48) == 2:
new_id = _modify_bits(task_id, 36, 43, dest_device.index)

return new_id
else:
return task_id


def _thrd_id_to(thrd_id, dest_device):
if _get_bits(thrd_id, 22, 27) == 2:
new_id = _modify_bits(thrd_id, 15, 22, dest_device.index)
return new_id
else:
return thrd_id


def _plan_to(plan_str, dest_device):
plan = plan_pb.Plan()
plan.ParseFromString(plan_str)
for task in plan.task:
task.task_id = _task_id_to(task.task_id, dest_device)
task.thrd_id = _thrd_id_to(task.thrd_id, dest_device)
for node in task.exec_sequence.exec_node:
_parallel_conf_to(
node.kernel_conf.op_attribute.parallel_conf_signature.op_parallel_conf,
dest_device,
)
for name, regst in task.produced_regst_desc.items():
regst.producer_task_id = _task_id_to(regst.producer_task_id, dest_device)
for c_task_id_idx in range(len(regst.consumer_task_id)):
regst.consumer_task_id[c_task_id_idx] = _task_id_to(
regst.consumer_task_id[c_task_id_idx], dest_device
)
_mem_case_to(regst.mem_case, dest_device)
for mem_block in plan.block_chunk_list.mem_block:
_mem_case_to(mem_block.mem_case, dest_device)
mem_block.thrd_id_hint = _thrd_id_to(mem_block.thrd_id_hint, dest_device)
for chunk in plan.block_chunk_list.chunk:
_mem_case_to(chunk.mem_case, dest_device)

new_ctrl_regst_desc_id2producer_task_id = {}
for (
regst_desc_id,
producer_task_id,
) in plan.ctrl_regst_desc_info.ctrl_regst_desc_id2producer_task_id.items():
new_ctrl_regst_desc_id2producer_task_id[regst_desc_id] = _task_id_to(
producer_task_id, dest_device
)
for (
regst_desc_id,
producer_task_id,
) in new_ctrl_regst_desc_id2producer_task_id.items():
plan.ctrl_regst_desc_info.ctrl_regst_desc_id2producer_task_id[
regst_desc_id
] = producer_task_id

for job_id, op_attr_tab in plan.job_id2op_attribute_ref_table.items():
for _, op_attr in op_attr_tab.op_name2op_attribute.items():
_parallel_conf_to(
op_attr.parallel_conf_signature.op_parallel_conf, dest_device
)

return plan.SerializeToString()
Loading