import plyvel import struct import numpy as np # functions for reading IAVL tree def read_varint(x: bytes, offset: int = 0) -> tuple[int, int]: result = 0 factor = 1 for i, b in enumerate(x[offset:]): if b >= 128: result = result + (b - 128) * factor else: result = result + b * factor return result // 2, offset+i+1 factor *= 128 def read_uvarint(x: bytes, offset: int = 0) -> tuple[int, int]: result = 0 factor = 1 for i, b in enumerate(x[offset:]): if b >= 128: result = result + (b - 128) * factor else: result = result + b * factor return result, offset+i+1 factor *= 128 def write_uvarint(x: int) -> list[int]: if x < 0: raise Exception('write_uvarint only supports positive integers') elif x == 0: return [0] result = [] while x > 0: result.append(128 + x % 128) x //= 128 result[-1] -= 128 return result def read_key(key: bytes) -> tuple[int, int] | None: if not key.startswith(b's'): return None version = struct.unpack_from('>Q', key[1:9])[0] nonce = struct.unpack_from('>I', key[9:13])[0] return (version, nonce) def write_key(key: tuple[int, int]) -> bytes: version = struct.pack('>Q', key[0]) nonce = struct.pack('>I', key[1]) return b's' + version + nonce def read_node(node: bytes) -> tuple[int, int, bytes, tuple[int, int], tuple[int, int]] | tuple[int, int, list[int], bytes] | tuple[int, int]: if node.startswith(b's'): return read_key(node) n = 0 height, n = read_varint(node, n) if height == 0: length, n = read_varint(node, n) size, n = read_uvarint(node, n) key = node[n:n+size] n += size valuesize, n = read_uvarint(node, n) value = node[n:n+valuesize] return (height, length, key, value) else: length, n = read_varint(node, n) size, n = read_uvarint(node, n) key = node[n:n+size] n += size hashsize, n = read_uvarint(node, n) n += hashsize mode, n = read_uvarint(node, n) left_version, n = read_varint(node, n) left_nonce, n = read_varint(node, n) right_version, n = read_varint(node, n) right_nonce, n = read_varint(node, n) return (height, length, key, (left_version, left_nonce), (right_version, right_nonce)) def get_raw(db, prefix: bytes, version: int, searchkey: bytes) -> None | bytes: root = db.get(prefix + write_key((version, 1))) if root is None: return None node = read_node(root) if len(node) == 2: # root copy? node = read_node(db.get(prefix + write_key(node))) while node[0] > 0: # print(node) nodekey = node[2] if searchkey < nodekey: next = node[3] else: next = node[4] node = read_node(db.get(prefix + write_key(next))) if node[2] == searchkey: return node[3] else: return None def get_next_key_raw(db, prefix: bytes, version: int, searchkey: bytes) -> None | bytes: root = db.get(prefix + write_key((version, 1))) if root is None: return None node = read_node(root) lowest_geq_key = node[2] if node[2] >= searchkey else None if len(node) == 2: # root copy? node = read_node(db.get(prefix + write_key(node))) while node[0] > 0: # print(node) nodekey = node[2] if searchkey < nodekey: next = node[3] else: next = node[4] node = read_node(db.get(prefix + write_key(next))) if node[2] >= searchkey and (lowest_geq_key is None or node[2] < lowest_geq_key): lowest_geq_key = node[2] return lowest_geq_key def get(db, prefix: str, version: int, format: str, searchkey: list) -> None | bytes: return get_raw(db, prefix.encode('utf-8'), version, encode_key(format, searchkey)) def parse_pb(data): n = 0 results = [] while n < len(data): key, n = read_uvarint(data, n) ty = key & 7 key >>= 3 if ty == 2: l, n = read_uvarint(data, n) val = data[n:n+l] n += l elif ty == 0: val, n = read_uvarint(data, n) else: raise Exception(f'unknown type {ty}, {data[n:]}') results.append((key, val)) return results # find max height def next_key(db, k: bytes) -> bytes | None: it = db.iterator(start = k) try: nk, _ = next(it) return nk except StopIteration: return None finally: it.close() def max_height(db) -> int: testnr = 1<<63 for i in range(62, -1, -1): prefix = b's/k:emissions/s' n = next_key(db, prefix + struct.pack('>Q', testnr)) if n is not None and n.startswith(prefix): # print(f'{testnr:16x} is low') testnr += 1 << i else: # print(f'{testnr:16x} is high') testnr -= 1 << i n = next_key(db, prefix + struct.pack('>Q', testnr)) if n is not None and n.startswith(prefix): return testnr else: return testnr - 1 # encode and decode keys def encode_key(format: str, key: list) -> bytes: result_bytes = [] result_bytes.append(key[0]) for i, f in enumerate(format): if i >= len(key) - 1: break if f == 's': result_bytes += list(key[i+1].encode('utf-8')) if i < len(format) - 1: result_bytes += [0] elif f == 'Q': result_bytes += list(struct.pack('>Q', key[i+1])) elif f == 'q': result_bytes += list(struct.pack('>Q', key[i+1] + (1<<63))) elif f == 'b': data = list(bytes.fromhex(key[i+1])) result_bytes += write_uvarint(len(data)) result_bytes += data return bytes(result_bytes) def decode_key(format: str, key: bytes) -> list: result = [] result.append(key[0]) idx = 1 for f in format: if f == 's': end = key[idx:].find(b'\x00') if end < 0: result.append(key[idx:].decode('utf-8')) idx = len(key) break else: result.append(key[idx:idx+end].decode('utf-8')) idx += end + 1 elif f == 'Q': v = struct.unpack('>Q', key[idx:idx+8])[0] result.append(v) idx += 8 elif f == 'q': v = struct.unpack('>Q', key[idx:idx+8])[0] result.append(v - (1<<63)) idx += 8 elif f == 'b': length, offset = read_uvarint(key[idx:]) data = key[idx+offset:idx+offset+length] result.append(data.hex().upper()) idx += offset + length if idx < len(key): result.append(key[idx:]) return result # iteration class IAVLTreeIteratorRaw: def __init__(self, db, prefix: bytes, version: int, start: bytes | None = None, end: bytes | None = None): self.db = db self.prefix = prefix self.version = version self.start = start self.end = end self.stack = [] self.lookups = [] def __iter__(self): return self def get_node(self, key): key_enc = self.prefix + write_key(key) self.lookups.append(key_enc) return self.db.get(key_enc) def __next__(self): if len(self.stack) == 0: # get root node root = self.get_node((self.version, 1)) if root is None: raise StopIteration node = read_node(root) if len(node) == 2: # link to other root node node = read_node(self.get_node(node)) self.stack.append(((self.version, 1), node)) # walk tree to either last before start or first after start while node[0] > 0: # print(node) nodekey = node[2] if self.start is None or self.start < nodekey: next = node[3] else: next = node[4] node = read_node(self.get_node(next)) self.stack.append((next, node)) # return early if we ended up at first item after start if self.start is None or node[2] >= self.start: return (node[2], node[3]) # print('Stack:', [x[0] for x in self.stack]) # go up to first parent which we're a left child of key = None for i in range(len(self.stack)-1, 0, -1): current_key = self.stack[i][0] parent_node = self.stack[i-1][1] self.stack.pop() left = parent_node[3] right = parent_node[4] if current_key == left: key = right break # are we at the right end of the tree? if key is None: raise StopIteration # go right node = read_node(self.get_node(key)) self.stack.append((key, node)) # go left until at a leaf while node[0] > 0: key = node[3] node = read_node(self.get_node(key)) self.stack.append((key, node)) if self.end is not None and node[2] >= self.end: raise StopIteration return (node[2], node[3]) class IAVLTreeIterator: def __init__(self, db, prefix: bytes, version: int, format: str, start: bytes | None = None, end: bytes | None = None): self.format = format self.inner = IAVLTreeIteratorRaw(db, prefix, version, start, end) def __iter__(self): return self def __next__(self): (k, v) = next(self.inner) return (decode_key(self.format, k), v) def next_bs(x: bytes) -> bytes | None: if len(x) == 0: return None x_enc = None for i in range(len(x),0,-1): if x[i-1] != 255: x_enc = x[:i-1] + bytes([x[i-1] + 1]) + bytes([0 for _ in range(len(x)-i)]) break return x_enc def iterate(db, prefix, version, format = '', key = None, start = None, end = None): prefix_enc = prefix.encode('utf-8') if key is not None: start_enc = encode_key(format, key) end_enc = next_bs(start_enc) else: start_enc = encode_key(format, start) if start is not None else None end_enc = encode_key(format, end) if end is not None else None return IAVLTreeIterator(db, prefix_enc, version, format, start = start_enc, end = end_enc) def count(db, prefix, version, format = '', key = None, start = None, end = None): prefix_enc = prefix.encode('utf-8') if key is not None: start_enc = encode_key(format, key) end_enc = next_bs(start_enc) else: start_enc = encode_key(format, start) if start is not None else None end_enc = encode_key(format, end) if end is not None else None startidx = indexof_raw(db, prefix_enc, version, start_enc) if start_enc is not None else 0 if end_enc is not None: endidx = indexof_raw(db, prefix_enc, version, end_enc) else: # get full count it = IAVLTreeIteratorRaw(db, prefix_enc, version) try: next(it) endidx = it.stack[0][1][1] # just read the length field of the root element except StopIteration: endidx = 0 return endidx - startidx def indexof_raw(db, prefix: bytes, version: int, key: bytes) -> int: """ Find how many items come before `key` in the tree. If `key` doesn't exist, how many items come before the slot it would get inserted at """ it = IAVLTreeIteratorRaw(db, prefix, version, start=key) try: next(it) except StopIteration: # get root count return read_node(db.get(prefix + write_key(it.stack[0][0])))[1] keys = [p[1][3] for p, c in zip(it.stack, it.stack[1:]) if c[0] == p[1][4]] keys_encoded = [prefix + write_key(k) for k in keys] count = sum([read_node(db.get(k))[1] for k in keys_encoded]) return count def indexof(db, prefix: str, version: int, format: str, key: list) -> int: return indexof_raw(db, prefix.encode('utf-8'), version, encode_key(format, key))