Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized string.Replace(char, char) #67049

Merged
merged 13 commits into from
Aug 17, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -967,13 +967,11 @@ private string ReplaceCore(string oldValue, string? newValue, CompareInfo? ci, C
//
public string Replace(char oldChar, char newChar)
{
if (oldChar == newChar)
return this;

int firstIndex = IndexOf(oldChar);

if (firstIndex < 0)
int firstIndex;
if (oldChar == newChar || (firstIndex = IndexOf(oldChar)) < 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it against the guidelines to perform an assignment inside an if ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, but it gives nice machine code here 😉
It's about collapsing the epilogs for the first checks (oldChar == newChar, and firstIndex < 0).
Maybe I think too complicated now, but the other option would be using goto for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the very least it would be good to have a comment calling out the assignment why it is being done here.

Otherwise, at a glance it may look like a potential bug or comparison using ==

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, the below is also the more "natural" pattern and more readable:

if (oldChar == newChar)
{
    return this;
}

int firstIndex = IndexOf(oldChar);

if (firstIndex < 0)
{
    return this;
}

Ideally the JIT would handle such a pattern "correctly" and optimize it down accordingly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tannergooding so what's my action here?

  • add the comment explaining that the epilogs get collapsed in generated machine-code
  • write it more naturally and file an issue for the JIT
  • do both, and write it naturally once the JIT issue is fixed

(I'm leaning towards the last option, for perf-reasons -- except JIT issue will be fixed for .NET 7 😉)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think writing it naturally and filing an issue for the JIT is the best choice and don't expect the cost to be significant here.

If the cost is more significant, then adding a comment calling out the assignment and why as well as filing an issue for the JIT is the next best option.

If the issue is actually being fixed for .NET 7, that's all the more reason to do the first approach.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the codegen issue tracked by #8883?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code change here (back to where it was) with 30889ac
If I read the issue from the previous comment correct, so this should cover that case.

{
return this;
}

int remainingLength = Length - firstIndex;
string result = FastAllocateString(Length);
Expand All @@ -987,35 +985,51 @@ public string Replace(char oldChar, char newChar)
}

// Copy the remaining characters, doing the replacement as we go.
ref ushort pSrc = ref Unsafe.Add(ref Unsafe.As<char, ushort>(ref _firstChar), copyLength);
ref ushort pDst = ref Unsafe.Add(ref Unsafe.As<char, ushort>(ref result._firstChar), copyLength);
ref ushort pSrc = ref Unsafe.Add(ref Unsafe.As<char, ushort>(ref _firstChar), (nint)(uint)copyLength);
ref ushort pDst = ref Unsafe.Add(ref Unsafe.As<char, ushort>(ref result._firstChar), (nint)(uint)copyLength);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

off-topic: gosh, the same line with raw pointers is basically

ushort* pDst = ((ushort*)result._firstChar)[copyLength]

"Safe" Unsafe is killing me 🤦‍♂️

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here 🙈 It reads (and writes) like a mess.
Pinning wasn't used here, so I didn't use it too -- also I expect a little bit of regression then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might become slightly more readable if the Unsafe.As was the outermost.

We could also expose an internal only Vector.LoadUnsafe(ref T, nuint index) API which would also simplify things here.

nint i = 0;

if (Vector.IsHardwareAccelerated && remainingLength >= Vector<ushort>.Count)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine a scenario where we are on an AVX2 machine and have 15 characters where index 8 is the first match. We Buffer.Memmove the first 8 (index 0-7, representing 16-bytes) then find we have 7 characters remaining and have to fallback to a scalar loop to handle it.

This effectively pessimizes the support added to "backtrack" so the "trailing" elements can be handled via vectorization.

It would likely be better to check that Length >= Vector<ushort>.Count so that we can backtrack and continue handling the trailing elements efficiently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea 👍🏻

on an AVX2 machine and have 15 characters

I follow your idea (in general), but is 15 a mistake here? I mean Vector<ushort>.Count on AVX2 is 16, so it would be an access violation to read 16 chars (ushorts) from the end of that string.

Copy link
Member

@tannergooding tannergooding Jul 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, should be 31 and 16 rather than 15 and 8. I was thinking I need to halve the bytes to get a count and then forgot that I need twice as many characters so its just under 2 vectors worth 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done with 8627f6f

{
Vector<ushort> oldChars = new Vector<ushort>(oldChar);
Vector<ushort> newChars = new Vector<ushort>(newChar);
Vector<ushort> oldChars = new(oldChar);
Vector<ushort> newChars = new(newChar);

Vector<ushort> original;
Vector<ushort> equals;
Vector<ushort> results;

do
{
Vector<ushort> original = Unsafe.ReadUnaligned<Vector<ushort>>(ref Unsafe.As<ushort, byte>(ref pSrc));
Vector<ushort> equals = Vector.Equals(original, oldChars);
Vector<ushort> results = Vector.ConditionalSelect(equals, newChars, original);
Unsafe.WriteUnaligned(ref Unsafe.As<ushort, byte>(ref pDst), results);

pSrc = ref Unsafe.Add(ref pSrc, Vector<ushort>.Count);
pDst = ref Unsafe.Add(ref pDst, Vector<ushort>.Count);
remainingLength -= Vector<ushort>.Count;
original = Unsafe.ReadUnaligned<Vector<ushort>>(ref Unsafe.As<ushort, byte>(ref Unsafe.Add(ref pSrc, i)));
equals = Vector.Equals(original, oldChars);
results = Vector.ConditionalSelect(equals, newChars, original);
Unsafe.WriteUnaligned(ref Unsafe.As<ushort, byte>(ref Unsafe.Add(ref pDst, i)), results);

i += Vector<ushort>.Count;
}
while (remainingLength >= Vector<ushort>.Count);
while (i <= (nint)(uint)(remainingLength - Vector<ushort>.Count));
gfoidl marked this conversation as resolved.
Show resolved Hide resolved

// There are [0, Vector<ushort>.Count) elements remaining now.
// As the operation is idempotent, and we know that in total there are at least Vector<ushort>.Count
// elements available, we read a vector from the very end of the string, perform the replace
// and write to the destination at the very end.
// Thus we can eliminate the scalar processing of the remaining elements.
// We perform this operation even if there are 0 elements remaining, as it is cheaper than the
// additional check which would introduce a branch here.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as it is cheaper than the additional check which would introduce a branch here.

Can you quantify this? Even with good branch prediction it's still more expensive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to pour this statement into numbers, as with a BDN-benchmark the branch predictor will very likely do a great job (they got really smart over the last generation of cpus).

In contrast to real-world usage I assume that it is more likely to have $&gt; 0$ elements remaining than having a remainder of $= 0$. In that case, and with the assumption that the branch predictor predictis $&gt; 0$ elements, the additional check (would be a test-instruction on x86) costs more than just executing the code (which needs to be done anyway).
So we penalize the case of having 0 elements remaining (which is assumed to be less likely), but all the data should be in the cache and cpu's memory system's store buffer should help to minimize that penalty.

When I start working on Vector128/256 support for string.Replace I'll try to examine that further, as there may be a code-path that starts with Vector256 where remainders will be processed by Vector128.


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps worth adding an assert that current Debug.Assert(this.Length - i <= Vector<ushort>.Count) to make sure we won't skip any data?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I think in this case a test should fail?
I'll re-check the tests and make sure that case is covered.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests cover these cases, so I don't see a need for the Debug.Assert -- but I'll add it of course if you want.

// -------------------- For Vector<ushort>.Count == 8 (SSE2 / ARM NEON) --------------------
[InlineData("Aaaaaaaa", 'A', 'a', "aaaaaaaa")] // Single iteration of vectorised path; no remainders through non-vectorised path
// Three leading 'a's before a match (copyLength > 0), Single iteration of vectorised path; no remainders through non-vectorised path
[InlineData("aaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaa")]
// Single iteration of vectorised path; 3 remainders through non-vectorised path
[InlineData("AaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaa")]
// ------------------------- For Vector<ushort>.Count == 16 (AVX2) -------------------------
[InlineData("AaaaaaaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaaaaaaa")] // Single iteration of vectorised path; no remainders through non-vectorised path
// Three leading 'a's before a match (copyLength > 0), Single iteration of vectorised path; no remainders through non-vectorised path
[InlineData("aaaAaaaaaaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaaaaaaaaaa")]
// Single iteration of vectorised path; 3 remainders through non-vectorised path
[InlineData("AaaaaaaaAaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaaaaaaaaaa")]
// ----------------------------------- General test data -----------------------------------

i = (nint)(uint)this.Length - Vector<ushort>.Count;
original = Unsafe.ReadUnaligned<Vector<ushort>>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref _firstChar, i)));
equals = Vector.Equals(original, oldChars);
results = Vector.ConditionalSelect(equals, newChars, original);
Unsafe.WriteUnaligned(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref result._firstChar, i)), results);
}

for (; remainingLength > 0; remainingLength--)
else
{
ushort currentChar = pSrc;
pDst = currentChar == oldChar ? newChar : currentChar;

pSrc = ref Unsafe.Add(ref pSrc, 1);
pDst = ref Unsafe.Add(ref pDst, 1);
for (; i < (nint)(uint)remainingLength; ++i)
{
ushort currentChar = Unsafe.Add(ref pSrc, i);
Unsafe.Add(ref pDst, i) = currentChar == oldChar ? newChar : currentChar;
}
}

return result;
Expand Down