@@ -57,65 +57,83 @@ def choose_qparams_and_quantize_codebook_coreml(
57
57
assert code_dtype in list (_SUB_BYTE_UINT_BOUNDS .keys ()) + [torch .uint8 ]
58
58
nbits = _DTYPE_TO_BIT_WIDTH [code_dtype ]
59
59
assert nbits >= 1 and nbits <= 8 , f"nbits must be in [1, 8], got { nbits } "
60
-
61
- assert len (block_size ) == input_tensor .dim ()
62
- block_size = block_size .copy ()
63
- for i in range (len (block_size )):
64
- if block_size [i ] == - 1 :
65
- block_size [i ] = input_tensor .shape [i ]
66
- assert block_size [i ] >= 1 and input_tensor .shape [i ] % block_size [i ] == 0 , (
67
- "block_size[i] must divide input_tensor.shape[i]"
68
- )
69
-
70
60
assert input_tensor .dim () == 2 , "Currently only rank 2 tensors are supported"
71
- assert block_size [0 ] == input_tensor .shape [0 ], (
72
- "Currently only support per-grouped channel granularity"
73
- )
74
61
assert cluster_dim == 1 , (
75
62
f"only cluster_dim == 1 is supported right now, got { cluster_dim } "
76
63
)
77
64
78
- num_lut = input_tensor .shape [1 ] // block_size [1 ]
79
- group_size = block_size [1 ]
80
-
81
- # for converting to numpy
82
- input_tensor = input_tensor .detach ()
83
65
original_shape = input_tensor .shape
66
+ N , K = original_shape
67
+ input_tensor = input_tensor .detach ()
84
68
85
- # reshape to (N, K // group_size, group_size)
86
- input_tensor = input_tensor .reshape (input_tensor .shape [0 ], num_lut , group_size )
87
- from coremltools .models .neural_network .quantization_utils import (
88
- _get_kmeans_lookup_table_and_weight ,
69
+ # --- Process block_size ---
70
+ assert len (block_size ) == 2
71
+ processed_block_size = block_size .copy ()
72
+ if processed_block_size [0 ] == - 1 :
73
+ processed_block_size [0 ] = N
74
+ if processed_block_size [1 ] == - 1 :
75
+ processed_block_size [1 ] = K
76
+
77
+ row_block_size , col_block_size = processed_block_size
78
+ assert N % row_block_size == 0 , (
79
+ f"Tensor rows ({ N } ) not divisible by row block size ({ row_block_size } )"
80
+ )
81
+ assert K % col_block_size == 0 , (
82
+ f"Tensor cols ({ K } ) not divisible by col block size ({ col_block_size } )"
89
83
)
90
84
91
- res_lut = []
92
- # each res_w[:, i, :] will use the same lookup table
93
- # res_w: (N, K // group_size, group_size)
94
- res_w = torch .zeros_like (input_tensor , dtype = torch .uint8 )
95
- for i in range (num_lut ):
96
- # lut: (2**nbits, 1)
97
- # w: (N * group_size)
98
- lut , w = _get_kmeans_lookup_table_and_weight (
99
- nbits , input_tensor [:, i , :], force_kmeans1d , cluster_dim , vector_axis
100
- )
101
- res_lut .append (torch .from_numpy (lut ))
102
- res_w [:, i , :] = torch .from_numpy (w .reshape (input_tensor .shape [0 ], group_size ))
103
-
104
- # directly stack all lookup tables along dim 0
105
- # res_lut: (K // group_size, 2 ** nbits)
106
- res_lut = torch .stack (res_lut , dim = 0 )
107
-
108
- # The final LUT should have dimension equal to input_tensor.dim() + 2
109
- # The first input_tensor.dim() dimensions index over the tables,
110
- # input_tensor.dim() + 1 indexes over the nbit indices
111
- # input_tensor.dim() + 2 are the look up values (shape = 1 for scalar)
112
- # res_lut: (N, K // group_size, 2 ** nbits, group_size)
113
- res_lut = res_lut .reshape (1 , num_lut , 2 ** nbits , 1 )
85
+ # --- Determine and execute grouping strategy ---
86
+ assert row_block_size == N or col_block_size == K
87
+ is_col_grouping = row_block_size == N
114
88
115
- # reshape back to (N, K)
116
- res_w = res_w .reshape (* original_shape )
89
+ res_lut_list = []
90
+ from coremltools .models .neural_network .quantization_utils import (
91
+ _get_kmeans_lookup_table_and_weight ,
92
+ )
117
93
118
- return res_lut , res_w
94
+ if is_col_grouping :
95
+ # STRATEGY 1: Group by COLUMNS
96
+ num_luts = K // col_block_size
97
+ reshaped_tensor = input_tensor .reshape (N , num_luts , col_block_size )
98
+ res_codes = torch .zeros_like (reshaped_tensor , dtype = torch .uint8 )
99
+
100
+ for i in range (num_luts ):
101
+ block_to_quantize = reshaped_tensor [:, i , :]
102
+ lut , w = _get_kmeans_lookup_table_and_weight (
103
+ nbits , block_to_quantize , force_kmeans1d , cluster_dim , vector_axis
104
+ )
105
+ res_lut_list .append (torch .from_numpy (lut ))
106
+ res_codes [:, i , :] = torch .from_numpy (w .reshape (N , col_block_size ))
107
+
108
+ # Shape to match CoreML spec: (1, num_luts, 2**nbits, 1)
109
+ final_luts = torch .stack (res_lut_list , dim = 0 ).reshape (1 , num_luts , 2 ** nbits , 1 )
110
+
111
+ else : # is_row_grouping
112
+ # STRATEGY 2: Group by ROWS
113
+ num_luts = N // row_block_size
114
+ reshaped_tensor = input_tensor .reshape (num_luts , row_block_size , K )
115
+ res_codes = torch .zeros_like (reshaped_tensor , dtype = torch .uint8 )
116
+
117
+ for i in range (num_luts ):
118
+ block_to_quantize = reshaped_tensor [i , :, :]
119
+ lut , w = _get_kmeans_lookup_table_and_weight (
120
+ nbits , block_to_quantize , force_kmeans1d , cluster_dim , vector_axis
121
+ )
122
+ res_lut_list .append (torch .from_numpy (lut ))
123
+ res_codes [i , :, :] = torch .from_numpy (w .reshape (row_block_size , K ))
124
+
125
+ final_luts_stacked = torch .stack (
126
+ res_lut_list , dim = 0
127
+ ) # Shape: (num_luts, 2**nbits, 1)
128
+
129
+ # Reshape to the consistent 4D format
130
+ # The shape is (num_row_groups, 1, 2**nbits, 1)
131
+ final_luts = final_luts_stacked .reshape (num_luts , 1 , 2 ** nbits , 1 )
132
+
133
+ # Reshape codes back to the original tensor shape
134
+ final_codes = res_codes .reshape (* original_shape )
135
+
136
+ return final_luts , final_codes
119
137
120
138
121
139
@register_custom_op
0 commit comments