diff options
Diffstat (limited to 'src/backend/utils/adt/encode.c')
-rw-r--r-- | src/backend/utils/adt/encode.c | 96 |
1 files changed, 56 insertions, 40 deletions
diff --git a/src/backend/utils/adt/encode.c b/src/backend/utils/adt/encode.c index 4759be25e91..8449aaac56a 100644 --- a/src/backend/utils/adt/encode.c +++ b/src/backend/utils/adt/encode.c @@ -15,7 +15,7 @@ #include <ctype.h> -#include "common/hex_decode.h" +#include "common/hex.h" #include "mb/pg_wchar.h" #include "utils/builtins.h" #include "utils/memutils.h" @@ -32,10 +32,12 @@ */ struct pg_encoding { - uint64 (*encode_len) (const char *data, size_t dlen); - uint64 (*decode_len) (const char *data, size_t dlen); - uint64 (*encode) (const char *data, size_t dlen, char *res); - uint64 (*decode) (const char *data, size_t dlen, char *res); + uint64 (*encode_len) (const char *src, size_t srclen); + uint64 (*decode_len) (const char *src, size_t srclen); + uint64 (*encode) (const char *src, size_t srclen, + char *dst, size_t dstlen); + uint64 (*decode) (const char *src, size_t srclen, + char *dst, size_t dstlen); }; static const struct pg_encoding *pg_find_encoding(const char *name); @@ -81,11 +83,7 @@ binary_encode(PG_FUNCTION_ARGS) result = palloc(VARHDRSZ + resultlen); - res = enc->encode(dataptr, datalen, VARDATA(result)); - - /* Make this FATAL 'cause we've trodden on memory ... */ - if (res > resultlen) - elog(FATAL, "overflow - encode estimate too small"); + res = enc->encode(dataptr, datalen, VARDATA(result), resultlen); SET_VARSIZE(result, VARHDRSZ + res); @@ -129,11 +127,7 @@ binary_decode(PG_FUNCTION_ARGS) result = palloc(VARHDRSZ + resultlen); - res = enc->decode(dataptr, datalen, VARDATA(result)); - - /* Make this FATAL 'cause we've trodden on memory ... */ - if (res > resultlen) - elog(FATAL, "overflow - decode estimate too small"); + res = enc->decode(dataptr, datalen, VARDATA(result), resultlen); SET_VARSIZE(result, VARHDRSZ + res); @@ -145,32 +139,20 @@ binary_decode(PG_FUNCTION_ARGS) * HEX */ -static const char hextbl[] = "0123456789abcdef"; - -uint64 -hex_encode(const char *src, size_t len, char *dst) -{ - const char *end = src + len; - - while (src < end) - { - *dst++ = hextbl[(*src >> 4) & 0xF]; - *dst++ = hextbl[*src & 0xF]; - src++; - } - return (uint64) len * 2; -} - +/* + * Those two wrappers are still needed to match with the layer of + * src/common/. + */ static uint64 hex_enc_len(const char *src, size_t srclen) { - return (uint64) srclen << 1; + return pg_hex_enc_len(srclen); } static uint64 hex_dec_len(const char *src, size_t srclen) { - return (uint64) srclen >> 1; + return pg_hex_dec_len(srclen); } /* @@ -192,12 +174,12 @@ static const int8 b64lookup[128] = { }; static uint64 -pg_base64_encode(const char *src, size_t len, char *dst) +pg_base64_encode(const char *src, size_t srclen, char *dst, size_t dstlen) { char *p, *lend = dst + 76; const char *s, - *end = src + len; + *end = src + srclen; int pos = 2; uint32 buf = 0; @@ -213,6 +195,8 @@ pg_base64_encode(const char *src, size_t len, char *dst) /* write it out */ if (pos < 0) { + if ((p - dst + 4) > dstlen) + elog(ERROR, "overflow of destination buffer in base64 encoding"); *p++ = _base64[(buf >> 18) & 0x3f]; *p++ = _base64[(buf >> 12) & 0x3f]; *p++ = _base64[(buf >> 6) & 0x3f]; @@ -223,25 +207,30 @@ pg_base64_encode(const char *src, size_t len, char *dst) } if (p >= lend) { + if ((p - dst + 1) > dstlen) + elog(ERROR, "overflow of destination buffer in base64 encoding"); *p++ = '\n'; lend = p + 76; } } if (pos != 2) { + if ((p - dst + 4) > dstlen) + elog(ERROR, "overflow of destination buffer in base64 encoding"); *p++ = _base64[(buf >> 18) & 0x3f]; *p++ = _base64[(buf >> 12) & 0x3f]; *p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '='; *p++ = '='; } + Assert((p - dst) <= dstlen); return p - dst; } static uint64 -pg_base64_decode(const char *src, size_t len, char *dst) +pg_base64_decode(const char *src, size_t srclen, char *dst, size_t dstlen) { - const char *srcend = src + len, + const char *srcend = src + srclen, *s = src; char *p = dst; char c; @@ -289,11 +278,21 @@ pg_base64_decode(const char *src, size_t len, char *dst) pos++; if (pos == 4) { + if ((p - dst + 1) > dstlen) + elog(ERROR, "overflow of destination buffer in base64 decoding"); *p++ = (buf >> 16) & 255; if (end == 0 || end > 1) + { + if ((p - dst + 1) > dstlen) + elog(ERROR, "overflow of destination buffer in base64 decoding"); *p++ = (buf >> 8) & 255; + } if (end == 0 || end > 2) + { + if ((p - dst + 1) > dstlen) + elog(ERROR, "overflow of destination buffer in base64 decoding"); *p++ = buf & 255; + } buf = 0; pos = 0; } @@ -305,6 +304,7 @@ pg_base64_decode(const char *src, size_t len, char *dst) errmsg("invalid base64 end sequence"), errhint("Input data is missing padding, is truncated, or is otherwise corrupted."))); + Assert((p - dst) <= dstlen); return p - dst; } @@ -340,7 +340,7 @@ pg_base64_dec_len(const char *src, size_t srclen) #define DIG(VAL) ((VAL) + '0') static uint64 -esc_encode(const char *src, size_t srclen, char *dst) +esc_encode(const char *src, size_t srclen, char *dst, size_t dstlen) { const char *end = src + srclen; char *rp = dst; @@ -352,6 +352,8 @@ esc_encode(const char *src, size_t srclen, char *dst) if (c == '\0' || IS_HIGHBIT_SET(c)) { + if ((rp - dst + 4) > dstlen) + elog(ERROR, "overflow of destination buffer in escape encoding"); rp[0] = '\\'; rp[1] = DIG(c >> 6); rp[2] = DIG((c >> 3) & 7); @@ -361,6 +363,8 @@ esc_encode(const char *src, size_t srclen, char *dst) } else if (c == '\\') { + if ((rp - dst + 2) > dstlen) + elog(ERROR, "overflow of destination buffer in escape encoding"); rp[0] = '\\'; rp[1] = '\\'; rp += 2; @@ -368,6 +372,8 @@ esc_encode(const char *src, size_t srclen, char *dst) } else { + if ((rp - dst + 1) > dstlen) + elog(ERROR, "overflow of destination buffer in escape encoding"); *rp++ = c; len++; } @@ -375,11 +381,12 @@ esc_encode(const char *src, size_t srclen, char *dst) src++; } + Assert((rp - dst) <= dstlen); return len; } static uint64 -esc_decode(const char *src, size_t srclen, char *dst) +esc_decode(const char *src, size_t srclen, char *dst, size_t dstlen) { const char *end = src + srclen; char *rp = dst; @@ -388,7 +395,11 @@ esc_decode(const char *src, size_t srclen, char *dst) while (src < end) { if (src[0] != '\\') + { + if ((rp - dst + 1) > dstlen) + elog(ERROR, "overflow of destination buffer in escape decoding"); *rp++ = *src++; + } else if (src + 3 < end && (src[1] >= '0' && src[1] <= '3') && (src[2] >= '0' && src[2] <= '7') && @@ -400,12 +411,16 @@ esc_decode(const char *src, size_t srclen, char *dst) val <<= 3; val += VAL(src[2]); val <<= 3; + if ((rp - dst + 1) > dstlen) + elog(ERROR, "overflow of destination buffer in escape decoding"); *rp++ = val + VAL(src[3]); src += 4; } else if (src + 1 < end && (src[1] == '\\')) { + if ((rp - dst + 1) > dstlen) + elog(ERROR, "overflow of destination buffer in escape decoding"); *rp++ = '\\'; src += 2; } @@ -423,6 +438,7 @@ esc_decode(const char *src, size_t srclen, char *dst) len++; } + Assert((rp - dst) <= dstlen); return len; } @@ -504,7 +520,7 @@ static const struct { "hex", { - hex_enc_len, hex_dec_len, hex_encode, hex_decode + hex_enc_len, hex_dec_len, pg_hex_encode, pg_hex_decode } }, { |