diff --git a/open_lm/attention.py b/open_lm/attention.py index e0e8aba5..4c55c3bf 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -4,6 +4,15 @@ from torch.nn import functional as F import xformers.ops as xops +# Adding flag if using TE FP8 +using_te = False +try: + import transformer_engine.pytorch as te + + using_te = True +except ImportError as ie: + using_te = False + def get_rectangular_causal_mask(shape, q_seq_len, k_seq_len, device, dtype): """Create a rectangular causal mask. @@ -137,6 +146,55 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None): ) +def torch_attn_te(queries, keys, values, is_causal, attention_mask=None): + _, num_q_heads, _, _ = queries.shape + _, _, hidden_dim_k, _ = values.shape + scaleddotproductattn_module = te.DotProductAttention(num_attention_heads=num_q_heads, kv_channels=hidden_dim_k) + if is_causal and keys.shape[1] > queries.shape[1] > 1: + q_seq_len = queries.shape[1] + k_seq_len = keys.shape[1] + # Same as above, we would like to use: + # mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device) + mask = get_rectangular_causal_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype) + if attention_mask is not None: + apply_attention_mask_(mask, attention_mask, queries_dtype=queries.dtype) + return ( + scaleddotproductattn_module( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attention_mask=mask + ) + .transpose(1, 2) + .contiguous() + ) + else: + if attention_mask is None: + bias = None + # If we only have one query, assume we don't need to be in causal mode (can attend to all keys). + if queries.shape == 1: + is_causal = False + else: + if not is_causal: + raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.") + # Build causal mask that assumes queries are in the end of the sequence. + batch, q_seq_len, heads, _ = queries.shape + k_seq_len = keys.shape[1] + bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) + if attention_mask is not None: + apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype) + # We apply causal mask in attention instead of using is_causal=True. + is_causal = False + return ( + scaleddotproductattn_module( + queries.transpose(1, 2), + keys.transpose(1, 2), + values.transpose(1, 2), + attention_mask=bias, + attn_mask_type="causal" if is_causal else None, + ) + .transpose(1, 2) + .contiguous() + ) + + ATTN_ACTIVATIONS = { "relu": F.relu, "relu_squared": lambda x: torch.pow(F.relu(x), 2), @@ -189,12 +247,7 @@ def custom_attn( return torch.einsum("bhqk,bkhd->bqhd", attn_weight, values) -def get_attn_func( - attn_name, - attn_activation=None, - attn_seq_scalar=None, - alpha=None, -): +def get_attn_func(attn_name, attn_activation=None, attn_seq_scalar=None, alpha=None, use_fp8=False): if attn_name == "auto": return xformers_attn if torch.cuda.is_available() else torch_attn elif attn_name == "xformers_attn": @@ -206,6 +259,8 @@ def get_attn_func( # call .contiguous() on the output tensor. [#188] return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous() elif attn_name == "torch_attn": + # if using_te and use_fp8: + # return torch_attn_te return torch_attn elif attn_name == "custom_attn": assert ( diff --git a/open_lm/distributed.py b/open_lm/distributed.py index 8c07d663..95e22542 100644 --- a/open_lm/distributed.py +++ b/open_lm/distributed.py @@ -57,6 +57,7 @@ def init_distributed_device(args): args.world_size = 1 args.rank = 0 # global rank args.local_rank = 0 + args.world_group = None # For testing, allow forcing distributed mode to test distributed code path even on one gpu. if is_using_distributed() or args.force_distributed: if "SLURM_PROCID" in os.environ: @@ -74,7 +75,7 @@ def init_distributed_device(args): os.environ["LOCAL_RANK"] = str(args.local_rank) os.environ["RANK"] = str(args.rank) os.environ["WORLD_SIZE"] = str(args.world_size) - torch.distributed.init_process_group( + args.world_group = torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, @@ -85,7 +86,9 @@ def init_distributed_device(args): # Note that this currently assumes that the world size is all gpus in a node. assert args.preset_world_size is None, "--preset_world_size with torchrun is not currently supported." args.local_rank, _, _ = world_info_from_env() - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url) + args.world_group = torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url + ) args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() args.distributed = True diff --git a/open_lm/main.py b/open_lm/main.py index 7c80f558..ec4261a0 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -65,7 +65,6 @@ terminate_sync_process, ) - LATEST_CHECKPOINT_NAME = "epoch_latest.pt" @@ -466,13 +465,18 @@ def main(args): random_seed(args.seed, 0) + tensor_parallel_group = None + if args.use_fp8: + tensor_parallel_group = torch.distributed.new_group(ranks=[0], backend="nccl") + logging.info("Using FP8 to run training.") + model = None if args.hf_model is not None: model = create_wrapped_hf_model(args) else: # Optional: Use meta device with torch.device("meta" if args.experimental_meta_device and args.fsdp else args.device): - model = create_model(args) + model = create_model(args, tensor_parallel_group) args.vocab_size = model.vocab_size args.seq_len = model.seq_len @@ -548,8 +552,10 @@ def main(args): # Initialize FSDP. Use the same seed across workers to ensure reset_parameters is the same across workers. random_seed(args.seed, rank=0) + model = FSDP( model, + process_group=args.world_group, auto_wrap_policy=transformer_auto_wrapper_policy, device_id=device, mixed_precision=mp_policy, @@ -832,6 +838,7 @@ def main(args): total_steps=total_steps, args=args, tb_writer=writer, + data_parallel_group=args.world_group, ) if args.distributed: diff --git a/open_lm/model.py b/open_lm/model.py index 0c979c40..ec21a070 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -35,6 +35,24 @@ except ImportError: MambaLMHeadModel = None +# Adding flag if using TE FP8 +using_te = False +LinearTE = nn.Linear +try: + import transformer_engine.pytorch as te + + using_te = True + + class LinearTE(te.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, inp: torch.Tensor, is_first_microbatch: bool = True): + return super().forward(inp, is_first_microbatch=True) + +except ImportError as ie: + using_te = False + # from openclip _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs @@ -86,7 +104,9 @@ class Params: seq_len: int = 2048 post_embed_norm: bool = False weight_tying: bool = False - norm_type: nn.Module = nn.LayerNorm + norm_type: nn.Module = te.LayerNorm if using_te else nn.LayerNorm + linear_type: nn.Module = LinearTE if using_te else nn.Linear + te_device: str = "cuda" if using_te else None attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn apply_qk_norm: bool = False moe_loss_weight: float = 0.1 @@ -119,8 +139,8 @@ def __init__(self, layer_id, args: Params): super().__init__() self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads - self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) - self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + self.in_proj = args.linear_type(args.dim, 3 * args.n_heads * self.head_dim, bias=False, device=args.te_device) + self.out_proj = args.linear_type(args.n_heads * self.head_dim, args.dim, bias=False, device=args.te_device) self.pos_embed = get_pos_embed(args) self.attn_fn = args.attn_func self.apply_qk_norm = args.apply_qk_norm @@ -130,6 +150,7 @@ def __init__(self, layer_id, args: Params): args.norm_type( args.n_heads * self.head_dim, eps=args.norm_eps, + device=args.te_device, ) if self.apply_qk_norm else nn.Identity() @@ -138,6 +159,7 @@ def __init__(self, layer_id, args: Params): args.norm_type( args.n_heads * self.head_dim, eps=args.norm_eps, + device=args.te_device, ) if self.apply_qk_norm else nn.Identity() @@ -195,13 +217,13 @@ class GemmaMLP(nn.Module): Modified from https://github.com/google/gemma_pytorch/blob/01062c9ef4cf89ac0c985b25a734164ede017d0b/gemma/model.py#L182-L201 """ - def __init__(self, dim: int, hidden_dim: int, layer_id: int): + def __init__(self, dim: int, hidden_dim: int, layer_id: int, args: Params): super().__init__() self.dim = dim self.hidden_dim = hidden_dim - self.gate_proj = nn.Linear(dim, hidden_dim) - self.up_proj = nn.Linear(dim, hidden_dim) - self.down_proj = nn.Linear(hidden_dim, dim) + self.gate_proj = nn.Linear(dim, hidden_dim, device=args.te_device) + self.up_proj = nn.Linear(dim, hidden_dim, device=args.te_device) + self.down_proj = nn.Linear(hidden_dim, dim, device=args.te_device) self._layer_id = layer_id def forward(self, x): @@ -225,10 +247,10 @@ def reset_parameters(self): # Same as pseudocode provided from xformers SwiGLU # https://github.com/facebookresearch/xformers class SwiGLUTorch(nn.Module): - def __init__(self, in_dim, hidden_dim, out_dim, bias=True): + def __init__(self, in_dim, hidden_dim, out_dim, args: Params = Params, bias=True): super().__init__() - self.w12 = nn.Linear(in_dim, 2 * hidden_dim, bias=bias) - self.w3 = nn.Linear(hidden_dim, out_dim, bias=bias) + self.w12 = nn.Linear(in_dim, 2 * hidden_dim, bias=bias, device=args.te_device) + self.w3 = nn.Linear(hidden_dim, out_dim, bias=bias, device=args.te_device) def forward(self, x): gate, x = self.w12(x).chunk(2, dim=-1) @@ -252,17 +274,17 @@ def __init__(self, layer_id, args: Params): elif args.ffn_type == "swiglu_torch": # this follows llama / lit llama -- go to multiple of 256 self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) - self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, bias=False) + self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, args, bias=False) elif args.ffn_type == "gelu": # Follows mosaic mpt7b, but without a bias. self.hidden_dim = args.dim * 4 - self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False) - self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) + self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False, device=args.te_device) + self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False, device=args.te_device) self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) elif args.ffn_type == "gemma_geglu": # this follows llama / lit llama -- go to multiple of 256 self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) - self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id) + self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id, args) elif args.ffn_type == "moe": moe_args = MoEArgs( hidden_size=args.dim, @@ -283,10 +305,12 @@ def __init__(self, layer_id, args: Params): self.attention_norm = args.norm_type( args.dim, eps=args.norm_eps, + device=args.te_device, ) self.ffn_norm = args.norm_type( args.dim, eps=args.norm_eps, + device=args.te_device, ) self.attention.seq_len = args.seq_len self.reset_parameters() @@ -455,9 +479,15 @@ def create_params(args): vocab_size=cfg["vocab_size"], post_embed_norm=cfg["post_embed_norm"], weight_tying=cfg["weight_tying"], - norm_type=get_norm_class(cfg.get("model_norm", args.model_norm)), + norm_type=get_norm_class(cfg.get("model_norm", args.model_norm), args.use_fp8), + linear_type=LinearTE if (using_te and args.use_fp8) else nn.Linear, + te_device="cuda" if (using_te and args.use_fp8) else None, attn_func=get_attn_func( - args.attn_name, args.attn_activation, args.attn_seq_scalar, args.attn_seq_scalar_alpha + args.attn_name, + args.attn_activation, + args.attn_seq_scalar, + args.attn_seq_scalar_alpha, + use_fp8=args.use_fp8, ), apply_qk_norm=cfg.get("qk_norm", args.qk_norm), positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type), @@ -495,10 +525,23 @@ def forward(self, x): return out, None, None -def create_model(args): +def te_linear_ops(model, exclude_modules=["output"], tensor_parallel_group=None): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + te_linear_ops(module, exclude_modules, tensor_parallel_group) + if isinstance(module, te.Linear): + model._modules[name].set_tensor_parallel_group(tensor_parallel_group) + return model + + +def create_model(args, tensor_parallel_group=None): if "mamba" in args.model: model = Mamba(create_params(args)) + if tensor_parallel_group is not None and using_te: + model = te_linear_ops(model.to(torch.bfloat16).cuda(), tensor_parallel_group) return model else: model = Transformer(create_params(args)) + if tensor_parallel_group is not None and using_te: + model = te_linear_ops(model.to(torch.bfloat16).cuda(), tensor_parallel_group) return model diff --git a/open_lm/norms.py b/open_lm/norms.py index f02f2e48..fc07993e 100644 --- a/open_lm/norms.py +++ b/open_lm/norms.py @@ -8,6 +8,15 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter +# Adding flag if using TE FP8 +using_te = False +try: + import transformer_engine.pytorch as te + + using_te = True +except ImportError as ie: + using_te = False + class LayerNorm(nn.Module): # NOTE: taken from official pytorch implementation and modified @@ -98,6 +107,30 @@ def _cast_if_autocast_enabled(tensor): return tensor +class LayerNormTE(LayerNorm): + def forward(self, x): + layer_norm_module = te.LayerNorm(self.normalized_shape, eps=self.eps, device="cuda", params_dtype=x.dtype) + output_tensor = layer_norm_module(x) + if self.weight is not None and self.bias is not None: + output_tensor = output_tensor * self.weight + self.bias + return output_tensor + + +class LPLayerNormTE(LayerNorm): + def forward(self, x): + module_device = x.device + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + layer_norm_module = te.LayerNorm( + self.normalized_shape, eps=self.eps, device="cuda", params_dtype=downcast_x.dtype + ) + output_tensor = layer_norm_module(downcast_x) + if downcast_weight is not None and downcast_bias is not None: + output_tensor = output_tensor * downcast_weight + downcast_bias + return output_tensor + + class RmsNorm(nn.Module): def __init__( self, @@ -133,17 +166,25 @@ def extra_repr(self) -> str: return "{normalized_shape}, eps={eps} ".format(**self.__dict__) -def get_norm_class(model_norm): +def get_norm_class(model_norm, use_fp8=False): if model_norm == "default_layer_norm": return torch.nn.LayerNorm elif model_norm == "lp_layer_norm": + if use_fp8 and using_te: + return LPLayerNormTE return LPLayerNorm elif model_norm == "gain_only_lp_layer_norm": + if use_fp8 and using_te: + return partial(LPLayerNormTE, elementwise_gain=True, elementwise_bias=False) return partial(LPLayerNorm, elementwise_gain=True, elementwise_bias=False) elif model_norm == "gain_only_layer_norm": + if use_fp8 and using_te: + return partial(LayerNormTE, elementwise_gain=True, elementwise_bias=False) return partial(LayerNorm, elementwise_gain=True, elementwise_bias=False) elif model_norm == "no_wb_layer_norm": + if use_fp8 and using_te: + return partial(LayerNormTE, elementwise_gain=False, elementwise_bias=False) return partial(LayerNorm, elementwise_gain=False, elementwise_bias=False) elif model_norm == "rms_norm": diff --git a/open_lm/params.py b/open_lm/params.py index 0a7a3f64..719f2138 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -491,7 +491,7 @@ def parse_args(args): ) parser.add_argument( "--precision", - choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], + choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32", "amp_fp8"], default="amp", help="Floating point precision.", ) @@ -787,6 +787,12 @@ def parse_args(args): default=0, help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed", ) + parser.add_argument( + "--use-fp8", + action="store_true", + default=False, + help="If set, allow FP8 training for the model.", + ) add_model_args(parser) diff --git a/open_lm/train.py b/open_lm/train.py index 0d54bf70..f74ab25e 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -28,6 +28,18 @@ from open_lm.precision import get_autocast from open_lm.meters import AverageMeter +# Adding flag if using TE FP8 +using_te = False +try: + import transformer_engine.pytorch as te + from transformer_engine.common import recipe + + fp8_format = recipe.Format.HYBRID + fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + using_te = True +except ImportError as ie: + using_te = False + def unwrap_model(model): if hasattr(model, "module"): @@ -44,7 +56,19 @@ def backward(total_loss, scaler): def train_one_epoch( - model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, averagers=None + model, + data, + loss, + epoch, + step, + optimizer, + scaler, + scheduler, + total_steps, + args, + tb_writer=None, + averagers=None, + data_parallel_group=None, ): """Trains model for one epoch on the provided data. @@ -125,9 +149,12 @@ def train_one_epoch( optimizer.zero_grad() if args.accum_freq == 1: - with autocast(): + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group) if ( + using_te and args.use_fp8 + ) else autocast(): forward_start = time.time() inputs, targets = sample_chunk(texts, args) + out, _, _ = model(inputs) forward_time_m.update(time.time() - forward_start) @@ -146,7 +173,9 @@ def train_one_epoch( backward_time_m.update(time.time() - backward_start) if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: - with autocast(): + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group) if ( + using_te and args.use_fp8 + ) else autocast(): for key, averager in averagers.avgs_dict.items(): with torch.no_grad(): out_avg, _, _ = averager.av_model(inputs) @@ -168,12 +197,15 @@ def train_one_epoch( if isinstance(model, FSDP) and ii != args.accum_freq - 1: maybe_no_sync = model.no_sync with maybe_no_sync(): - with autocast(): + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group) if ( + using_te and args.use_fp8 + ) else autocast(): forward_start = time.time() inputs_ii = inputs[ii * per_batch : (ii + 1) * per_batch] if inputs_ii.shape[0] == 0: break targets_ii = targets[ii * per_batch : (ii + 1) * per_batch] + out, _, _ = model(inputs_ii) forward_total_time += time.time() - forward_start @@ -194,7 +226,9 @@ def train_one_epoch( backward_start = time.time() backward(local_loss, scaler) backward_total_time += time.time() - backward_start - with autocast(): + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group) if ( + using_te and args.use_fp8 + ) else autocast(): if ( averagers is not None and args.log_avg_model_training_loss diff --git a/sagemaker_train/Dockerfile b/sagemaker_train/Dockerfile index 36300449..f4d11c62 100644 --- a/sagemaker_train/Dockerfile +++ b/sagemaker_train/Dockerfile @@ -1,7 +1,7 @@ ARG AWS_REGION # SageMaker PyTorch image -FROM 763104351884.dkr.ecr.${AWS_REGION}.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-sagemaker +FROM 763104351884.dkr.ecr.${AWS_REGION}.amazonaws.com/pytorch-training:2.2.0-gpu-py310-cu121-ubuntu20.04-sagemaker # Run custom installation of libraries # RUN pip install xxx @@ -26,6 +26,8 @@ RUN rm /opt/ml/code/setup.py RUN pip install -r /opt/ml/code/requirements.txt RUN pip uninstall flash-attn -y RUN pip install flash-attn>=2.2 +RUN pip install s3fs>=2023.6.0 +RUN pip install --upgrade s3fs # # Prevent sagemaker from installing requirements again. # RUN rm /opt/ml/code/setup.py RUN rm /opt/ml/code/requirements.txt diff --git a/sagemaker_train/Dockerfile_update b/sagemaker_train/Dockerfile_update index 1282688c..c9fa936b 100644 --- a/sagemaker_train/Dockerfile_update +++ b/sagemaker_train/Dockerfile_update @@ -9,6 +9,8 @@ COPY . /opt/ml/code/ # RUN pip install -e /opt/ml/code/ # Prevent sagemaker from installing requirements again. +RUN pip install s3fs>=2023.6.0 +RUN pip install --upgrade s3fs RUN rm /opt/ml/code/setup.py RUN rm /opt/ml/code/requirements.txt diff --git a/sagemaker_train/cfg_sample.yaml b/sagemaker_train/cfg_sample.yaml index 07158730..59b03277 100644 --- a/sagemaker_train/cfg_sample.yaml +++ b/sagemaker_train/cfg_sample.yaml @@ -1,11 +1,11 @@ accum-freq: 4 beta1: 0.9 beta2: 0.95 -data-key: "json" -dataset-resampled: True +data-key: "json.gz" +dataset-resampled: False # delete-previous-checkpoint: False # Total 25B * 40 = 1T tokens -epochs: 40 +epochs: 1 fsdp: True fsdp-limit-all-gathers: True # grad-checkpointing: False @@ -16,23 +16,27 @@ name: "sample_7b" precision: "amp_bfloat16" report-to: "wandb" seed: 124 -train-data-mix-weights: [0.725, 0.275] -train-data: ["TODO"] -train-num-samples: 25_000_000_000 -wandb-project-name: "lm1" +# train-data-mix-weights: [0.725, 0.275] +dataset-manifest: ["TODO"] +train-num-samples: 28_000_000_000 +wandb-project-name: "lm7" workers: 4 logs: /opt/ml/checkpoints/ # Some important parameters, double checked with Mitchell: -batch-size: 16 -ffn-type: swiglu +global-batch-size: 32 +ffn-type: swiglu_torch # fsdp-amp: False fsdp-pure-bf16: True fsdp-backward-prefetch: True -lr: 3.e-4 +fsdp-use-orig-params: True +lr: 3.e-3 lr-cooldown-end: 3.e-5 model-norm: "gain_only_lp_layer_norm" qk-norm: True -warmup: 5000 +warmup: 2000 wd: 0.1 z-loss-coefficient: 1.e-4 +attn-name: torch_attn +torchcompile: True +use_fp8: False