Skip to content

Commit

Permalink
Fix NVQ distance computations in Native provider (#389)
Browse files Browse the repository at this point in the history
* Replace multiplication and addition by FMA in tail computation
* Fix computation of tail elements in nvqCosine8bit, nvqDotProduct8bit, nvqSquareDistance8bit
* When using FloatVector.fromMemorySegment, use MemorySegmentVectorFloat.offset to compute the correct offset
  • Loading branch information
marianotepper authored Jan 23, 2025
1 parent 1ece84b commit 86f747c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,7 @@ public static void quantizePartials(float delta, MemorySegmentVectorFloat partia
//---------------------------------------------
// NVQ quantization instructions start here
//---------------------------------------------

static final FloatVector const1f = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, 1.f);
static final FloatVector const05f = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, 0.5f);

Expand Down Expand Up @@ -742,7 +743,7 @@ static void nvqQuantize8bit(MemorySegmentVectorFloat vector, float alpha, float
var invLogisticScale = 255 / (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias);

for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
var arr = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i, ByteOrder.LITTLE_ENDIAN);
var arr = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), vector.offset(i), ByteOrder.LITTLE_ENDIAN);
arr = logisticNQT(arr, scaledAlpha, scaledX0);
arr = arr.sub(logisticBias).mul(invLogisticScale);
var bytes = arr.add(const05f)
Expand Down Expand Up @@ -777,7 +778,7 @@ static float nvqLoss(MemorySegmentVectorFloat vector, float alpha, float x0, flo
var invLogisticScale = 1 / logisticScale;

for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
var arr = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i, ByteOrder.LITTLE_ENDIAN);
var arr = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), vector.offset(i), ByteOrder.LITTLE_ENDIAN);
var recArr = logisticNQT(arr, scaledAlpha, scaledX0);
recArr = recArr.sub(logisticBias).mul(invLogisticScale);
recArr = recArr.add(const05f)
Expand All @@ -803,7 +804,7 @@ static float nvqLoss(MemorySegmentVectorFloat vector, float alpha, float x0, flo
recValue = (recValue - logisticBias) * invLogisticScale;
recValue = Math.round(recValue);
recValue = Math.fma(logisticScale, recValue, logisticBias);
recValue = logitNQT(recValue, scaledAlpha, scaledX0);
recValue = logitNQT(recValue, invScaledAlpha, scaledX0);

squaredSum += MathUtil.square(value - recValue);
}
Expand All @@ -813,19 +814,21 @@ static float nvqLoss(MemorySegmentVectorFloat vector, float alpha, float x0, flo

static float nvqUniformLoss(MemorySegmentVectorFloat vector, float minValue, float maxValue, int nBits) {
float constant = (1 << nBits) - 1;
float delta = maxValue - minValue;

int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length());

FloatVector squaredSumVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED);

for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
var arr = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i, ByteOrder.LITTLE_ENDIAN);
var recArr = arr.sub(minValue).mul(constant / (maxValue - minValue));
var arr = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), vector.offset(i), ByteOrder.LITTLE_ENDIAN);
var recArr = arr.sub(minValue).mul(constant / delta);
recArr = recArr.add(const05f)
.convert(VectorOperators.F2I, 0)
.reinterpretAsInts()
.convert(VectorOperators.I2F, 0)
.reinterpretAsFloats();
recArr = recArr.fma((maxValue - minValue) / constant, minValue);
recArr = recArr.fma(delta / constant, minValue);

var diff = arr.sub(recArr);
squaredSumVec = diff.fma(diff, squaredSumVec);
Expand All @@ -838,9 +841,9 @@ static float nvqUniformLoss(MemorySegmentVectorFloat vector, float minValue, flo
for (int i = vectorizedLength; i < vector.length(); i++) {
value = vector.get(i);

recValue = (value - minValue) / (maxValue - minValue);
recValue = (value - minValue) / delta;
recValue = Math.round(constant * recValue) / constant;
recValue = recValue / (maxValue - minValue) + minValue;
recValue = recValue * delta + minValue;

squaredSum += MathUtil.square(value - recValue);
}
Expand All @@ -866,7 +869,7 @@ static float nvqSquareDistance8bit(MemorySegmentVectorFloat vector, MemorySegmen
var byteArr = ByteVector.fromMemorySegment(ByteVector.SPECIES_PREFERRED, quantizedVector.get(), i, ByteOrder.LITTLE_ENDIAN);

for (int j = 0; j < 4; j++) {
var v1 = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i + floatStep * j, ByteOrder.LITTLE_ENDIAN);
var v1 = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), vector.offset(i + floatStep * j), ByteOrder.LITTLE_ENDIAN);
var v2 = nvqDequantize8bit(byteArr, invScaledAlpha, scaledX0, logisticScale, logisticBias, j);

var diff = v1.sub(v2);
Expand All @@ -881,15 +884,14 @@ static float nvqSquareDistance8bit(MemorySegmentVectorFloat vector, MemorySegmen
for (int i = vectorizedLength; i < quantizedVector.length(); i++) {
value2 = Byte.toUnsignedInt(quantizedVector.get(i));
value2 = Math.fma(logisticScale, value2, logisticBias);
value2 = logitNQT(value2, scaledAlpha, scaledX0);
value2 = logitNQT(value2, invScaledAlpha, scaledX0);
diff = vector.get(i) - value2;
squaredSum += MathUtil.square(diff);
}

return squaredSum;
}


static float nvqDotProduct8bit(MemorySegmentVectorFloat vector, MemorySegmentByteSequence quantizedVector,
float alpha, float x0, float minValue, float maxValue) {
FloatVector dotProdVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
Expand All @@ -904,12 +906,11 @@ static float nvqDotProduct8bit(MemorySegmentVectorFloat vector, MemorySegmentByt
var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0);
var logisticScale = (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias) / 255;


for (int i = 0; i < vectorizedLength; i += ByteVector.SPECIES_PREFERRED.length()) {
var byteArr = ByteVector.fromMemorySegment(ByteVector.SPECIES_PREFERRED, quantizedVector.get(), i, ByteOrder.LITTLE_ENDIAN);

for (int j = 0; j < 4; j++) {
var v1 = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i + floatStep * j, ByteOrder.LITTLE_ENDIAN);
var v1 = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), vector.offset(i + floatStep * j), ByteOrder.LITTLE_ENDIAN);
var v2 = nvqDequantize8bit(byteArr, invScaledAlpha, scaledX0, logisticScale, logisticBias, j);
dotProdVec = v1.fma(v2, dotProdVec);
}
Expand All @@ -922,7 +923,7 @@ static float nvqDotProduct8bit(MemorySegmentVectorFloat vector, MemorySegmentByt
for (int i = vectorizedLength; i < quantizedVector.length(); i++) {
value2 = Byte.toUnsignedInt(quantizedVector.get(i));
value2 = Math.fma(logisticScale, value2, logisticBias);
value2 = logitNQT(value2, scaledAlpha, scaledX0);
value2 = logitNQT(value2, invScaledAlpha, scaledX0);
dotProd = Math.fma(vector.get(i), value2, dotProd);
}

Expand Down Expand Up @@ -953,10 +954,10 @@ static float[] nvqCosine8bit(MemorySegmentVectorFloat vector, MemorySegmentByteS
var byteArr = ByteVector.fromMemorySegment(ByteVector.SPECIES_PREFERRED, quantizedVector.get(), i, ByteOrder.LITTLE_ENDIAN);

for (int j = 0; j < 4; j++) {
var va = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i + floatStep * j, ByteOrder.LITTLE_ENDIAN);
var va = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), vector.offset(i + floatStep * j), ByteOrder.LITTLE_ENDIAN);
var vb = nvqDequantize8bit(byteArr, invScaledAlpha, scaledX0, logisticScale, logisticBias, j);

var vCentroid = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, centroid.get(), i + floatStep * j, ByteOrder.LITTLE_ENDIAN);
var vCentroid = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, centroid.get(), centroid.offset(i + floatStep * j), ByteOrder.LITTLE_ENDIAN);
vb = vb.add(vCentroid);

vsum = va.fma(vb, vsum);
Expand All @@ -972,7 +973,7 @@ static float[] nvqCosine8bit(MemorySegmentVectorFloat vector, MemorySegmentByteS
for (int i = vectorizedLength; i < vector.length(); i++) {
value2 = Byte.toUnsignedInt(quantizedVector.get(i));
value2 = Math.fma(logisticScale, value2, logisticBias);
value2 = logitNQT(value2, scaledAlpha, scaledX0) + centroid.get(i);
value2 = logitNQT(value2, invScaledAlpha, scaledX0) + centroid.get(i);
sum = Math.fma(vector.get(i), value2, sum);
bMagnitude = Math.fma(value2, value2, bMagnitude);
}
Expand Down Expand Up @@ -1012,6 +1013,6 @@ static void nvqShuffleQueryInPlace8bit(MemorySegmentVectorFloat vector) {
}

//---------------------------------------------
// NVQ quantization instructions end here
// NVQ instructions end here
//---------------------------------------------
}
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ static float nvqSquareDistance8bit(ArrayVectorFloat vector, ArrayByteSequence qu
float value2, diff;
for (int i = vectorizedLength; i < quantizedVector.length(); i++) {
value2 = Byte.toUnsignedInt(quantizedVector.get(i));
value2 = logisticScale * value2 + logisticBias;
value2 = Math.fma(logisticScale, value2, logisticBias);
value2 = logitNQT(value2, invScaledAlpha, scaledX0);
diff = vector.get(i) - value2;
squaredSum += MathUtil.square(diff);
Expand Down

0 comments on commit 86f747c

Please sign in to comment.