Skip to content

Commit a30f083

Browse files
committed
fix quantized sequence mask being too small, assert conditions
1 parent c6462d1 commit a30f083

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

include/nbl/builtin/hlsl/sampling/quantized_sequence.hlsl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ struct decode_helper<T, D, true>
7171
using sequence_store_type = typename sequence_type::store_type;
7272
using sequence_scalar_type = typename vector_traits<sequence_store_type>::scalar_type;
7373
using return_type = vector<fp_type, D>;
74-
NBL_CONSTEXPR_STATIC_INLINE scalar_type UNormConstant = unorm_constant<8u*sizeof(scalar_type)>::value;
74+
// NBL_CONSTEXPR_STATIC_INLINE scalar_type UNormConstant = unorm_constant<8u*sizeof(scalar_type)>::value;
75+
NBL_CONSTEXPR_STATIC_INLINE scalar_type UNormConstant = unorm_constant<21>::value;
7576

7677
static return_type __call(NBL_CONST_REF_ARG(sequence_type) val, const uint32_t scrambleSeed)
7778
{
@@ -118,7 +119,7 @@ struct QuantizedSequence<T, Dim NBL_PARTIAL_REQ_BOT(SEQUENCE_SPECIALIZATION_CONC
118119
using store_type = T;
119120
NBL_CONSTEXPR_STATIC_INLINE uint16_t StoreBits = uint16_t(8u) * size_of_v<store_type>;
120121
NBL_CONSTEXPR_STATIC_INLINE uint16_t BitsPerComponent = StoreBits / Dim;
121-
NBL_CONSTEXPR_STATIC_INLINE uint16_t Mask = (uint16_t(1u) << BitsPerComponent) - uint16_t(1u);
122+
NBL_CONSTEXPR_STATIC_INLINE store_type Mask = (uint16_t(1u) << BitsPerComponent) - uint16_t(1u);
122123
NBL_CONSTEXPR_STATIC_INLINE uint16_t DiscardBits = StoreBits - BitsPerComponent;
123124
NBL_CONSTEXPR_STATIC_INLINE uint32_t UNormConstant = impl::unorm_constant<BitsPerComponent>::value;
124125

@@ -161,13 +162,13 @@ struct QuantizedSequence<T, Dim NBL_PARTIAL_REQ_BOT(SEQUENCE_SPECIALIZATION_CONC
161162
using scalar_type = typename vector_traits<T>::scalar_type;
162163
NBL_CONSTEXPR_STATIC_INLINE uint16_t StoreBits = uint16_t(8u) * size_of_v<store_type>;
163164
NBL_CONSTEXPR_STATIC_INLINE uint16_t BitsPerComponent = StoreBits / Dim;
164-
NBL_CONSTEXPR_STATIC_INLINE uint16_t Mask = (uint16_t(1u) << BitsPerComponent) - uint16_t(1u);
165+
NBL_CONSTEXPR_STATIC_INLINE scalar_type Mask = (uint16_t(1u) << BitsPerComponent) - uint16_t(1u);
165166
NBL_CONSTEXPR_STATIC_INLINE uint16_t DiscardBits = StoreBits - BitsPerComponent;
166167
NBL_CONSTEXPR_STATIC_INLINE uint32_t UNormConstant = impl::unorm_constant<BitsPerComponent>::value;
167168

168169
scalar_type get(const uint16_t idx)
169170
{
170-
assert(idx > 0 && idx < 3);
171+
assert(idx >= 0 && idx < 3);
171172
if (idx < 2)
172173
{
173174
return data[idx] & Mask;
@@ -182,15 +183,16 @@ struct QuantizedSequence<T, Dim NBL_PARTIAL_REQ_BOT(SEQUENCE_SPECIALIZATION_CONC
182183

183184
void set(const uint16_t idx, const scalar_type value)
184185
{
185-
assert(idx > 0 && idx < 3);
186+
assert(idx >= 0 && idx < 3);
186187
if (idx < 2)
187188
{
189+
const scalar_type trunc_val = value >> DiscardBits;
188190
data[idx] &= ~Mask;
189-
data[idx] |= (value >> DiscardBits) & Mask;
191+
data[idx] |= trunc_val &Mask;
190192
}
191193
else
192194
{
193-
const scalar_type zbits = StoreBits-BitsPerComponent;
195+
const uint16_t zbits = StoreBits-BitsPerComponent;
194196
const scalar_type zmask = (uint16_t(1u) << zbits) - uint16_t(1u);
195197
const scalar_type trunc_val = value >> (DiscardBits-1u);
196198
data[0] &= Mask;
@@ -211,20 +213,20 @@ struct QuantizedSequence<T, Dim NBL_PARTIAL_REQ_BOT(SEQUENCE_SPECIALIZATION_CONC
211213
using scalar_type = typename vector_traits<T>::scalar_type;
212214
NBL_CONSTEXPR_STATIC_INLINE uint16_t StoreBits = uint16_t(8u) * size_of_v<store_type>;
213215
NBL_CONSTEXPR_STATIC_INLINE uint16_t BitsPerComponent = StoreBits / Dim;
214-
NBL_CONSTEXPR_STATIC_INLINE uint16_t Mask = (uint16_t(1u) << BitsPerComponent) - uint16_t(1u);
216+
NBL_CONSTEXPR_STATIC_INLINE scalar_type Mask = (uint16_t(1u) << BitsPerComponent) - uint16_t(1u);
215217
NBL_CONSTEXPR_STATIC_INLINE uint16_t DiscardBits = StoreBits - BitsPerComponent;
216218
NBL_CONSTEXPR_STATIC_INLINE uint32_t UNormConstant = impl::unorm_constant<BitsPerComponent>::value;
217219

218220
scalar_type get(const uint16_t idx)
219221
{
220-
assert(idx > 0 && idx < 4);
222+
assert(idx >= 0 && idx < 4);
221223
const uint16_t i = (idx & uint16_t(2u)) >> uint16_t(1u);
222224
return (data[i] >> (BitsPerComponent * (idx & uint16_t(1u)))) & Mask;
223225
}
224226

225227
void set(const uint16_t idx, const scalar_type value)
226228
{
227-
assert(idx > 0 && idx < 4);
229+
assert(idx >= 0 && idx < 4);
228230
const uint16_t i = (idx & uint16_t(2u)) >> uint16_t(1u);
229231
const uint16_t odd = idx & uint16_t(1u);
230232
data[i] &= hlsl::mix(~Mask, Mask, bool(odd));
@@ -245,7 +247,7 @@ struct QuantizedSequence<T, Dim NBL_PARTIAL_REQ_BOT(SEQUENCE_SPECIALIZATION_CONC
245247

246248
base_type get(const uint16_t idx)
247249
{
248-
assert(idx > 0 && idx < 2);
250+
assert(idx >= 0 && idx < 2);
249251
base_type a;
250252
a[0] = data[uint16_t(2u) * idx];
251253
a[1] = data[uint16_t(2u) * idx + 1];
@@ -254,7 +256,7 @@ struct QuantizedSequence<T, Dim NBL_PARTIAL_REQ_BOT(SEQUENCE_SPECIALIZATION_CONC
254256

255257
void set(const uint16_t idx, const base_type value)
256258
{
257-
assert(idx > 0 && idx < 2);
259+
assert(idx >= 0 && idx < 2);
258260
base_type a;
259261
data[uint16_t(2u) * idx] = value[0];
260262
data[uint16_t(2u) * idx + 1] = value[1];
@@ -275,13 +277,13 @@ struct QuantizedSequence<T, Dim NBL_PARTIAL_REQ_BOT(SEQUENCE_SPECIALIZATION_CONC
275277
NBL_CONSTEXPR_STATIC_INLINE uint16_t StoreBits = uint16_t(8u) * size_of_v<store_type>;
276278
NBL_CONSTEXPR_STATIC_INLINE uint16_t BitsPerComponent = StoreBits / Dim;
277279
NBL_CONSTEXPR_STATIC_INLINE uint16_t LeftoverBitsPerComponent = BitsPerComponent - uint16_t(8u) * size_of_v<scalar_type>;
278-
NBL_CONSTEXPR_STATIC_INLINE uint16_t Mask = (uint16_t(1u) << LeftoverBitsPerComponent) - uint16_t(1u);
280+
NBL_CONSTEXPR_STATIC_INLINE scalar_type Mask = (uint16_t(1u) << LeftoverBitsPerComponent) - uint16_t(1u);
279281
NBL_CONSTEXPR_STATIC_INLINE uint16_t DiscardBits = StoreBits - BitsPerComponent;
280282
NBL_CONSTEXPR_STATIC_INLINE uint32_t UNormConstant = impl::unorm_constant<8u*sizeof(scalar_type)>::value;
281283

282284
base_type get(const uint16_t idx)
283285
{
284-
assert(idx > 0 && idx < 3);
286+
assert(idx >= 0 && idx < 3);
285287
base_type a;
286288
a[0] = data[idx];
287289
a[1] = (data[3] >> (LeftoverBitsPerComponent * idx)) & Mask;
@@ -290,7 +292,7 @@ struct QuantizedSequence<T, Dim NBL_PARTIAL_REQ_BOT(SEQUENCE_SPECIALIZATION_CONC
290292

291293
void set(const uint16_t idx, const base_type value)
292294
{
293-
assert(idx > 0 && idx < 3);
295+
assert(idx >= 0 && idx < 3);
294296
data[idx] = value[0];
295297
data[3] &= ~Mask;
296298
data[3] |= ((value[1] >> DiscardBits) & Mask) << (LeftoverBitsPerComponent * idx);

0 commit comments

Comments
 (0)