encode/base: correct decode table indexing

This commit is contained in:
Danny Robson 2018-01-17 20:10:26 +11:00
parent d323197459
commit d037e71bba
2 changed files with 58 additions and 36 deletions

View File

@ -200,47 +200,48 @@ namespace util::encode {
if (src.size () % group_symbols)
throw std::invalid_argument ("base-encoded strings must be a proper multiple of symbols");
union {
uint_fast32_t num;
uint8_t bytes[group_bytes];
};
const bool padded = src.end ()[-1] == '=';
auto cursor = std::cbegin (src);
for (size_t i = 0, last = std::size (src) / group_symbols - padded?1:0;
i != last;
++i)
{
num = std::accumulate (
cursor, cursor + group_symbols,
uint_fast32_t {0},
[] (auto a, auto b) {
return a << symbol_bits | dec_v<Size>[static_cast<unsigned>(b)];
}
);
const int groups = std::size (src) / group_symbols;
const bool padded = std::cend (src)[-1] == '=';
cursor += group_symbols;
for (int i = 0; i < groups - (padded?1:0); ++i) {
uint64_t accum = 0;
dst = std::copy (std::crbegin (bytes), std::crend (bytes), dst);
for (int j = 0; j < group_symbols; ++j) {
auto symbol = dec_v<Size>[unsigned(*cursor++)];
accum <<= symbol_bits;
accum |= symbol;
}
if (cursor != std::end (src)) {
auto last = std::find (cursor, std::cend (src), '=');
num = std::accumulate (
cursor, last,
uint_fast32_t{0},
[] (auto a, auto b) {
return a << symbol_bits | dec_v<Size>[static_cast<unsigned> (b)];
for (int j = group_bytes - 1; j >= 0; --j) {
const uint8_t byte = accum >> (j*8);
*dst++ = byte;
}
}
if (padded) {
uint64_t accum = 0;
int symbols = 0;
for (int j = 0; j < group_symbols; ++j) {
auto symbol = *cursor++;
if (symbol == '=')
break;
symbols++;
accum <<= symbol_bits;
accum |= dec_v<Size>[unsigned(symbol)];
}
);
auto symbols = last - cursor;
auto bits = symbols * symbol_bits;
auto shift = bits%8;
num >>= shift;
accum >>= shift;
for (auto i = bits / 8; i; )
*dst++ = bytes[--i];
for (int j = bits / 8 - 1; j >= 0; --j) {
auto byte = (accum >> (j*8)) & 0xff;
*dst++ = byte;
}
}
return dst;

View File

@ -8,24 +8,39 @@
static constexpr char input[] = "foobar";
//-----------------------------------------------------------------------------
///////////////////////////////////////////////////////////////////////////////
// test vectors from rfc4648
template <int Size>
struct output { static const char *value[std::size (input)]; };
//-----------------------------------------------------------------------------
template <>
const char* output<64>::value[] = {
"", "Zg==", "Zm8=", "Zm9v", "Zm9vYg==", "Zm9vYmE=", "Zm9vYmFy",
"",
"Zg==",
"Zm8=",
"Zm9v",
"Zm9vYg==",
"Zm9vYmE=",
"Zm9vYmFy",
};
//-----------------------------------------------------------------------------
template <>
const char* output<32>::value[] {
"", "MY======", "MZXQ====", "MZXW6===", "MZXW6YQ=", "MZXW6YTB", "MZXW6YTBOI======"
"",
"MY======",
"MZXQ====",
"MZXW6===",
"MZXW6YQ=",
"MZXW6YTB",
"MZXW6YTBOI======"
};
//-----------------------------------------------------------------------------
#if 0
template <>
const char* output<32h>::value[] {
@ -37,7 +52,13 @@ const char* output<32h>::value[] {
//-----------------------------------------------------------------------------
template <>
const char *output<16>::value[] = {
"", "66", "666F", "666F6F", "666F6F62", "666F6F6261", "666F6F626172"
"",
"66",
"666F",
"666F6F",
"666F6F62",
"666F6F6261",
"666F6F626172"
};