Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 331 additions & 0 deletions vllm/compilation/sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import is_rocm_aiter_rmsnorm_enabled
from vllm.platforms import current_platform

from .inductor_pass import enable_fake_mode
Expand Down Expand Up @@ -100,6 +101,23 @@
)
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]

def _aiter_functional_rmsnorm(self, input_tensor, weight_tensor):
return torch.ops.vllm.rocm_aiter_rms_norm.default(
input_tensor,
weight_tensor,
self.epsilon,
)

def _aiter_functional_fused_add_rmsnorm(
self, input_tensor, residual_tensor, weight_tensor
):
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default(
input_tensor,
residual_tensor,
weight_tensor,
self.epsilon,
)


class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
"""Helper for sequence parallelism patterns."""
Expand Down Expand Up @@ -255,6 +273,291 @@
FP8_DTYPE = current_platform.fp8_dtype()



class AiterFirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)

return [input, weight]

def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
all_reduce = self._all_reduce(input)
rmsnorm = self._aiter_functional_rmsnorm(all_reduce, weight)
return rmsnorm, all_reduce

def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
logger.info("Aiter FirstAllReduceRMSNormPattern replacement called!")
reduce_scatter = self._reduce_scatter(input)

rmsnorm = self._aiter_functional_rmsnorm(reduce_scatter, weight)

all_gather = self._all_gather(rmsnorm)

return all_gather, reduce_scatter
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)


class AiterMiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)

return [
residual,
mm_1,
rms_norm_weights,
]

def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._aiter_functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights
)
return rmsnorm[0], rmsnorm[1]

def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._aiter_functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights
)
all_gather = self._all_gather(rmsnorm[0])
return all_gather, rmsnorm[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)


class AiterLastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)

return [
residual,
mm_1,
rms_norm_weights,
]

def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> torch.Tensor:
all_reduce = self._all_reduce(mm_1)
rmsnorm = self._aiter_functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights
)
return rmsnorm[0]

def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> torch.Tensor:
reduce_scatter = self._reduce_scatter(mm_1)
rmsnorm = self._aiter_functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights
)
normalized = self._all_gather(rmsnorm[0])
return normalized
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)



class AiterFirstAllReduceRMSNormQuantPattern(_SequenceParallelPatternHelper):

def __init__(
self, epsilon: float, dtype: torch.dtype, device: str
):
super().__init__(epsilon, dtype, device, quant_op=None)

def get_inputs(self):
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
return [input, weight]

def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
):
all_reduce = self._all_reduce(input)

rmsnorm_result = self._aiter_functional_rmsnorm(all_reduce, weight)

quant_input, scale = torch.ops.vllm.aiter_dynamic_per_token_scaled_quant(
rmsnorm_result,
quant_dtype=FP8_DTYPE)

return quant_input, scale, all_reduce

def replacement(
input: torch.Tensor,
weight: torch.Tensor,
):
logger.info("Aiter FirstAllReduceRMSNormQuantPattern replacement called!")
reduce_scatter = self._reduce_scatter(input)

rmsnorm_result = self._aiter_functional_rmsnorm(reduce_scatter, weight)

quant_input, scale = torch.ops.vllm.aiter_dynamic_per_token_scaled_quant(
rmsnorm_result,
quant_dtype=FP8_DTYPE)

all_gather = self._all_gather(quant_input)

all_gather_scale = self._all_gather(scale)

return all_gather, all_gather_scale, reduce_scatter
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)


class AiterMiddleAllReduceRMSNormQuantPattern(_SequenceParallelPatternHelper):

def __init__(
self, epsilon: float, dtype: torch.dtype, device: str
):
super().__init__(epsilon, dtype, device, quant_op=None)

def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [residual, mm_1, rms_norm_weights]

def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)

rmsnorm = self._aiter_functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights
)

quant_input, scale = torch.ops.vllm.aiter_dynamic_per_token_scaled_quant(
rmsnorm[0],
quant_dtype=FP8_DTYPE)

return quant_input, scale, rmsnorm[1]

def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
logger.info("Aiter MiddleAllReduceRMSNormQuantPattern replacement called!")
reduce_scatter = self._reduce_scatter(mm_1)

rmsnorm = self._aiter_functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights
)

quant_input, scale = torch.ops.vllm.aiter_dynamic_per_token_scaled_quant(
rmsnorm[0], quant_dtype=FP8_DTYPE)

all_gather = self._all_gather(quant_input)
all_gather_scale = self._all_gather(scale)
return all_gather, all_gather_scale, rmsnorm[1]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)


class AiterLastAllReduceRMSNormQuantPattern(_SequenceParallelPatternHelper):

def __init__(
self, epsilon: float, dtype: torch.dtype, device: str
):
super().__init__(epsilon, dtype, device, quant_op=None)

def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

return [
result,
residual,
mm_1,
rms_norm_weights,
scale,
]

def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(mm_1)

rmsnorm, _ = self._aiter_functional_fused_add_rmsnorm(
all_reduce, residual, rms_norm_weights
)

quant_input, scale = torch.ops.vllm.aiter_dynamic_per_token_scaled_quant(
rmsnorm,
quant_dtype=FP8_DTYPE)
return quant_input, scale

def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
logger.info("Aiter LastAllReduceRMSNormQuantPattern replacement called!")
reduce_scatter = self._reduce_scatter(mm_1)

rmsnorm, _ = self._aiter_functional_fused_add_rmsnorm(
reduce_scatter, residual, rms_norm_weights
)

quant_input, scale = torch.ops.vllm.aiter_dynamic_per_token_scaled_quant(
rmsnorm,
quant_dtype=FP8_DTYPE)

all_gather = self._all_gather(quant_input)
all_gather_scale = self._all_gather(scale)
return all_gather, all_gather_scale
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)


class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
Expand Down Expand Up @@ -455,9 +758,37 @@
pass_name="sequence_parallelism_pass"
)

logger.info(f"Aiter RMSNorm enabled: {current_platform.is_rocm()}")

Check failure on line 761 in vllm/compilation/sequence_parallelism.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/compilation/sequence_parallelism.py:761:21: G004 Logging statement uses f-string

for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default

if current_platform.is_rocm() and is_rocm_aiter_rmsnorm_enabled():
AiterFirstAllReduceRMSNormQuantPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)

AiterMiddleAllReduceRMSNormQuantPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)

AiterLastAllReduceRMSNormQuantPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)

AiterFirstAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)

AiterMiddleAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)

AiterLastAllReduceRMSNormPattern(
epsilon, self.model_dtype, self.device
).register(self.patterns)

FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device, fp8_quant_op
).register(self.patterns)
Expand Down
Loading
Loading