Skip to content

Commit

Permalink
Add Base64.IsValid and allow Base64.DecodeXx methods to skip whitespa…
Browse files Browse the repository at this point in the history
…ce (#85938)

* Allow Base64Decoder to ignore space chars, add IsValid methods and tests

* Some cleanup of Base64.IsValid changes

This includes making FromBase64Transform significantly faster via SearchValues.

* Address PR feedback and some more cleanup

---------

Co-authored-by: Heath Baron-Morgan <heathbm@outlook.com>
  • Loading branch information
stephentoub and heathbm authored May 9, 2023
1 parent 5b90c47 commit b2d730c
Show file tree
Hide file tree
Showing 11 changed files with 1,080 additions and 45 deletions.
106 changes: 101 additions & 5 deletions src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Xunit;

namespace System.Buffers.Text.Tests
{
public class Base64DecoderUnitTests
public class Base64DecoderUnitTests : Base64TestBase
{
[Fact]
public void BasicDecoding()
Expand Down Expand Up @@ -157,7 +160,7 @@ public void DecodingOutputTooSmall()

Span<byte> decodedBytes = new byte[3];
int consumed, written;
if (numBytes % 4 == 0)
if (numBytes >= 8)
{
Assert.True(OperationStatus.DestinationTooSmall ==
Base64.DecodeFromUtf8(source, decodedBytes, out consumed, out written), "Number of Input Bytes: " + numBytes);
Expand Down Expand Up @@ -373,8 +376,12 @@ public void DecodingInvalidBytes(bool isFinalBlock)
for (int i = 0; i < invalidBytes.Length; i++)
{
// Don't test padding (byte 61 i.e. '='), which is tested in DecodingInvalidBytesPadding
if (invalidBytes[i] == Base64TestHelper.EncodingPad)
// Don't test chars to be ignored (spaces: 9, 10, 13, 32 i.e. '\n', '\t', '\r', ' ')
if (invalidBytes[i] == Base64TestHelper.EncodingPad ||
Base64TestHelper.IsByteToBeIgnored(invalidBytes[i]))
{
continue;
}

// replace one byte with an invalid input
source[j] = invalidBytes[i];
Expand Down Expand Up @@ -568,8 +575,12 @@ public void DecodeInPlaceInvalidBytes()
Span<byte> buffer = "2222PPPP"u8.ToArray(); // valid input

// Don't test padding (byte 61 i.e. '='), which is tested in DecodeInPlaceInvalidBytesPadding
if (invalidBytes[i] == Base64TestHelper.EncodingPad)
// Don't test chars to be ignored (spaces: 9, 10, 13, 32 i.e. '\n', '\t', '\r', ' ')
if (invalidBytes[i] == Base64TestHelper.EncodingPad ||
Base64TestHelper.IsByteToBeIgnored(invalidBytes[i]))
{
continue;
}

// replace one byte with an invalid input
buffer[j] = invalidBytes[i];
Expand All @@ -594,7 +605,7 @@ public void DecodeInPlaceInvalidBytes()
{
Span<byte> buffer = "2222PPP"u8.ToArray(); // incomplete input
Assert.Equal(OperationStatus.InvalidData, Base64.DecodeFromUtf8InPlace(buffer, out int bytesWritten));
Assert.Equal(0, bytesWritten);
Assert.Equal(3, bytesWritten);
}
}

Expand Down Expand Up @@ -667,5 +678,90 @@ public void DecodeInPlaceInvalidBytesPadding()
}
}

[Theory]
[MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))]
public void BasicDecodingIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes)
{
byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored);
byte[] resultBytes = new byte[5];
OperationStatus result = Base64.DecodeFromUtf8(utf8BytesWithByteToBeIgnored, resultBytes, out int bytesConsumed, out int bytesWritten);

// Control value from Convert.FromBase64String
byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored);

Assert.Equal(OperationStatus.Done, result);
Assert.Equal(utf8WithCharsToBeIgnored.Length, bytesConsumed);
Assert.Equal(expectedBytes.Length, bytesWritten);
Assert.True(expectedBytes.SequenceEqual(resultBytes));
Assert.True(stringBytes.SequenceEqual(resultBytes));
}

[Theory]
[MemberData(nameof(ValidBase64Strings_WithCharsThatMustBeIgnored))]
public void DecodeInPlaceIgnoresCharsToBeIgnoredAsConvertToBase64Does(string utf8WithCharsToBeIgnored, byte[] expectedBytes)
{
Span<byte> utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored);
OperationStatus result = Base64.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored, out int bytesWritten);
Span<byte> bytesOverwritten = utf8BytesWithByteToBeIgnored.Slice(0, bytesWritten);
byte[] resultBytesArray = bytesOverwritten.ToArray();

// Control value from Convert.FromBase64String
byte[] stringBytes = Convert.FromBase64String(utf8WithCharsToBeIgnored);

Assert.Equal(OperationStatus.Done, result);
Assert.Equal(expectedBytes.Length, bytesWritten);
Assert.True(expectedBytes.SequenceEqual(resultBytesArray));
Assert.True(stringBytes.SequenceEqual(resultBytesArray));
}

[Theory]
[MemberData(nameof(StringsOnlyWithCharsToBeIgnored))]
public void BasicDecodingWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgnored)
{
byte[] utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored);
byte[] resultBytes = new byte[5];
OperationStatus result = Base64.DecodeFromUtf8(utf8BytesWithByteToBeIgnored, resultBytes, out int bytesConsumed, out int bytesWritten);

Assert.Equal(OperationStatus.Done, result);
Assert.Equal(0, bytesWritten);
}

[Theory]
[MemberData(nameof(StringsOnlyWithCharsToBeIgnored))]
public void DecodingInPlaceWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgnored)
{
Span<byte> utf8BytesWithByteToBeIgnored = UTF8Encoding.UTF8.GetBytes(utf8WithCharsToBeIgnored);
OperationStatus result = Base64.DecodeFromUtf8InPlace(utf8BytesWithByteToBeIgnored, out int bytesWritten);

Assert.Equal(OperationStatus.Done, result);
Assert.Equal(0, bytesWritten);
}

[Theory]
[MemberData(nameof(BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes_MemberData))]
public void BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes(string inputString, int expectedConsumed, int expectedWritten)
{
Span<byte> source = Encoding.ASCII.GetBytes(inputString);
Span<byte> decodedBytes = new byte[Base64.GetMaxDecodedFromUtf8Length(source.Length)];

Assert.Equal(OperationStatus.Done, Base64.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount));
Assert.Equal(expectedConsumed, consumed);
Assert.Equal(expectedWritten, decodedByteCount);
Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes));
}

public static IEnumerable<object[]> BasicDecodingWithExtraWhitespaceShouldBeCountedInConsumedBytes_MemberData()
{
var r = new Random(42);
for (int i = 0; i < 5; i++)
{
yield return new object[] { "AQ==" + new string(r.GetItems<char>(" \n\t\r", i)), 4 + i, 1 };
}

foreach (string s in new[] { "MTIz", "M TIz", "MT Iz", "MTI z", "MTIz ", "M TI z", "M T I Z " })
{
yield return new object[] { s + s + s + s, s.Length * 4, 12 };
}
}
}
}
111 changes: 111 additions & 0 deletions src/libraries/System.Memory/tests/Base64/Base64TestBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.utf8Bytes, utf8Bytes.Length

using System.Collections.Generic;
using System.Text;

namespace System.Buffers.Text.Tests
{
public class Base64TestBase
{
public static IEnumerable<object[]> ValidBase64Strings_WithCharsThatMustBeIgnored()
{
// Create a Base64 string
string text = "a b c";
byte[] utf8Bytes = Encoding.UTF8.GetBytes(text);
string base64Utf8String = Convert.ToBase64String(utf8Bytes);

// Split the base64 string in half
int stringLength = base64Utf8String.Length / 2;
string firstSegment = base64Utf8String.Substring(0, stringLength);
string secondSegment = base64Utf8String.Substring(stringLength, stringLength);

// Insert ignored chars between the base 64 string
// One will have 1 char, another will have 3

// Line feed
yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(9), 3), utf8Bytes };

// Horizontal tab
yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(10), 3), utf8Bytes };

// Carriage return
yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(13), 3), utf8Bytes };

// Space
yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedInTheMiddle(Convert.ToChar(32), 3), utf8Bytes };

string GetBase64StringWithPassedCharInsertedInTheMiddle(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{new string(charToInsert, numberOfTimesToInsert)}{secondSegment}";

// Insert ignored chars at the start of the base 64 string
// One will have 1 char, another will have 3

// Line feed
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(9), 3), utf8Bytes };

// Horizontal tab
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(10), 3), utf8Bytes };

// Carriage return
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(13), 3), utf8Bytes };

// Space
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheStart(Convert.ToChar(32), 3), utf8Bytes };

string GetBase64StringWithPassedCharInsertedAtTheStart(char charToInsert, int numberOfTimesToInsert) => $"{new string(charToInsert, numberOfTimesToInsert)}{firstSegment}{secondSegment}";

// Insert ignored chars at the end of the base 64 string
// One will have 1 char, another will have 3
// Whitespace after end/padding is not included in consumed bytes

// Line feed
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(9), 3), utf8Bytes };

// Horizontal tab
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(10), 3), utf8Bytes };

// Carriage return
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(13), 3), utf8Bytes };

// Space
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 1), utf8Bytes };
yield return new object[] { GetBase64StringWithPassedCharInsertedAtTheEnd(Convert.ToChar(32), 3), utf8Bytes };

string GetBase64StringWithPassedCharInsertedAtTheEnd(char charToInsert, int numberOfTimesToInsert) => $"{firstSegment}{secondSegment}{new string(charToInsert, numberOfTimesToInsert)}";
}

public static IEnumerable<object[]> StringsOnlyWithCharsToBeIgnored()
{
// One will have 1 char, another will have 3

// Line feed
yield return new object[] { GetRepeatedChar(Convert.ToChar(9), 1) };
yield return new object[] { GetRepeatedChar(Convert.ToChar(9), 3) };

// Horizontal tab
yield return new object[] { GetRepeatedChar(Convert.ToChar(10), 1) };
yield return new object[] { GetRepeatedChar(Convert.ToChar(10), 3) };

// Carriage return
yield return new object[] { GetRepeatedChar(Convert.ToChar(13), 1) };
yield return new object[] { GetRepeatedChar(Convert.ToChar(13), 3) };

// Space
yield return new object[] { GetRepeatedChar(Convert.ToChar(32), 1) };
yield return new object[] { GetRepeatedChar(Convert.ToChar(32), 3) };

string GetRepeatedChar(char charToInsert, int numberOfTimesToInsert) => new string(charToInsert, numberOfTimesToInsert);
}
}
}
2 changes: 2 additions & 0 deletions src/libraries/System.Memory/tests/Base64/Base64TestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public static class Base64TestHelper
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
};

public static bool IsByteToBeIgnored(byte charByte) => charByte is (byte)' ' or (byte)'\t' or (byte)'\r' or (byte)'\n';

public const byte EncodingPad = (byte)'='; // '=', for padding
public const sbyte InvalidByte = -1; // Designating -1 for invalid bytes in the decoding map

Expand Down
Loading

0 comments on commit b2d730c

Please sign in to comment.