From c2a196ae6ed89c8f3ca41907593525a290466d8e Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Tue, 21 May 2024 19:50:43 +0000 Subject: [PATCH 1/3] Support auto-generated FP8 meta for CKPT converters. Signed-off-by: Ming Huang --- .../utils/te_pax_t5x_ckpt_converter/README.md | 4 +- .../converter/main.py | 19 ++- .../converter/paxml_converters.py | 65 ++++++-- .../converter/t5x_converters.py | 151 ++++++++++++++---- .../converter/utils.py | 113 ++++++++++--- 5 files changed, 279 insertions(+), 73 deletions(-) diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/README.md b/rosetta/utils/te_pax_t5x_ckpt_converter/README.md index fcadf7163..d41a93b4d 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/README.md +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/README.md @@ -63,7 +63,7 @@ python converter/main.py \ --input-path=/your_path_to_src_ckpt \ --output-path=/your_path_to_output_ckpt \ --fw=pax \ - --direction=fw2tw \ + --direction=fw2te \ --pax-repeat \ --num-of-layer=8 \ --num-of-head=6 \ @@ -154,7 +154,7 @@ restoring it to keep training. #### The folder structure of CKPT by Pax and T5X If you would like to run the converted CKPTs with frameworks, you may expect the converted CKPTs have the same folder -structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the +structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the CKPTs from frameworks, and no need to pre-generate folders, since it would be generated when needed. For Pax, you could set `--output-path` be like ` /${your_path_to_output}/checkpoints/checkpoint_${step}`. For T5X, you could set `--output-path` be like `/${your_path_to_output}/checkpoint_${step}`. diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py index 31d52b6d7..d8e6caa31 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse from paxml_converters import Pax2TEConvertHelper, Pax2TERepeatConvertHelper @@ -109,6 +108,17 @@ def parse_args(): default=False, help="indicate if skip the conversion for LayerNorm.") + parser.add_argument('--gen-fp8-meta', + action="store_true", + default=False, + help="indicate if generate corresponding FP8 meta." + " Only works when --direction=fw2te") + parser.add_argument( + '--amax-history-len', + type=int, + default=1, + help="the length of amax history, which is only used when --gen-fp8-meta is specified.") + parser.add_argument('--pax-repeat', action="store_true", default=False, @@ -129,7 +139,8 @@ def parse_args(): def get_convert_helper(args): model_config = ModelConfig(args.num_of_layer, args.embed_dim, args.num_of_head, args.head_dim, - args.mlp_intermediate_dim, args.kernel_chunk_size) + args.mlp_intermediate_dim, args.kernel_chunk_size, + args.amax_history_len) convert_helper_cls = None @@ -140,8 +151,8 @@ def get_convert_helper(args): convert_helper_cls = T5X_CONVERT_HELPER_DICT[(args.direction, args.t5x_fuse_qkv)] assert convert_helper_cls is not None, "Not Supported." - return convert_helper_cls(args.input_path, args.output_path, model_config, - args.weight_only, args.skip_ln) + return convert_helper_cls(args.input_path, args.output_path, model_config, args.weight_only, + args.skip_ln, args.gen_fp8_meta) if __name__ == "__main__": diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py index 3596c0e0a..8aa607b2b 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import jax.numpy as jnp from utils import ConvertHelper @@ -26,6 +25,10 @@ def catagories(self): return ['mdl_vars.params'] return ['mdl_vars.params', "opt_states_0_2.m.params", "opt_states_0_2.v.params"] + @property + def fp8_meta_catagories(self): + return {'mdl_vars.params': 'mdl_vars.fp8_metas'} + class Pax2TEConvertHelper(PaxConvertHelperBase): @@ -46,8 +49,11 @@ def _generate_ckpt_map(self): f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel", - (hidden_dim, mlp_intermediate_dim), 0, - lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))), + (hidden_dim, mlp_intermediate_dim), + 0, + lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1])), + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.bias.b": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_bias", @@ -57,7 +63,10 @@ def _generate_ckpt_map(self): f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.linear.w": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_kernel", - (mlp_intermediate_dim, hidden_dim), 1), + (mlp_intermediate_dim, hidden_dim), + 1, + gen_fp8_meta=True, + fp8_meta_postfix='1'), f"lm.transformer.x_layers_{i}.ff_layer.layer_norm.bias": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.ln_bias", @@ -90,9 +99,12 @@ def _generate_ckpt_map(self): f"lm.transformer.x_layers_{i}.self_attention.combined_qkv.w": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.kernel", - (3, hidden_dim, num_of_head, head_dim), 0, + (3, hidden_dim, num_of_head, head_dim), + 0, lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])), - lambda x: jnp.transpose(x, (1, 0, 2))), + lambda x: jnp.transpose(x, (1, 0, 2)), + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"lm.transformer.x_layers_{i}.self_attention.post.b": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.out.bias", @@ -102,9 +114,12 @@ def _generate_ckpt_map(self): f"lm.transformer.x_layers_{i}.self_attention.post.w": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.out.kernel", - (hidden_dim, num_of_head, head_dim), 1, + (hidden_dim, num_of_head, head_dim), + 1, lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])), - lambda x: jnp.transpose(x, (1, 0))) + lambda x: jnp.transpose(x, (1, 0)), + gen_fp8_meta=True, + fp8_meta_postfix='0') }) return ckpt_map @@ -199,6 +214,10 @@ def catagories(self): f"opt_states_0.p#{num_of_layer}#i-1_2.v.params" ] + @property + def fp8_meta_catagories(self): + return {'mdl_vars.params': 'mdl_vars.fp8_metas'} + class Pax2TERepeatConvertHelper(PaxRepeatConvertHelperBase): @@ -220,8 +239,12 @@ def _generate_ckpt_map(self): 'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w': self._get_convert_pkg( 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel', - (num_of_layer, hidden_dim, mlp_intermediate_dim), 1, - lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))), + (num_of_layer, hidden_dim, mlp_intermediate_dim), + 1, + lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1])), + gen_fp8_meta=True, + fp8_meta_postfix='0', + fp8_meta_shape_prefix=(num_of_layer,)), 'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b': self._get_convert_pkg( 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wo_bias', @@ -231,7 +254,11 @@ def _generate_ckpt_map(self): 'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w': self._get_convert_pkg( 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wo_kernel', - (num_of_layer, mlp_intermediate_dim, hidden_dim), 2), + (num_of_layer, mlp_intermediate_dim, hidden_dim), + 2, + gen_fp8_meta=True, + fp8_meta_postfix='1', + fp8_meta_shape_prefix=(num_of_layer,)), 'lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias': self._get_convert_pkg( 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.ln_bias', @@ -264,9 +291,13 @@ def _generate_ckpt_map(self): 'lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w': self._get_convert_pkg( 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.attention.qkv.kernel', - (num_of_layer, 3, hidden_dim, num_of_head, head_dim), 1, + (num_of_layer, 3, hidden_dim, num_of_head, head_dim), + 1, lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])), - lambda x: jnp.transpose(x, (0, 2, 1, 3))), + lambda x: jnp.transpose(x, (0, 2, 1, 3)), + gen_fp8_meta=True, + fp8_meta_postfix='0', + fp8_meta_shape_prefix=(num_of_layer,)), 'lm.transformer.repeat.sub.x_layers_0.self_attention.post.b': self._get_convert_pkg( 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.attention.out.bias', @@ -276,9 +307,13 @@ def _generate_ckpt_map(self): 'lm.transformer.repeat.sub.x_layers_0.self_attention.post.w': self._get_convert_pkg( 'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.attention.out.kernel', - (num_of_layer, hidden_dim, num_of_head, head_dim), 2, + (num_of_layer, hidden_dim, num_of_head, head_dim), + 2, lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])), - lambda x: jnp.transpose(x, (0, 2, 1))) + lambda x: jnp.transpose(x, (0, 2, 1)), + gen_fp8_meta=True, + fp8_meta_postfix='0', + fp8_meta_shape_prefix=(num_of_layer,)) }) return ckpt_map diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/t5x_converters.py b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/t5x_converters.py index 06c669aa4..28dbfbbd7 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/t5x_converters.py +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/t5x_converters.py @@ -25,6 +25,10 @@ def catagories(self): return CATAGORIES[:1] return CATAGORIES + @property + def fp8_meta_catagories(self): + return {CATAGORIES[0]: 'flax_mutables.fp8_metas'} + class T5X2TENotFuseQKVConvertHelper(T5XConvertHelperBase): @@ -32,6 +36,9 @@ def _generate_ckpt_map(self): ckpt_map = {} embed_dim = self.model_config.embed_dim + num_of_head = self.model_config.num_of_head + head_dim = self.model_config.head_dim + hidden_dim = num_of_head * head_dim mlp_intermediate_dim = self.model_config.mlp_intermediate_dim for i in range(self.model_config.num_of_layer): @@ -43,19 +50,32 @@ def _generate_ckpt_map(self): just_copy=True), f"encoder.layers_{i}.attention.query.kernel": self._get_convert_pkg(f"encoder.layers_{i}.attention.query.kernel", + (embed_dim, hidden_dim), None, - None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"encoder.layers_{i}.attention.key.kernel": self._get_convert_pkg(f"encoder.layers_{i}.attention.key.kernel", + (embed_dim, hidden_dim), None, - None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"encoder.layers_{i}.attention.value.kernel": self._get_convert_pkg(f"encoder.layers_{i}.attention.value.kernel", + (embed_dim, hidden_dim), None, + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), + f"encoder.layers_{i}.attention.out.kernel": + self._get_convert_pkg(f"encoder.layers_{i}.attention.out.kernel", + (hidden_dim, embed_dim), None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"encoder.layers_{i}.pre_mlp_layer_norm.scale": self._get_convert_pkg(f"encoder.layers_{i}.mlp.scale", None, @@ -66,12 +86,16 @@ def _generate_ckpt_map(self): (embed_dim, mlp_intermediate_dim), None, extra_src_paths=[f"encoder.layers_{i}.mlp.wi_1.kernel"], - stack_dim=1), + stack_dim=1, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"encoder.layers_{i}.mlp.wo.kernel": self._get_convert_pkg(f"encoder.layers_{i}.mlp.wo_kernel", + (mlp_intermediate_dim, embed_dim), None, - None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='1'), f"decoder.layers_{i}.pre_self_attention_layer_norm.scale": self._get_convert_pkg(f"decoder.layers_{i}.self_attention.query.scale", None, @@ -79,19 +103,32 @@ def _generate_ckpt_map(self): just_copy=True), f"decoder.layers_{i}.self_attention.query.kernel": self._get_convert_pkg(f"decoder.layers_{i}.self_attention.query.kernel", + (embed_dim, hidden_dim), None, - None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.self_attention.key.kernel": self._get_convert_pkg(f"decoder.layers_{i}.self_attention.key.kernel", + (embed_dim, hidden_dim), None, - None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.self_attention.value.kernel": self._get_convert_pkg(f"decoder.layers_{i}.self_attention.value.kernel", + (embed_dim, hidden_dim), None, + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), + f"decoder.layers_{i}.self_attention.out.kernel": + self._get_convert_pkg(f"decoder.layers_{i}.self_attention.out.kernel", + (hidden_dim, embed_dim), None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.pre_cross_attention_layer_norm.scale": self._get_convert_pkg( f"decoder.layers_{i}.encoder_decoder_attention.query.scale", @@ -101,21 +138,35 @@ def _generate_ckpt_map(self): f"decoder.layers_{i}.encoder_decoder_attention.query.kernel": self._get_convert_pkg( f"decoder.layers_{i}.encoder_decoder_attention.query.kernel", + (embed_dim, hidden_dim), None, - None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.encoder_decoder_attention.key.kernel": self._get_convert_pkg( f"decoder.layers_{i}.encoder_decoder_attention.key.kernel", + (embed_dim, hidden_dim), None, - None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.encoder_decoder_attention.value.kernel": self._get_convert_pkg( f"decoder.layers_{i}.encoder_decoder_attention.value.kernel", + (embed_dim, hidden_dim), None, + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), + f"decoder.layers_{i}.encoder_decoder_attention.out.kernel": + self._get_convert_pkg( + f"decoder.layers_{i}.encoder_decoder_attention.out.kernel", + (hidden_dim, embed_dim), None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.pre_mlp_layer_norm.scale": self._get_convert_pkg(f"decoder.layers_{i}.mlp.scale", None, @@ -126,12 +177,16 @@ def _generate_ckpt_map(self): (embed_dim, mlp_intermediate_dim), None, extra_src_paths=[f"decoder.layers_{i}.mlp.wi_1.kernel"], - stack_dim=1), + stack_dim=1, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.mlp.wo.kernel": self._get_convert_pkg(f"decoder.layers_{i}.mlp.wo_kernel", + (mlp_intermediate_dim, embed_dim), None, - None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='1'), }) return ckpt_map @@ -276,7 +331,15 @@ def _generate_ckpt_map(self): f"encoder.layers_{i}.attention.key.kernel", f"encoder.layers_{i}.attention.value.kernel" ], - stack_dim=1), + stack_dim=1, + gen_fp8_meta=True, + fp8_meta_postfix='0'), + f"encoder.layers_{i}.attention.out.kernel": + self._get_convert_pkg(f"encoder.layers_{i}.attention.out.kernel", + (hidden_dim, embed_dim), + None, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"encoder.layers_{i}.pre_mlp_layer_norm.scale": self._get_convert_pkg(f"encoder.layers_{i}.mlp.scale", None, @@ -287,12 +350,16 @@ def _generate_ckpt_map(self): (embed_dim, mlp_intermediate_dim), None, extra_src_paths=[f"encoder.layers_{i}.mlp.wi_1.kernel"], - stack_dim=1), + stack_dim=1, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"encoder.layers_{i}.mlp.wo.kernel": self._get_convert_pkg(f"encoder.layers_{i}.mlp.wo_kernel", None, None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='1'), f"decoder.layers_{i}.pre_self_attention_layer_norm.scale": self._get_convert_pkg(f"decoder.layers_{i}.self_attention.qkv.scale", None, @@ -306,7 +373,15 @@ def _generate_ckpt_map(self): f"decoder.layers_{i}.self_attention.key.kernel", f"decoder.layers_{i}.self_attention.value.kernel" ], - stack_dim=1), + stack_dim=1, + gen_fp8_meta=True, + fp8_meta_postfix='0'), + f"decoder.layers_{i}.self_attention.out.kernel": + self._get_convert_pkg(f"decoder.layers_{i}.self_attention.out.kernel", + (hidden_dim, embed_dim), + None, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.pre_cross_attention_layer_norm.scale": self._get_convert_pkg( f"decoder.layers_{i}.encoder_decoder_attention.query.scale", @@ -316,9 +391,11 @@ def _generate_ckpt_map(self): f"decoder.layers_{i}.encoder_decoder_attention.query.kernel": self._get_convert_pkg( f"decoder.layers_{i}.encoder_decoder_attention.query.kernel", + (embed_dim, hidden_dim), None, - None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.encoder_decoder_attention.key.kernel": self._get_convert_pkg( f"decoder.layers_{i}.encoder_decoder_attention.kv.kernel", @@ -327,7 +404,17 @@ def _generate_ckpt_map(self): extra_src_paths=[ f"decoder.layers_{i}.encoder_decoder_attention.value.kernel" ], - stack_dim=1), + stack_dim=1, + gen_fp8_meta=True, + fp8_meta_postfix='0'), + f"decoder.layers_{i}.encoder_decoder_attention.out.kernel": + self._get_convert_pkg( + f"decoder.layers_{i}.encoder_decoder_attention.out.kernel", + (hidden_dim, embed_dim), + None, + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.pre_mlp_layer_norm.scale": self._get_convert_pkg(f"decoder.layers_{i}.mlp.scale", None, @@ -338,12 +425,16 @@ def _generate_ckpt_map(self): (embed_dim, mlp_intermediate_dim), None, extra_src_paths=[f"decoder.layers_{i}.mlp.wi_1.kernel"], - stack_dim=1), + stack_dim=1, + gen_fp8_meta=True, + fp8_meta_postfix='0'), f"decoder.layers_{i}.mlp.wo.kernel": self._get_convert_pkg(f"decoder.layers_{i}.mlp.wo_kernel", None, None, - just_copy=True), + just_copy=True, + gen_fp8_meta=True, + fp8_meta_postfix='1'), }) return ckpt_map diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/utils.py b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/utils.py index c9fd3defb..4b4a41706 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/utils.py +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import asyncio import os import shutil @@ -40,6 +39,7 @@ class ModelConfig: head_dim: int mlp_intermediate_dim: int kernel_chunk_size: int = None + amax_history_len: int = 1 @dataclass @@ -51,40 +51,68 @@ class ConvertPkg: extra_src_paths: list[str] stack_dim: int just_copy: bool + gen_fp8_meta: bool + fp8_meta_prefix: str + fp8_meta_postfix: str + fp8_meta_shape_prefix: tuple[int] class ConvertHelper: def __init__(self, input_path: str, output_path: str, model_config: ModelConfig, - weight_only: bool, skip_ln: bool): + weight_only: bool, skip_ln: bool, gen_fp8_meta: bool): self.input_path = input_path self.output_path = output_path self.model_config = model_config self.weight_only = weight_only self.skip_ln = skip_ln + self.gen_fp8_meta = gen_fp8_meta @property def catagories(self): raise NotImplementedError - def _get_convert_pkg(self, - target_path, - shape, - chunk_dim, - *converters, - extra_src_paths=[], - stack_dim=0, - just_copy=False): + @property + def fp8_meta_catagories(self): + raise NotImplementedError + + def _get_convert_pkg( + self, + target_path, + shape, + chunk_dim, + *converters, + extra_src_paths=[], + stack_dim=0, + just_copy=False, + gen_fp8_meta=False, + fp8_meta_prefix='', + fp8_meta_postfix='0', + fp8_meta_shape_prefix=tuple(), + ): return ConvertPkg(target_path, shape, chunk_dim, tuple(converters), extra_src_paths, - stack_dim, just_copy) + stack_dim, just_copy, gen_fp8_meta, fp8_meta_prefix, fp8_meta_postfix, + fp8_meta_shape_prefix) def _unpack_convert_pkg(self, pkg): return pkg.target_path, pkg.shape, pkg.chunk_dim, pkg.converters, \ - pkg.extra_src_paths, pkg.stack_dim, pkg.just_copy + pkg.extra_src_paths, pkg.stack_dim, pkg.just_copy, \ + pkg.gen_fp8_meta, pkg.fp8_meta_prefix, pkg.fp8_meta_postfix, \ + pkg.fp8_meta_shape_prefix def _generate_ckpt_map(self): raise NotImplementedError + def _generate_fp8_path(self, prefix, target_path): + fp8_meta_full_prefix = prefix + '.' + target_path[:target_path.rfind('.')] + return fp8_meta_full_prefix + # fp8_meta_amax_full_paths = [] + # fp8_meta_scale_full_paths = [] + # for t in ['i', 'w', 'g']: + # fp8_meta_amax_full_paths.append(fp8_meta_full_prefix + f'.amax_{t}_{postfix}') + # fp8_meta_scale_full_paths.append(fp8_meta_full_prefix + f'.scale_{t}_{postfix}') + # return *fp8_meta_amax_full_paths, *fp8_meta_amax_full_paths + def generate_ckpt_map_with_full_name(self): ckpt_map = self._generate_ckpt_map() @@ -98,9 +126,12 @@ def is_ln_weights(key): return False if self.skip_ln: + keys_to_pop = [] for key in ckpt_map: if is_ln_weights(key): - ckpt_map.pop(key) + keys_to_pop.append(key) + for key in keys_to_pop: + ckpt_map.pop(key) ckpt_map_with_full_name = {} for prefix in self.catagories: @@ -112,13 +143,21 @@ def is_ln_weights(key): ckpt_pkgs_with_full_name = [] for pkg in ckpt_pkgs: target_path, shape, chunk_dim, converters, \ - extra_src_paths, stack_dim, just_copy = self._unpack_convert_pkg(pkg) + extra_src_paths, stack_dim, just_copy, \ + gen_fp8_meta, fp8_meta_prefix, fp8_meta_postfix, \ + fp8_meta_shape_prefix= self._unpack_convert_pkg(pkg) full_src_name = prefix + '.' + src_path full_target_name = prefix + '.' + target_path full_extra_src_names = None if extra_src_paths is None else \ [prefix + '.'+ esp for esp in extra_src_paths] + if prefix not in self.fp8_meta_catagories: + gen_fp8_meta = False + else: + fp8_meta_prefix = self.fp8_meta_catagories[prefix] + fp8_meta_prefix = self._generate_fp8_path(fp8_meta_prefix, target_path) + ckpt_pkgs_with_full_name.append( self._get_convert_pkg(full_target_name, shape, @@ -126,7 +165,11 @@ def is_ln_weights(key): *converters, extra_src_paths=full_extra_src_names, stack_dim=stack_dim, - just_copy=just_copy)) + just_copy=just_copy, + gen_fp8_meta=gen_fp8_meta, + fp8_meta_prefix=fp8_meta_prefix, + fp8_meta_postfix=fp8_meta_postfix, + fp8_meta_shape_prefix=fp8_meta_shape_prefix)) ckpt_map_with_full_name[full_src_name] = ckpt_pkgs_with_full_name @@ -143,12 +186,15 @@ def convert(self): for ckpt_pkg in ckpt_map_with_full_path[folder]: target_path, shape, chunk_dim, converters, \ - extra_src_paths, stack_dim, just_copy = self._unpack_convert_pkg(ckpt_pkg) + extra_src_paths, stack_dim, just_copy, \ + gen_fp8_meta, fp8_meta_prefix, fp8_meta_postfix, \ + fp8_meta_shape_prefix = self._unpack_convert_pkg(ckpt_pkg) if just_copy: src_path = os.path.join(self.input_path, folder) target_path = os.path.join(self.output_path, target_path) copy_ckpt(src_path, target_path) + target_shape = shape else: target_path = os.path.join(self.output_path, target_path) @@ -156,7 +202,7 @@ def convert(self): for src in [folder, *extra_src_paths]: skip_pool.add(src) src_path = os.path.join(self.input_path, src) - jnp_arrs.append(serialize_tensor(src_path, shape)) + jnp_arrs.append(deserialize_tensor(src_path, shape)) if len(jnp_arrs) == 1: jnp_arr = jnp_arrs[0] @@ -166,8 +212,30 @@ def convert(self): for converter in converters: jnp_arr = converter(jnp_arr) - deserialize_tensor(target_path, jnp_arr, chunk_dim, - self.model_config.kernel_chunk_size) + target_shape = jnp_arr.shape + serialize_tensor(target_path, jnp_arr, chunk_dim, + self.model_config.kernel_chunk_size) + + if gen_fp8_meta and self.gen_fp8_meta: + jarr = deserialize_tensor(target_path, target_shape) + for t in ['i', 'w', 'g']: + fp8_meta_amax_path = \ + os.path.join(self.output_path, + fp8_meta_prefix + f'.amax_{t}_{fp8_meta_postfix}') + fp8_meta_scale_path = \ + os.path.join(self.output_path, + fp8_meta_prefix + f'.scale_{t}_{fp8_meta_postfix}') + + amax = jnp.zeros(fp8_meta_shape_prefix + + (self.model_config.amax_history_len,), + dtype=jnp.float32) + scale = jnp.ones(fp8_meta_shape_prefix + (1,), dtype=jnp.float32) + if t == 'w': + w_amax = jnp.max(jnp.abs(jarr)).astype(jnp.float32) + amax = amax.at[0].set(w_amax) + + serialize_tensor(fp8_meta_amax_path, amax, None, None) + serialize_tensor(fp8_meta_scale_path, scale, None, None) for folder in os.listdir(self.input_path): if folder not in ckpt_map_with_full_path and folder not in skip_pool: @@ -212,7 +280,7 @@ def get_cast_tspec_serialize(tspec, value): return tspec -def serialize_tensor(path: str, shape: tuple, dtype=jnp.float32): +def deserialize_tensor(path: str, shape: tuple, dtype=jnp.float32): path = epath.Path(path) tspec = get_json_tspec(path) @@ -221,12 +289,13 @@ def serialize_tensor(path: str, shape: tuple, dtype=jnp.float32): return jnp_arr -def deserialize_tensor(path: str, tensor: jnp.ndarray, chunk_dim: int = None, chunk_size=None): +def serialize_tensor(path: str, tensor: jnp.ndarray, chunk_dim: int = None, chunk_size=None): path = epath.Path(path) tspec = get_json_tspec(path) tspec['metadata'] = serialization._get_metadata(tensor) - del tspec['metadata']['dtype'] + if 'dtype' in tspec['metadata']: + del tspec['metadata']['dtype'] if chunk_dim is not None: chunk_shape = tuple([ From ebc9745836c3ab7a3f68e13526f65b9326160182 Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Thu, 23 May 2024 18:42:10 +0000 Subject: [PATCH 2/3] Changeing default amax-length and Modifying README.md Signed-off-by: Ming Huang --- .../utils/te_pax_t5x_ckpt_converter/README.md | 44 ++++++++++++++++--- .../converter/main.py | 2 +- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/README.md b/rosetta/utils/te_pax_t5x_ckpt_converter/README.md index d41a93b4d..1c4ce7d4a 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/README.md +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/README.md @@ -23,6 +23,9 @@ the size to chucnk kernel (weighs) then store, only support with --fw=pax. Setting None means no chunking. (default: None) --weight-only indicate if the source checkpoint only includes weights. (default: False) --skip-ln indicate if skip the conversion for LayerNorm. (default: False) +--gen-fp8-meta indicate if generate corresponding FP8 meta. Only works when --direction=fw2te (default: False) +--amax-history-len AMAX_HISTORY_LEN + the length of amax history, which is only used when --gen-fp8-meta is specified. (default: 1024) --pax-repeat indicate if the source Pax checkpoint enables Repeat. (default: False) --t5x-fuse-qkv indicate if the source T5X checkpoint enables fused_qkv_params of TE. (default: False) ``` @@ -145,12 +148,41 @@ python converter/main.py \ ### Notes #### Running converted CKPTs with Transformer Engine (TE) + FP8 -If you run the converted TE checkpoints ,from frameworks Pax or T5X, with FP8 enabled, you might enounter -an error said that there is not FP8 meta found in the given checkpoint at restoring phase. That is because the -original checkpoints to convert do not contains information about FP8 meta. To address this issue, please run -a training process with the same model config on the target framework, plus TE and FP8, then store a checkpoint -at step 0. Next, use the converted checkpoint to replace weights of the checkpoint from famework + TE + FP8, and -restoring it to keep training. +We now support auto-generating FP8 meta for converted TE checkpoints from framework checkpoints for further FP8 training. +To enable this feature, please add `--gen-fp8-meta` to your command when running the converter. +Additionally, you should specify the size of the amax history to be applied to subsequent FP8 training using `--amax-history-len`. + +For examples: +- Pax -> TE (Repeat) with FP8 with 1024 amax history length: +```bash +python converter/main.py \ + --input-path=/your_path_to_src_ckpt \ + --output-path=/your_path_to_output_ckpt \ + --fw=pax \ + --direction=fw2te \ + --pax-repeat \ + --gen-fp8-meta \ + --amax-history-len=1024 \ + --num-of-layer=8 \ + --num-of-head=6 \ + --head-dim=64 \ + --mlp-intermediate-dim=1024 +``` + +- T5X -> TE/FusedQKV with FP8 with 1024 amax history length: +```bash +python converter/main.py \ + --input-path=/your_path_to_src_ckpt \ + --output-path=/your_path_to_output_ckpt \ + --fw=t5x \ + --direction=fw2te \ + --t5x-fuse-qkv \ + --embed-dim=512 \ + --num-of-layer=8 \ + --num-of-head=6 \ + --head-dim=64 \ + --mlp-intermediate-dim=1024 +``` #### The folder structure of CKPT by Pax and T5X If you would like to run the converted CKPTs with frameworks, you may expect the converted CKPTs have the same folder diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py index d8e6caa31..2bb5dd692 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py @@ -116,7 +116,7 @@ def parse_args(): parser.add_argument( '--amax-history-len', type=int, - default=1, + default=1024, help="the length of amax history, which is only used when --gen-fp8-meta is specified.") parser.add_argument('--pax-repeat', From c46cded5deb9fa034c6f167b3795d0e49b6eed67 Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Thu, 23 May 2024 19:46:13 +0000 Subject: [PATCH 3/3] Adding some notes for resuming converted FP8 checkpoints Signed-off-by: Ming Huang --- rosetta/utils/te_pax_t5x_ckpt_converter/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/README.md b/rosetta/utils/te_pax_t5x_ckpt_converter/README.md index 1c4ce7d4a..00bfb4599 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/README.md +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/README.md @@ -184,6 +184,10 @@ python converter/main.py \ --mlp-intermediate-dim=1024 ``` +NOTE: +For the generated FP8 meta, only the amax of weights is accurate. Therefore, please be aware that a few steps for adjusting FP8 meta +of inputs and gradients are needed when resuming training with the converted FP8 checkpoints. + #### The folder structure of CKPT by Pax and T5X If you would like to run the converted CKPTs with frameworks, you may expect the converted CKPTs have the same folder structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the