from collections import namedtuple, defaultdict import itertools class Options: def __init__(self, **kwargs): self.override_litlen_counts = kwargs.get("override_litlen_counts", { }) self.override_dist_counts = kwargs.get("override_dist_counts", { }) self.max_uncompressed_length = kwargs.get("max_uncompressed_length", 0xffff) self.prune_interval = kwargs.get("prune_interval", 65536) self.max_match_distance = kwargs.get("max_match_distance", 32768) self.search_budget = kwargs.get("search_budget", 4096) self.force_block_types = kwargs.get("force_block_types", []) self.block_size = kwargs.get("block_size", 32768) self.invalid_sym = kwargs.get("invalid_sym", None) self.no_decode = kwargs.get("no_decode", False) Code = namedtuple("Code", "code bits") IntCoding = namedtuple("IntCoding", "symbol base bits") BinDesc = namedtuple("BinDesc", "offset value bits desc") SymExtra = namedtuple("Code", "symbol extra bits") null_code = Code(0,0) def make_int_coding(first_symbol, first_value, bit_sizes): symbol = first_symbol value = first_value codings = [] for bits in bit_sizes: codings.append(IntCoding(symbol, value, bits)) value += 1 << bits symbol += 1 return codings length_coding = make_int_coding(257, 3, [ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5, ]) distance_coding = make_int_coding(0, 1, [ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13, ]) def find_int_coding(codes, value): for coding in codes: if value < coding.base + (1 << coding.bits): return coding class BitBuf: def __init__(self): self.pos = 0 self.data = 0 self.desc = [] def push(self, val, bits, desc=""): if bits == 0: return assert val < 1 << bits val = int(val) self.desc.append(BinDesc(self.pos, val, bits, desc)) self.data |= val << self.pos self.pos += bits def push_rev(self, val, bits, desc=""): if bits == 0: return assert val < 1 << bits rev = 0 for n in range(bits): rev |= ((val >> n) & 1) << bits-n-1 self.push(rev, bits, desc) def push_code(self, code, desc=""): self.push(code.code, code.bits, desc) def push_rev_code(self, code, desc=""): if code is None: raise RuntimeError("Empty code") self.push_rev(code.code, code.bits, desc) def append(self, buf): for desc in buf.desc: self.desc.append(desc._replace(offset = desc.offset + self.pos)) self.data |= buf.data << self.pos self.pos += buf.pos def patch(self, offset, value, bits, desc=""): self.data = self.data & ~(((1 << bits) - 1) << offset) | (value << offset) def to_bytes(self): return bytes((self.data>>p&0xff) for p in range(0, self.pos, 8)) class Literal: def __init__(self, data): self.data = data self.length = len(data) def count_codes(self, litlen_count, dist_count): for c in self.data: litlen_count[c] += 1 def encode(self, buf, litlen_syms, dist_syms, opts): for c in self.data: sym = litlen_syms.get(c, opts.invalid_sym) if c >= 32 and c <= 128: buf.push_rev_code(sym, "Literal '{}' (0x{:02x})".format(chr(c), c)) else: buf.push_rev_code(sym, "Literal {:3d} (0x{:02x})".format(c, c)) def decode(self, result): result += self.data def split(self, pos): assert pos >= 0 return Literal(self.data[:pos]), Literal(self.data[pos:]) def __repr__(self): return "Literal({!r})".format(self.data) class Match: def __init__(self, length, distance): self.length = length self.distance = distance if length < 258: self.lcode = find_int_coding(length_coding, length) else: assert length == 258 self.lcode = IntCoding(285, 0, 0) self.dcode = find_int_coding(distance_coding, distance) def count_codes(self, litlen_count, dist_count): litlen_count[self.lcode.symbol] += 1 dist_count[self.dcode.symbol] += 1 def encode(self, buf, litlen_syms, dist_syms, opts): lsym = litlen_syms.get(self.lcode.symbol, opts.invalid_sym) dsym = dist_syms.get(self.dcode.symbol, opts.invalid_sym) buf.push_rev_code(lsym, "Length: {}".format(self.length)) if self.lcode.bits > 0: buf.push(self.length - self.lcode.base, self.lcode.bits, "Length extra") buf.push_rev_code(dsym, "Distance: {}".format(self.distance)) if self.dcode.bits > 0: buf.push(self.distance - self.dcode.base, self.dcode.bits, "Distance extra") def decode(self, result): begin = len(result) - self.distance assert begin >= 0 for n in range(begin, begin + self.length): result.append(result[n]) def split(self, pos): return self, Literal(b"") def __repr__(self): return "Match({}, {})".format(self.length, self.distance) def make_huffman_bits(syms, max_code_length): if len(syms) == 0: return { } if len(syms) == 1: return { next(iter(syms)): 1 } sym_groups = ((prob, (sym,)) for sym,prob in syms.items()) initial_groups = list(sorted(sym_groups)) groups = initial_groups for n in range(max_code_length-1): packaged = [(a[0]+b[0], a[1]+b[1]) for a,b in zip(groups[0::2], groups[1::2])] groups = list(sorted(packaged + initial_groups)) sym_bits = { } for g in groups[:(len(syms) - 1) * 2]: for sym in g[1]: sym_bits[sym] = sym_bits.get(sym, 0) + 1 return sym_bits def make_huffman_codes(sym_bits, max_code_length): if len(sym_bits) == 0: return { } bl_count = [0] * (max_code_length + 1) next_code = [0] * (max_code_length + 1) for bits in sym_bits.values(): bl_count[bits] += 1 code = 0 for n in range(1, max_code_length + 1): code = (code + bl_count[n - 1]) << 1 next_code[n] = code codes = { } for sym,bits in sorted(sym_bits.items()): codes[sym] = Code(next_code[bits], bits) next_code[bits] += 1 return codes def make_huffman(syms, max_code_length): sym_bits = make_huffman_bits(syms, max_code_length) return make_huffman_codes(sym_bits, max_code_length) def decode(message): result = [] for m in message: m.decode(result) return bytes(result) def encode_huff_bits(bits): encoded = [] for value,copies in itertools.groupby(bits): num = len(list(copies)) assert value < 16 if value == 0: while num >= 11: amount = min(num, 138) encoded.append(SymExtra(18, amount-11, 7)) num -= amount while num >= 3: amount = min(num, 10) encoded.append(SymExtra(17, amount-3, 3)) num -= amount while num >= 1: encoded.append(SymExtra(0, 0, 0)) num -= 1 else: encoded.append(SymExtra(value, 0, 0)) num -= 1 while num >= 3: amount = min(num, 6) encoded.append(SymExtra(16, amount-3, 2)) num -= amount while num >= 1: encoded.append(SymExtra(value, 0, 0)) num -= 1 return encoded def write_encoded_huff_bits(buf, codes, syms, desc): value = 0 prev = 0 for code in codes: sym = code.symbol num = 1 if sym <= 15: buf.push_rev_code(syms[sym], "{} {} bits: {}".format(desc, value, sym)) prev = sym elif sym == 16: num = code.extra + 3 buf.push_rev_code(syms[sym], "{} {}-{} bits: {}".format(desc, value, value+num-1, prev)) elif sym == 17: num = code.extra + 3 buf.push_rev_code(syms[sym], "{} {}-{} bits: {}".format(desc, value, value+num-1, 0)) elif sym == 18: num = code.extra + 11 buf.push_rev_code(syms[sym], "{} {}-{} bits: {}".format(desc, value, value+num-1, 0)) value += num if code.bits > 0: buf.push(code.extra, code.bits, "{} N={}".format(desc, num)) def prune_matches(matches, offset, opts): new_matches = defaultdict(list) begin = offset - opts.max_match_distance for trigraph,chain in matches.items(): new_chain = [o for o in chain if o >= begin] if new_chain: new_matches[trigraph] = new_chain return new_matches def match_block(data, opts=Options()): message = [] matches = defaultdict(list) literal = [] offset = 0 size = len(data) prune_interval = 0 while offset + 3 <= size: trigraph = data[offset:offset+3] advance = 1 match_begin, match_length = 0, 0 search_steps = 0 for m in reversed(matches[trigraph]): length = 3 while offset + length < size and length < 258: if data[offset + length] != data[m + length]: break length += 1 search_steps += 1 if length > match_length and m - offset <= 32768: match_begin, match_length = m, length if search_steps >= opts.search_budget: break if match_length > 0: if literal: message.append(Literal(bytes(literal))) literal.clear() message.append(Match(match_length, offset - match_begin)) advance = match_length else: literal.append(data[offset]) for n in range(advance): if offset >= 3: trigraph = data[offset - 3:offset] matches[trigraph].append(offset - 3) offset += 1 prune_interval += advance if prune_interval >= opts.prune_interval: matches = prune_matches(matches, offset, opts) prune_interval = 0 while offset < size: literal.append(data[offset]) offset += 1 if literal: message.append(Literal(bytes(literal))) return message def compress_block_uncompressed(buf, data, align, final, opts): size = len(data) begin = 0 while begin < size: amount = min(size - begin, opts.max_uncompressed_length) end = begin + amount real_final = final and end == size buf.push(real_final, 1, "BFINAL Final chunk: {}".format(real_final)) buf.push(0b00, 2, "BTYPE Chunk type: Uncompressed") buf.push(0, -(buf.pos + align) & 7, "Pad to byte") buf.push(amount, 16, "LEN: {}".format(amount)) buf.push(~amount&0xffff, 16, "NLEN: ~{}".format(amount)) for byte in data[begin:end]: buf.push(byte, 8, "Byte '{}' ({:02x})".format(chr(byte), byte)) begin = end def compress_block_static(buf, message, final, opts): litlen_bits = [8]*(144-0) + [9]*(256-144) + [7]*(280-256) + [8]*(288-280) distance_bits = [5] * 32 litlen_syms = make_huffman_codes(dict(enumerate(litlen_bits)), 16) distance_syms = make_huffman_codes(dict(enumerate(distance_bits)), 16) buf.push(final, 1, "BFINAL Final chunk: {}".format(final)) buf.push(0b01, 2, "BTYPE Chunk type: Static Huffman") for m in message: m.encode(buf, litlen_syms, distance_syms, opts) # End-of-block buf.push_rev_code(litlen_syms.get(256, opts.invalid_sym), "End-of-block") def compress_block_dynamic(buf, message, final, opts): litlen_count = [0] * 286 distance_count = [0] * 30 # There's always one end-of-block litlen_count[256] = 1 for m in message: m.count_codes(litlen_count, distance_count) for sym,count in opts.override_litlen_counts.items(): litlen_count[sym] = count for sym,count in opts.override_dist_counts.items(): distance_count[sym] = count litlen_map = { sym: count for sym,count in enumerate(litlen_count) if count > 0 } distance_map = { sym: count for sym,count in enumerate(distance_count) if count > 0 } litlen_syms = make_huffman(litlen_map, 15) distance_syms = make_huffman(distance_map, 15) num_litlens = max(itertools.chain((k for k in litlen_map.keys()), (256,))) + 1 num_distances = max(itertools.chain((k for k in distance_map.keys()), (0,))) + 1 litlen_bits = [litlen_syms.get(s, null_code).bits for s in range(num_litlens)] distance_bits = [distance_syms.get(s, null_code).bits for s in range(num_distances)] litlen_bit_codes = encode_huff_bits(litlen_bits) distance_bit_codes = encode_huff_bits(distance_bits) codelen_count = [0] * 20 for code in itertools.chain(litlen_bit_codes, distance_bit_codes): codelen_count[code.symbol] += 1 codelen_map = { sym: count for sym,count in enumerate(codelen_count) if count > 0 } codelen_syms = make_huffman(codelen_map, 8) codelen_permutation = [16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15] num_codelens = 0 for i, p in enumerate(codelen_permutation): if codelen_count[p] > 0: num_codelens = i + 1 num_codelens = max(num_codelens, 4) buf.push(final, 1, "BFINAL Final chunk: {}".format(final)) buf.push(0b10, 2, "BTYPE Chunk type: Dynamic Huffman") buf.push(num_litlens - 257, 5, "HLIT Number of Litlen codes: {} (257 + {})".format(num_litlens, num_litlens - 257)) buf.push(num_distances - 1, 5, "HDIST Number of Distance codes: {} (1 + {})".format(num_distances, num_distances - 1)) buf.push(num_codelens - 4, 4, "HCLEN Number of Codelen codes: {} (4 + {})".format(num_codelens, num_codelens - 4)) for p in codelen_permutation[:num_codelens]: bits = 0 if p in codelen_syms: bits = codelen_syms[p].bits buf.push(bits, 3, "Codelen {} bits: {}".format(p, bits)) write_encoded_huff_bits(buf, litlen_bit_codes, codelen_syms, "Litlen") write_encoded_huff_bits(buf, distance_bit_codes, codelen_syms, "Distance") for m in message: m.encode(buf, litlen_syms, distance_syms, opts) # End-of-block buf.push_rev_code(litlen_syms.get(256, opts.invalid_sym), "End-of-block") def adler32(data): a, b = 1, 0 for d in data: a = (a + d) % 65521 b = (b + a) % 65521 return b << 16 | a def compress_message(message, opts=Options(), *args): buf = BitBuf() # ZLIB CFM byte buf.push(8, 4, "CM=8 Compression method: DEFLATE") buf.push(7, 4, "CINFO=7 Compression info: 32kB window size") # ZLIB FLG byte buf.push(28, 5, "FCHECK (CMF*256+FLG) % 31 == 0") buf.push(0, 1, "FDICT=0 Preset dictionary: No") buf.push(2, 2, "FLEVEL=2 Compression level: Default") multi_part = False multi_messages = [] multi_opts = [] if args: multi_part = True multi_messages = [message] multi_opts = [opts] args_it = iter(args) message = message[:] for msg, opt in zip(args_it, args_it): message += msg multi_messages.append(msg) multi_opts.append(opt) byte_offset = 0 part_pos = 0 num_parts = len(message) overflow_part = Literal(b"") block_message = [] block_opts = opts message_bytes = b"" if opts.no_decode else decode(message) last_part = False multi_index = 0 while not last_part: if multi_part: block_message = multi_messages[multi_index] block_opts = multi_opts[multi_index] size = sum(m.length for m in block_message) block_index = 0 multi_index += 1 last_part = multi_index == len(multi_messages) else: block_message.clear() part, overflow_part = overflow_part.split(opts.block_size) if part.length > 0: block_message.append(part) size = part.length # Append parts until desired block size is reached if size < opts.block_size: while part_pos < num_parts: part = message[part_pos] part_pos += 1 if size + part.length >= opts.block_size: last_part, overflow_part = part.split(opts.block_size - size) if last_part.length > 0: block_message.append(last_part) size += last_part.length break else: block_message.append(part) size += part.length last_part = part_pos >= num_parts and overflow_part.length == 0 # Compress the block best_buf = None block_index = 0 for block_type in range(3): if block_index < len(block_opts.force_block_types): if block_type != block_opts.force_block_types[block_index]: continue block_buf = BitBuf() if block_type == 0: compress_block_uncompressed(block_buf, message_bytes[byte_offset:byte_offset + size], buf.pos, last_part, block_opts) elif block_type == 1: compress_block_static(block_buf, block_message, last_part, block_opts) elif block_type == 2: compress_block_dynamic(block_buf, block_message, last_part, block_opts) if not best_buf or block_buf.pos < best_buf.pos: best_buf = block_buf buf.append(best_buf) byte_offset += size block_index += 1 buf.push(0, -buf.pos & 7, "Pad to byte") adler_hash = adler32(message_bytes) buf.push((adler_hash >> 24) & 0xff, 8, "Adler[24:32]") buf.push((adler_hash >> 16) & 0xff, 8, "Adler[16:24]") buf.push((adler_hash >> 8) & 0xff, 8, "Adler[8:16]") buf.push((adler_hash >> 0) & 0xff, 8, "Adler[0:8]") return buf def deflate(data, opts=Options()): message = match_block(data, opts) encoded = compress_message(message, opts) return encoded def print_huffman(tree): width = max(len(str(s)) for s in tree.keys()) for sym,code in tree.items(): print("".format(sym, width, code.code, code.bits)) def print_buf(buf): for d in buf.desc: val = " {0:0{1}b}".format(d.value, d.bits) if len(val) > 10: val = "0x{0:x}".format(d.value) desc = d.desc patched_value = (buf.data >> d.offset) & ((1 << d.bits) - 1) spacer = "|" if patched_value != d.value: desc += " >>> Patched to: {0:0{1}b} ({0})".format(patched_value, d.bits) spacer = ">" print("{0:>4} {0:>4x} {5}{1:>2} {5} {2:>10} {5} {3:>4} {5} {4}".format(d.offset, d.bits, val, d.value, desc, spacer)) def print_bytes(data): print(''.join('\\x%02x' % b for b in data))