Skip to content

Commit

Permalink
Optimize Varint Parsing for 32 and 64 bits
Browse files Browse the repository at this point in the history
It is intended to remove some push/pops and unify the common path

Diff:

https://gcc.godbolt.org/z/4f8a9zsjW

For 64 bit we got rid of extra push/pop. For 32 we think less register pressure and shld is better than ror+test+js. No push/pops. Fun fact that top 32 of 64 bits of intermediate result can be polluted with ones, that does not matter for final assignment of resulting pair

PiperOrigin-RevId: 506846358
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Feb 3, 2023
1 parent 390605c commit ac76ae9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 79 deletions.
121 changes: 43 additions & 78 deletions src/google/protobuf/generated_message_tctable_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,9 @@ class PROTOBUF_EXPORT TcParser final {

// Shift "byte" left by n * 7 bits, filling vacated bits with ones.
template <int n>
inline PROTOBUF_ALWAYS_INLINE uint64_t
shift_left_fill_with_ones(uint64_t byte, uint64_t ones) {
return (byte << (n * 7)) | (ones >> (64 - (n * 7)));
inline PROTOBUF_ALWAYS_INLINE int64_t shift_left_fill_with_ones(uint64_t byte,
uint64_t ones) {
return static_cast<int64_t>((byte << (n * 7)) | (ones >> (64 - (n * 7))));
}

// Shift "byte" left by n * 7 bits, filling vacated bits with ones, and
Expand All @@ -746,17 +746,22 @@ inline PROTOBUF_ALWAYS_INLINE bool shift_left_fill_with_ones_was_negative(
asm("shldq %3, %2, %1"
: "=@ccs"(sign_bit), "+r"(byte)
: "r"(ones), "i"(n * 7));
res = byte;
res = static_cast<int64_t>(byte);
return sign_bit;
#else
// Generic fallback:
res = shift_left_fill_with_ones<n>(byte, ones);
return static_cast<int64_t>(res) < 0;
return res < 0;
#endif
}

inline PROTOBUF_ALWAYS_INLINE std::pair<const char*, uint64_t>
Parse64FallbackPair(const char* p, int64_t res1) {
template <class VarintType>
inline PROTOBUF_ALWAYS_INLINE std::pair<const char*, VarintType>
ParseFallbackPair(const char* p, int64_t res1) {
constexpr bool kIs64BitVarint = std::is_same<VarintType, uint64_t>::value;
constexpr bool kIs32BitVarint = std::is_same<VarintType, uint32_t>::value;
static_assert(kIs64BitVarint || kIs32BitVarint,
"Only 32 or 64 bit varints are supported");
auto ptr = reinterpret_cast<const int8_t*>(p);

// The algorithm relies on sign extension for each byte to set all high bits
Expand Down Expand Up @@ -787,23 +792,30 @@ Parse64FallbackPair(const char* p, int64_t res1) {
goto done3;

// For the remainder of the chunks, check the sign of the AND result.
res1 &= shift_left_fill_with_ones<3>(ptr[3], ones);
if (res1 >= 0) goto done4;
res2 &= shift_left_fill_with_ones<4>(ptr[4], ones);
if (res2 >= 0) goto done5;
res3 &= shift_left_fill_with_ones<5>(ptr[5], ones);
if (res3 >= 0) goto done6;
res1 &= shift_left_fill_with_ones<6>(ptr[6], ones);
if (res1 >= 0) goto done7;
res2 &= shift_left_fill_with_ones<7>(ptr[7], ones);
if (res2 >= 0) goto done8;
res3 &= shift_left_fill_with_ones<8>(ptr[8], ones);
if (res3 >= 0) goto done9;
res2 &= shift_left_fill_with_ones<3>(ptr[3], ones);
if (res2 >= 0) goto done4;
res1 &= shift_left_fill_with_ones<4>(ptr[4], ones);
if (res1 >= 0) goto done5;
if (kIs64BitVarint) {
res2 &= shift_left_fill_with_ones<5>(ptr[5], ones);
if (res2 >= 0) goto done6;
res3 &= shift_left_fill_with_ones<6>(ptr[6], ones);
if (res3 >= 0) goto done7;
res1 &= shift_left_fill_with_ones<7>(ptr[7], ones);
if (res1 >= 0) goto done8;
res3 &= shift_left_fill_with_ones<8>(ptr[8], ones);
if (res3 >= 0) goto done9;
} else if (kIs32BitVarint) {
if (PROTOBUF_PREDICT_TRUE(!(ptr[5] & 0x80))) goto done6;
if (PROTOBUF_PREDICT_TRUE(!(ptr[6] & 0x80))) goto done7;
if (PROTOBUF_PREDICT_TRUE(!(ptr[7] & 0x80))) goto done8;
if (PROTOBUF_PREDICT_TRUE(!(ptr[8] & 0x80))) goto done9;
}

// For valid 64bit varints, the 10th byte/ptr[9] should be exactly 1. In this
// case, the continuation bit of ptr[8] already set the top bit of res3
// correctly, so all we have to do is check that the expected case is true.
if (PROTOBUF_PREDICT_TRUE(ptr[9] == 1)) goto done10;
if (PROTOBUF_PREDICT_TRUE(kIs64BitVarint && ptr[9] == 1)) goto done10;

if (PROTOBUF_PREDICT_FALSE(ptr[9] & 0x80)) {
// If the continue bit is set, it is an unterminated varint.
Expand All @@ -814,7 +826,7 @@ Parse64FallbackPair(const char* p, int64_t res1) {
// over-serialized varint. This case should not happen, but if does (say, due
// to a nonconforming serializer), deassert the continuation bit that came
// from ptr[8].
if ((ptr[9] & 1) == 0) {
if (kIs64BitVarint && (ptr[9] & 1) == 0) {
#if defined(__GCC_ASM_FLAG_OUTPUTS__) && defined(__x86_64__)
// Use a small instruction since this is an uncommon code path.
asm("btcq $63,%0" : "+r"(res3));
Expand Down Expand Up @@ -916,7 +928,8 @@ PROTOBUF_NOINLINE const char* TcParser::FastTV64S1(PROTOBUF_TC_PARAM_DECL) {
hasbits |= (uint64_t{1} << hasbit_idx);
}

auto tmp = Parse64FallbackPair(ptr, static_cast<int8_t>(data.data >> 8));
auto tmp =
ParseFallbackPair<uint64_t>(ptr, static_cast<int8_t>(data.data >> 8));
data.data = 0; // Indicate to the compiler that we don't need this anymore.
ptr = tmp.first;
if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) {
Expand Down Expand Up @@ -949,63 +962,15 @@ PROTOBUF_NOINLINE const char* TcParser::FastTV32S1(PROTOBUF_TC_PARAM_DECL) {
hasbits |= (uint64_t{1} << hasbit_idx);
}

// Few registers
auto* out = &RefAt<FieldType>(msg, data_offset);
uint64_t res = 0xFF & (data.data >> 8);
/* if (PROTOBUF_PREDICT_FALSE(res & 0x80)) */ {
res = RotRight7AndReplaceLowByte(res, ptr[1]);
if (PROTOBUF_PREDICT_FALSE(res & 0x80)) {
res = RotRight7AndReplaceLowByte(res, ptr[2]);
if (PROTOBUF_PREDICT_FALSE(res & 0x80)) {
res = RotRight7AndReplaceLowByte(res, ptr[3]);
if (PROTOBUF_PREDICT_FALSE(res & 0x80)) {
res = RotRight7AndReplaceLowByte(res, ptr[4]);
if (PROTOBUF_PREDICT_FALSE(res & 0x80)) {
if (PROTOBUF_PREDICT_FALSE(ptr[5] & 0x80)) {
if (PROTOBUF_PREDICT_FALSE(ptr[6] & 0x80)) {
if (PROTOBUF_PREDICT_FALSE(ptr[7] & 0x80)) {
if (PROTOBUF_PREDICT_FALSE(ptr[8] & 0x80)) {
if (ptr[9] & 0x80) return Error(PROTOBUF_TC_PARAM_PASS);
*out = RotateLeft(res, 28);
ptr += 10;
PROTOBUF_MUSTTAIL return ToTagDispatch(
PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 28);
ptr += 9;
PROTOBUF_MUSTTAIL return ToTagDispatch(
PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 28);
ptr += 8;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 28);
ptr += 7;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 28);
ptr += 6;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 28);
ptr += 5;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 21);
ptr += 4;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 14);
ptr += 3;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}
*out = RotateLeft(res, 7);
ptr += 2;
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
auto tmp =
ParseFallbackPair<uint32_t>(ptr, static_cast<int8_t>(data.data >> 8));
data.data = 0; // Indicate to the compiler that we don't need this anymore.
ptr = tmp.first;
if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) {
return Error(PROTOBUF_TC_PARAM_PASS);
}
*out = res;
ptr += 1;

RefAt<FieldType>(msg, data_offset) = static_cast<FieldType>(tmp.second);
PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_PASS);
}

Expand Down
2 changes: 1 addition & 1 deletion src/google/protobuf/generated_message_tctable_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ inline PROTOBUF_ALWAYS_INLINE const char* ParseVarint(const char* p,
*value = byte;
return p + 1;
} else {
auto tmp = Parse64FallbackPair(p, byte);
auto tmp = ParseFallbackPair<std::make_unsigned_t<Type>>(p, byte);
if (PROTOBUF_PREDICT_TRUE(tmp.first)) {
*value = static_cast<Type>(tmp.second);
}
Expand Down

0 comments on commit ac76ae9

Please sign in to comment.