diff --git a/src/index.ts b/src/index.ts index 742c8ac..a858c0c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -182,7 +182,8 @@ const msb = (val: number) => { // read finite state entropy const rfse = (dat: Uint8Array, bt: number, mal: number): [number, FSEDT] => { // table pos - let tpos = (bt << 3) + 4; + // use multiplication, not <<, to avoid int32 overflow when bt >= 2^28 + let tpos = bt * 8 + 4; // accuracy log const al = (dat[bt] & 15) + 5; if (al > mal) err(3); @@ -201,7 +202,7 @@ const rfse = (dat: Uint8Array, bt: number, mal: number): [number, FSEDT] => { const nbits = new u8(buf, bb1 + sz); while (sym < 255 && probs > 0) { const bits = msb(probs + 1); - const cbt = tpos >> 3; + const cbt = Math.floor(tpos / 8); // mask const msk = (1 << (bits + 1)) - 1; let val = ((dat[cbt] | (dat[cbt + 1] << 8) | (dat[cbt + 2] << 16)) >> (tpos & 7)) & msk; @@ -224,7 +225,7 @@ const rfse = (dat: Uint8Array, bt: number, mal: number): [number, FSEDT] => { if (!val) { do { // repeat byte - const rbt = tpos >> 3; + const rbt = Math.floor(tpos / 8); re = ((dat[rbt] | (dat[rbt + 1] << 8)) >> (tpos & 7)) & 3; tpos += 2; sym += re; @@ -260,7 +261,7 @@ const rfse = (dat: Uint8Array, bt: number, mal: number): [number, FSEDT] => { const nb = nbits[i] = al - msb(ns); nstate[i] = (ns << nb) - sz; } - return [(tpos + 7) >> 3, { + return [Math.floor((tpos + 7) / 8), { b: al, s: syms, n: nbits, @@ -285,7 +286,7 @@ const rhu = (dat: Uint8Array, bt: number): [number, HDT] => { // end byte, fse decode table const [ebt, fdt] = rfse(dat, bt + 1, 6); bt += hb; - const epos = ebt << 3; + const epos = ebt * 8; // last byte const lb = dat[bt]; if (!lb) err(0); @@ -293,16 +294,16 @@ const rhu = (dat: Uint8Array, bt: number): [number, HDT] => { let st1 = 0, st2 = 0, btr1 = fdt.b, btr2 = btr1; // fse pos // pre-increment to account for original deficit of 1 - let fpos = (++bt << 3) - 8 + msb(lb); + let fpos = (++bt) * 8 - 8 + msb(lb); for (;;) { fpos -= btr1; if (fpos < epos) break; - let cbt = fpos >> 3; + let cbt = Math.floor(fpos / 8); st1 += ((dat[cbt] | (dat[cbt + 1] << 8)) >> (fpos & 7)) & ((1 << btr1) - 1); hw[++wc] = fdt.s[st1]; fpos -= btr2; if (fpos < epos) break; - cbt = fpos >> 3; + cbt = Math.floor(fpos / 8); st2 += ((dat[cbt] | (dat[cbt + 1] << 8)) >> (fpos & 7)) & ((1 << btr2) - 1); hw[++wc] = fdt.s[st2]; btr1 = fdt.n[st1]; @@ -527,11 +528,11 @@ const rzb = (dat: Uint8Array, st: DZstdState, out?: Uint8Array) => { const [mlt, oct, llt] = st.t = dts; const lb = dat[ebt - 1]; if (!lb) err(0); - let spos = (ebt << 3) - 8 + msb(lb) - llt.b, cbt = spos >> 3, oubt = 0; + let spos = ebt * 8 - 8 + msb(lb) - llt.b, cbt = Math.floor(spos / 8), oubt = 0; let lst = ((dat[cbt] | (dat[cbt + 1] << 8)) >> (spos & 7)) & ((1 << llt.b) - 1); - cbt = (spos -= oct.b) >> 3; + cbt = Math.floor((spos -= oct.b) / 8); let ost = ((dat[cbt] | (dat[cbt + 1] << 8)) >> (spos & 7)) & ((1 << oct.b) - 1); - cbt = (spos -= mlt.b) >> 3; + cbt = Math.floor((spos -= mlt.b) / 8); let mst = ((dat[cbt] | (dat[cbt + 1] << 8)) >> (spos & 7)) & ((1 << mlt.b) - 1); for (++ns; --ns;) { const llc = llt.s[lst]; @@ -541,19 +542,19 @@ const rzb = (dat: Uint8Array, st: DZstdState, out?: Uint8Array) => { const ofc = oct.s[ost]; const obtr = oct.n[ost]; - cbt = (spos -= ofc) >> 3; + cbt = Math.floor((spos -= ofc) / 8); const ofp = 1 << ofc; let off = ofp + (((dat[cbt] | (dat[cbt + 1] << 8) | (dat[cbt + 2] << 16) | (dat[cbt + 3] << 24)) >>> (spos & 7)) & (ofp - 1)); - cbt = (spos -= mlb[mlc]) >> 3; + cbt = Math.floor((spos -= mlb[mlc]) / 8); let ml = mlbl[mlc] + (((dat[cbt] | (dat[cbt + 1] << 8) | (dat[cbt + 2] << 16)) >> (spos & 7)) & ((1 << mlb[mlc]) - 1)); - cbt = (spos -= llb[llc]) >> 3; + cbt = Math.floor((spos -= llb[llc]) / 8); const ll = llbl[llc] + (((dat[cbt] | (dat[cbt + 1] << 8) | (dat[cbt + 2] << 16)) >> (spos & 7)) & ((1 << llb[llc]) - 1)); - cbt = (spos -= lbtr) >> 3; + cbt = Math.floor((spos -= lbtr) / 8); lst = llt.t[lst] + (((dat[cbt] | (dat[cbt + 1] << 8)) >> (spos & 7)) & ((1 << lbtr) - 1)); - cbt = (spos -= mbtr) >> 3; + cbt = Math.floor((spos -= mbtr) / 8); mst = mlt.t[mst] + (((dat[cbt] | (dat[cbt + 1] << 8)) >> (spos & 7)) & ((1 << mbtr) - 1)); - cbt = (spos -= obtr) >> 3; + cbt = Math.floor((spos -= obtr) / 8); ost = oct.t[ost] + (((dat[cbt] | (dat[cbt + 1] << 8)) >> (spos & 7)) & ((1 << obtr) - 1)); if (off > 3) { @@ -770,4 +771,4 @@ export class Decompress { * Handler called whenever data is decompressed */ ondata: ZstdStreamHandler; -} \ No newline at end of file +} diff --git a/tests/large_frame_overflow_test.ts b/tests/large_frame_overflow_test.ts new file mode 100644 index 0000000..df484a6 --- /dev/null +++ b/tests/large_frame_overflow_test.ts @@ -0,0 +1,76 @@ +import { assertEquals } from "https://deno.land/std@0.103.0/testing/asserts.ts"; +import * as fzstd from "../src/index.ts"; + +// Regression test for an int32 bit-position overflow in the one-shot +// `decompress()` path — see docs/large-frame-overflow.md. Once the byte +// offset `bt` into the compressed buffer reaches 2^28 (256 MiB), +// `bt << 3` wraps negative in JavaScript and the sequence decoder reads +// the wrong bytes. We generate 6-bit-alphabet noise with a short marker +// planted every 1 KiB so `zstd -1` emits btype-2 blocks containing +// sequences (`ns > 0`) that exercise the bit-position math past the +// 2^28 threshold. +// +// Requires the `zstd` CLI on PATH. +// Run: deno test --no-check --allow-run=zstd tests/large_frame_overflow_test.ts + +Deno.test("decompress() handles frames > 256 MiB", async () => { + const SIZE = 500 * 1024 * 1024; + const THRESHOLD = 1 << 28; + const MARK = new TextEncoder().encode("__fzstd_match_marker__"); + + // 6-bit-alphabet random bytes + planted markers every 1 KiB. + // Incompressible data would be emitted as raw blocks (btype 0), bypassing + // the buggy path; markerless data would give `ns == 0` and skip the + // sequence decoder. This recipe forces both to exercise the bug. + const source = new Uint8Array(SIZE); + for (let off = 0; off < SIZE; off += 65536) { + crypto.getRandomValues(source.subarray(off, Math.min(off + 65536, SIZE))); + } + for (let i = 0; i < SIZE; i++) source[i] &= 0x3f; + for (let j = 0; j + MARK.length <= SIZE; j += 1024) source.set(MARK, j); + + // Compress via the zstd CLI (generating a > 256 MiB frame with fzstd + // itself would be circular). Stream stdin -> stdout to avoid temp files. + const proc = new Deno.Command("zstd", { + args: ["-1", "--single-thread"], + stdin: "piped", + stdout: "piped", + stderr: "null", + clearEnv: true, + }).spawn(); + + const stdoutChunks = (async () => { + const chunks: Uint8Array[] = []; + for await (const chunk of proc.stdout) chunks.push(chunk); + return chunks; + })(); + const writer = proc.stdin.getWriter(); + await writer.write(source); + await writer.close(); + const chunks = await stdoutChunks; + const { code } = await proc.status; + assertEquals(code, 0); + + const compSize = chunks.reduce((s, c) => s + c.length, 0); + const compressed = new Uint8Array(compSize); + { + let o = 0; + for (const c of chunks) { compressed.set(c, o); o += c.length; } + } + // The bug only triggers when the compressed frame exceeds 2^28 bytes. + // Fail loudly if zstd compresses better than expected and the test + // would otherwise pass without exercising the bug. + if (compressed.length <= THRESHOLD) { + throw new Error( + `compressed size ${compressed.length} <= 2^28; bug would not be exercised`, + ); + } + + // Before the fix this throws `invalid zstd data`. After the fix, bytes + // match the source. + const decompressed = fzstd.decompress(compressed); + assertEquals(decompressed.length, SIZE); + const srcHash = new Uint8Array(await crypto.subtle.digest("SHA-256", source)); + const decHash = new Uint8Array(await crypto.subtle.digest("SHA-256", decompressed)); + assertEquals(decHash, srcHash); +});