Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine mt5 #410

Merged
merged 2 commits into from
Oct 28, 2022
Merged

Refine mt5 #410

merged 2 commits into from
Oct 28, 2022

Conversation

xiezipeng-ML
Copy link
Contributor

#406 (comment)

上文中说的 SelfAttention 中的未知 SendRecv 是必要的,它在代码这里,megatron 这里没有的原因是算法不一样,megatron 里面的 t5 没有这个 position_bias。

position_bias 这里 position_bias (S(0), B) 要与 attention_scores (S(0), S(1)) 做计算,需要做一个 (S(0), B) -> (S(0), S(1)),目前 2d SBP 里面是用 SendRecv 实现的,但可以用 SameDim0AllScatter 来实现(没有通信开销)。

但上述 (S(0), B) -> (S(0), S(1)) 的转换不用每一层 layer 都做,因为 position_bias 是在 layer 0 通过 compute_bias 计算出来的,后面的所有 layer 使用的都是 layer 0 的 position_bias,所以该转换只需要做一次。而 position_bias 在与 attention_scores 相加之前,需要先与 attention_mask (S(0), B) 相加(见这里),加完之后 position_bias sbp 也变为了 (S(0), B)。

我们只需要将 position_bias = position_bias.to_global(placement=attention_scores.placement) 这行代码移动到前面的 if 作用域之内,position_bias = position_bias + (1 - attention_mask) * -1000 之后,即可使 (S(0), B) -> (S(0), S(1)) 的转换只做1次。

根据wenxiao的这个refine一下mt5的compute_bias中的to_global位置

@strint
Copy link
Collaborator

strint commented Oct 28, 2022

需要做一个测评,看下性能变化的指标,再合并

@xiezipeng-ML
Copy link
Contributor Author

# num_layers=6
修改后显存2459MiB
[10/28 06:00:17 libai]: >>> done with building model. Building time: 0.902 seconds
[10/28 06:03:40 lb.utils.events]:  eta: 0:51:33  iteration: 849/24000  consumed_samples: 6800  total_loss: 3.505  time: 0.1323 s/iter  data_time: 0.0114 s/iter total_throughput: 60.49 samples/s lr: 8.49e-05

修改前显存2459MiB
done with building model. Building time: 1.366 seconds
[10/28 05:59:29 lb.utils.events]:  eta: 0:51:30  iteration: 849/24000  consumed_samples: 6800  total_loss: 3.513  time: 0.1322 s/iter  data_time: 0.0118 s/iter total_throughput: 60.51 samples/s lr: 8.49e-05  

# num_layers=12
修改后3587MiB
[10/28 06:06:15 libai]: >>> done with building model. Building time: 1.312 seconds
[10/28 06:12:40 lb.utils.events]:  eta: 1:37:56  iteration: 849/24000  consumed_samples: 6800  total_loss: 3.514  time: 0.2518 s/iter  data_time: 0.0117 s/iter total_throughput: 31.77 samples/s lr: 8.49e-05 

修改前3587MiB
[10/28 06:13:43 libai]: >>> done with building model. Building time: 1.555 seconds
[10/28 06:20:14 lb.utils.events]:  eta: 1:39:01  iteration: 849/24000  consumed_samples: 6800  total_loss: 3.514  time: 0.2579 s/iter  data_time: 0.0176 s/iter total_throughput: 31.03 samples/s lr: 8.49e-05 

@strint 可能这个改变的影响有限

@xiezipeng-ML xiezipeng-ML merged commit b3c5ba2 into main Oct 28, 2022
@xiezipeng-ML xiezipeng-ML deleted the refine_mt5 branch October 28, 2022 06:35
@xiezipeng-ML xiezipeng-ML requested review from oneflow-ci-bot and removed request for oneflow-ci-bot October 28, 2022 06:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants