-
Notifications
You must be signed in to change notification settings - Fork 802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Supplement ND SBP signatures for reshape op #9858
Conversation
Co-authored-by: Yipeng Li <jamesonli1313@gmail.com>
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9858/ |
const size_t kMaxSplitAxis = 8; | ||
const size_t kCarryDigit = kMaxSplitAxis + 2; | ||
auto Mesure = [](const NdSbp& nd_sbp) -> size_t { | ||
size_t value = 0; | ||
for (int i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { | ||
size_t cur_dim_value = 0; | ||
const auto& sbp = nd_sbp.sbp_parallel(i); | ||
if (sbp.has_split_parallel()) { | ||
CHECK_LT(sbp.split_parallel().axis(), kMaxSplitAxis); | ||
cur_dim_value = sbp.split_parallel().axis(); | ||
} else if (sbp.has_broadcast_parallel()) { | ||
cur_dim_value = kMaxSplitAxis; | ||
} else if (sbp.has_partial_sum_parallel()) { | ||
cur_dim_value = kMaxSplitAxis + 1; | ||
} else { | ||
UNIMPLEMENTED(); | ||
} | ||
value += cur_dim_value * std::pow(kCarryDigit, (nd_sbp.sbp_parallel_size() - i - 1)); | ||
} | ||
return value; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果rank mesh始终保持不变的话其实是没问题的,就是有的时候可能sbp会被缩成1维
这样S0跟空的是没有区别的,举一个例子:
S1 的值是 1
(S1, S0) 的值也是 1
为了保证不同的sbp顺序绝对不同,建议 S0 -> 1, S1 -> 2, ..., Si -> i+1
此时 kCarryDigit = kMaxSplitAxis + 3;
还有就是 kMaxSplitAxis 在别的文件有具体的定义,可以复用那个,万一以后扩张维度了,只需要改一个就行了。
另外不要用power,
value = value * kCarryDigit + curr_dim_value;
就行。
最后就是建议用数据结构 map<size_t, NdSbp> 来排序
Mesure得到的key是一个size_t,本身就是有序的,用map阔以避免每次compare都重新计算一次key
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9858/ |
if (sbp.has_split_parallel()) { | ||
CHECK_LT(sbp.split_parallel().axis(), kMaxSplitAxis); | ||
// from 1 to 8 | ||
cur_dim_value = sbp.split_parallel().axis() + 1; | ||
} else if (sbp.has_broadcast_parallel()) { | ||
// 9 | ||
cur_dim_value = kMaxSplitAxis + 1; | ||
} else if (sbp.has_partial_sum_parallel()) { | ||
// 10 | ||
cur_dim_value = kMaxSplitAxis + 2; | ||
} else { | ||
UNIMPLEMENTED(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
别的地方没什么问题了。
这里我想了一下,上次我们开会时是讨论到说用97作为不同blob的进位对吧,如果说NdSbp的数字不超过96,是不会有任何风险的。
那现在把B映射到9,P映射到10,而进位是11,这样(B, B)就会是9*11+9 = 108,是超过了97的。
而B是出现很频繁的SBP,也就是风险会出现得频繁一些。
但是如果把B映射到1,P映射到2,Si -> i+3,这样要超过96起码是 88+9,也就是 (S5, S6),(S5, S7)或者 (S6, 任意SBP) 才有可能有风险。
实际中S5少见,更不用说 (S5, S6)了,甚至 (S5, S5) 都是没有问题的。要出问题,至少有一个S6或者 S7,也就是张量起码要有7维。
所以把B,P映射的数字前调能够很有效地避免大部分的风险。(当然即使有风险也不一定会出问题,素数能有效地规避掉一些,但是如果能够避免大部分的风险,还是避免的好)
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9858/ |
对于绝大多数 op 来说,我们只需要列举其可能支持的 1-D SBP signatures,在 2-D (可推及到 ND) 时,对其 1-D SBP signature list 做叉积即可得到 2-D SBP signature list。
但 reshape op 某些情况就属于例外,见如下的例子:
从上面的例子中可以发现一个规律:高维的 SBP signatures 不能完全由低维组合而来。
基于以上理由,为 op 提供一个新的重载函数 EnumerateNdSbpSignatures:其会在 1-D SBP signatures 被列举完后,并由 1-D 叉积产生了 2-D SBP signature list 后被调用。作为当 1-D 叉积不能产生全部的 2-D SBP signatures 的时候,提供一种手段来补充额外的 2-D SBP signatures。
为 reshape 实现了 EnumerateNdSbpSignatures,算法简单来说就是,找到那些被 reshape 的 dimension,从高到低连续按 rank num 切分,直到失败,或者能均匀切到每个 rank 上(从高到底是为了保证切分连续性)。
EnumerateNdSbpSignatures 与已有的 GetNdSbpSignatureList 重载区别是:EnumerateNdSbpSignatures 是在 1-D SBP signatures 叉积之后的额外补充。而 GetNdSbpSignatureList 是完全重载 2-D SBP signatures 的列举逻辑,不会包含 1-D SBP 的叉积生成,其主要作用是为了 source op,用户可以直接通过 attr 来设置输出的 sbp,而无需推导。