@@ -1828,35 +1828,6 @@ def setup_quant_scales(self, module: torch.nn.Module):
18281828 fc2_global = module .fc2_alpha ,
18291829 )
18301830
1831- def post_load_weights (self , module : torch .nn .Module ):
1832- super ().post_load_weights (module )
1833- if module .moe_backend == "CUTEDSL" :
1834- # Interleave FC1 weight and scales for GEMM1 + SwiGLU fusion.
1835- w3_w1_weight = module .w3_w1_weight .data .view (float4_e2m1x2 )
1836- m = w3_w1_weight .size (1 )
1837- n = w3_w1_weight .size (2 ) * 2
1838- w3_w1_weight_interleaved = interleave_linear_and_gate (w3_w1_weight ,
1839- group_size = 64 ,
1840- dim = 1 )
1841- w3_w1_weight_interleaved = w3_w1_weight_interleaved .view (
1842- module .w3_w1_weight .data .dtype )
1843- module .w3_w1_weight .data .copy_ (w3_w1_weight_interleaved )
1844-
1845- w3_w1_weight_scale = module .quant_scales .fc1_weight_block .data .view (
1846- float4_sf_dtype )
1847- w3_w1_weight_scale_unswizzled = unswizzle_sf (
1848- w3_w1_weight_scale , m , n ).view (- 1 , m ,
1849- n // module .scaling_vector_size )
1850- w3_w1_weight_scale_unswizzled_interleaved = interleave_linear_and_gate (
1851- w3_w1_weight_scale_unswizzled , group_size = 64 , dim = 1 )
1852- w3_w1_weight_scale_interleaved = swizzle_sf (
1853- w3_w1_weight_scale_unswizzled_interleaved , m ,
1854- n ).view (- 1 , m , n // module .scaling_vector_size )
1855- w3_w1_weight_scale_interleaved = w3_w1_weight_scale_interleaved .view (
1856- module .quant_scales .fc1_weight_block .data .dtype )
1857- module .quant_scales .fc1_weight_block .data .copy_ (
1858- w3_w1_weight_scale_interleaved )
1859-
18601831
18611832class NVFP4CutlassFusedMoEMethod (NVFP4FusedMoEMethod ):
18621833 weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE
@@ -1935,6 +1906,38 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
19351906 dst_w2_weight_scale .copy_ (dst_w2_weight_scale_interleaved )
19361907
19371908
1909+ class NVFP4CuteDslFusedMoEMethod (NVFP4CutlassFusedMoEMethod ):
1910+
1911+ def post_load_weights (self , module : torch .nn .Module ):
1912+ super ().post_load_weights (module )
1913+
1914+ # Interleave FC1 weight and scales for GEMM1 + SwiGLU fusion.
1915+ w3_w1_weight = module .w3_w1_weight .data .view (float4_e2m1x2 )
1916+ m = w3_w1_weight .size (1 )
1917+ n = w3_w1_weight .size (2 ) * 2
1918+ w3_w1_weight_interleaved = interleave_linear_and_gate (w3_w1_weight ,
1919+ group_size = 64 ,
1920+ dim = 1 )
1921+ w3_w1_weight_interleaved = w3_w1_weight_interleaved .view (
1922+ module .w3_w1_weight .data .dtype )
1923+ module .w3_w1_weight .data .copy_ (w3_w1_weight_interleaved )
1924+
1925+ w3_w1_weight_scale = module .quant_scales .fc1_weight_block .data .view (
1926+ float4_sf_dtype )
1927+ w3_w1_weight_scale_unswizzled = unswizzle_sf (
1928+ w3_w1_weight_scale , m , n ).view (- 1 , m ,
1929+ n // module .scaling_vector_size )
1930+ w3_w1_weight_scale_unswizzled_interleaved = interleave_linear_and_gate (
1931+ w3_w1_weight_scale_unswizzled , group_size = 64 , dim = 1 )
1932+ w3_w1_weight_scale_interleaved = swizzle_sf (
1933+ w3_w1_weight_scale_unswizzled_interleaved , m ,
1934+ n ).view (- 1 , m , n // module .scaling_vector_size )
1935+ w3_w1_weight_scale_interleaved = w3_w1_weight_scale_interleaved .view (
1936+ module .quant_scales .fc1_weight_block .data .dtype )
1937+ module .quant_scales .fc1_weight_block .data .copy_ (
1938+ w3_w1_weight_scale_interleaved )
1939+
1940+
19381941class NVFP4TRTLLMGenFusedMoEMethod (NVFP4FusedMoEMethod ):
19391942 weight_dtype = float4_sf_dtype
19401943 block_scales_dtype = torch .float8_e4m3fn
0 commit comments