kvstore/iavltree.py
2025-10-13 13:04:33 -04:00

432 lines
13 KiB
Python

import plyvel
import struct
# functions for reading IAVL tree
def read_varint(x: bytes, offset: int = 0) -> tuple[int, int]:
result = 0
factor = 1
for i, b in enumerate(x[offset:]):
if b >= 128:
result = result + (b - 128) * factor
else:
result = result + b * factor
return result // 2, offset+i+1
factor *= 128
def read_uvarint(x: bytes, offset: int = 0) -> tuple[int, int]:
result = 0
factor = 1
for i, b in enumerate(x[offset:]):
if b >= 128:
result = result + (b - 128) * factor
else:
result = result + b * factor
return result, offset+i+1
factor *= 128
def write_uvarint(x: int) -> list[int]:
if x < 0:
raise Exception('write_uvarint only supports positive integers')
elif x == 0:
return [0]
result = []
while x > 0:
result.append(128 + x % 128)
x //= 128
result[-1] -= 128
return result
def read_key(key: bytes) -> tuple[int, int] | None:
if not key.startswith(b's'):
return None
version = struct.unpack_from('>Q', key[1:9])[0]
nonce = struct.unpack_from('>I', key[9:13])[0]
return (version, nonce)
def write_key(key: tuple[int, int]) -> bytes:
version = struct.pack('>Q', key[0])
nonce = struct.pack('>I', key[1])
return b's' + version + nonce
def read_node(node: bytes) -> tuple[int, int, bytes, tuple[int, int], tuple[int, int]] | tuple[int, int, list[int], bytes] | tuple[int, int]:
if node.startswith(b's'):
return read_key(node)
n = 0
height, n = read_varint(node, n)
if height == 0:
length, n = read_varint(node, n)
size, n = read_uvarint(node, n)
key = node[n:n+size]
n += size
valuesize, n = read_uvarint(node, n)
value = node[n:n+valuesize]
return (height, length, key, value)
else:
length, n = read_varint(node, n)
size, n = read_uvarint(node, n)
key = node[n:n+size]
n += size
hashsize, n = read_uvarint(node, n)
n += hashsize
mode, n = read_uvarint(node, n)
left_version, n = read_varint(node, n)
left_nonce, n = read_varint(node, n)
right_version, n = read_varint(node, n)
right_nonce, n = read_varint(node, n)
return (height, length, key, (left_version, left_nonce), (right_version, right_nonce))
def get_raw(db: plyvel.DB, prefix: bytes, version: int, searchkey: bytes) -> None | bytes:
root = db.get(prefix + write_key((version, 1)))
if root is None:
return None
node = read_node(root)
if len(node) == 2: # root copy?
node = read_node(db.get(prefix + write_key(node)))
while node[0] > 0:
# print(node)
nodekey = node[2]
if searchkey < nodekey:
next = node[3]
else:
next = node[4]
node = read_node(db.get(prefix + write_key(next)))
if node[2] == searchkey:
return node[3]
else:
return None
def get_next_key_raw(db: plyvel.DB, prefix: bytes, version: int, searchkey: bytes) -> None | bytes:
root = db.get(prefix + write_key((version, 1)))
if root is None:
return None
node = read_node(root)
lowest_geq_key = node[2] if node[2] >= searchkey else None
if len(node) == 2: # root copy?
node = read_node(db.get(prefix + write_key(node)))
while node[0] > 0:
# print(node)
nodekey = node[2]
if searchkey < nodekey:
next = node[3]
else:
next = node[4]
node = read_node(db.get(prefix + write_key(next)))
if node[2] >= searchkey and (lowest_geq_key is None or node[2] < lowest_geq_key):
lowest_geq_key = node[2]
return lowest_geq_key
def get(db: plyvel.DB, prefix: str, version: int, format: str, searchkey: list) -> None | bytes:
return get_raw(db, prefix.encode('utf-8'), version, encode_key(format, searchkey))
def parse_pb(data):
n = 0
results = []
while n < len(data):
key, n = read_uvarint(data, n)
ty = key & 7
key >>= 3
if ty == 2:
l, n = read_uvarint(data, n)
val = data[n:n+l]
n += l
elif ty == 0:
val, n = read_uvarint(data, n)
else:
raise Exception(f'unknown type {ty}, {data[n:]}')
results.append((key, val))
return results
# find max height
def next_key(db, k: bytes) -> bytes | None:
it = db.iterator(start = k)
try:
nk, _ = next(it)
return nk
except StopIteration:
return None
finally:
it.close()
def max_height(db: plyvel.DB, prefix: bytes) -> int:
testnr = 1<<63
for i in range(62, -1, -1):
n = next_key(db, prefix + b's' + struct.pack('>Q', testnr) + struct.pack('>I', 1))
if n is not None and n.startswith(prefix):
# print(f'{testnr:16x} is low')
testnr += 1 << i
else:
# print(f'{testnr:16x} is high')
testnr -= 1 << i
n = next_key(db, prefix + struct.pack('>Q', testnr))
if n is not None and n.startswith(prefix):
return testnr
else:
return testnr - 1
def min_max_height(db: plyvel.DB, prefix: bytes) -> tuple[int, int]:
hmax = max_height(db, prefix)
h = 1<<hmax.bit_length()
inc = h>>1
for _ in range(25):
if h > hmax:
highenough = True
else:
root = db.get(prefix + write_key((h, 1)))
highenough = root is not None
# print(h, highenough, inc)
(h, inc) = (h + (1 - 2*highenough) * inc, inc >> 1)
if not highenough:
h += 1
return (h, hmax)
# encode and decode keys
def encode_key(format: str, key: list) -> bytes:
result_bytes = []
result_bytes.append(key[0])
for i, f in enumerate(format):
if i >= len(key) - 1:
break
if f == 's':
result_bytes += list(key[i+1].encode('utf-8'))
if i < len(format) - 1:
result_bytes += [0]
elif f == 'Q':
result_bytes += list(struct.pack('>Q', key[i+1]))
elif f == 'q':
result_bytes += list(struct.pack('>Q', key[i+1] + (1<<63)))
elif f == 'b':
data = list(bytes.fromhex(key[i+1]))
result_bytes += write_uvarint(len(data))
result_bytes += data
return bytes(result_bytes)
def decode_key(format: str, key: bytes) -> list:
result = []
result.append(key[0])
idx = 1
for f in format:
if f == 's':
end = key[idx:].find(b'\x00')
if end < 0:
result.append(key[idx:].decode('utf-8'))
idx = len(key)
break
else:
result.append(key[idx:idx+end].decode('utf-8'))
idx += end + 1
elif f == 'Q':
v = struct.unpack('>Q', key[idx:idx+8])[0]
result.append(v)
idx += 8
elif f == 'q':
v = struct.unpack('>Q', key[idx:idx+8])[0]
result.append(v - (1<<63))
idx += 8
elif f == 'b':
length, offset = read_uvarint(key[idx:])
data = key[idx+offset:idx+offset+length]
result.append(data.hex().upper())
idx += offset + length
if idx < len(key):
result.append(key[idx:])
return result
# iteration
class IAVLTreeIteratorRaw:
def __init__(self, db: plyvel.DB, prefix: bytes, version: int, start: bytes | None = None, end: bytes | None = None):
self.db = db
self.prefix = prefix
self.version = version
self.start = start
self.end = end
self.stack = []
self.lookups = []
def __iter__(self):
return self
def get_node(self, key):
key_enc = self.prefix + write_key(key)
self.lookups.append(key_enc)
return self.db.get(key_enc)
def __next__(self):
if len(self.stack) == 0:
# get root node
root = self.get_node((self.version, 1))
if root is None:
raise StopIteration
node = read_node(root)
if len(node) == 2: # link to other root node
node = read_node(self.get_node(node))
self.stack.append(((self.version, 1), node))
# walk tree to either last before start or first after start
while node[0] > 0:
# print(node)
nodekey = node[2]
if self.start is None or self.start < nodekey:
next = node[3]
else:
next = node[4]
node = read_node(self.get_node(next))
self.stack.append((next, node))
# return early if we ended up at first item after start
if self.start is None or node[2] >= self.start:
return (node[2], node[3])
# print('Stack:', [x[0] for x in self.stack])
# go up to first parent which we're a left child of
key = None
for i in range(len(self.stack)-1, 0, -1):
current_key = self.stack[i][0]
parent_node = self.stack[i-1][1]
self.stack.pop()
left = parent_node[3]
right = parent_node[4]
if current_key == left:
key = right
break
# are we at the right end of the tree?
if key is None:
raise StopIteration
# go right
node = read_node(self.get_node(key))
self.stack.append((key, node))
# go left until at a leaf
while node[0] > 0:
key = node[3]
node = read_node(self.get_node(key))
self.stack.append((key, node))
if self.end is not None and node[2] >= self.end:
raise StopIteration
return (node[2], node[3])
class IAVLTreeIterator:
def __init__(self, db: plyvel.DB, prefix: bytes, version: int, format: str, start: bytes | None = None, end: bytes | None = None):
self.format = format
self.inner = IAVLTreeIteratorRaw(db, prefix, version, start, end)
def __iter__(self):
return self
def __next__(self):
(k, v) = next(self.inner)
return (decode_key(self.format, k), v)
def next_bs(x: bytes) -> bytes | None:
if len(x) == 0:
return None
x_enc = None
for i in range(len(x),0,-1):
if x[i-1] != 255:
x_enc = x[:i-1] + bytes([x[i-1] + 1]) + bytes([0 for _ in range(len(x)-i)])
break
return x_enc
def iterate(db: plyvel.DB, prefix, version, format = '', key = None, start = None, end = None):
prefix_enc = prefix.encode('utf-8')
if key is not None:
start_enc = encode_key(format, key)
end_enc = next_bs(start_enc)
else:
start_enc = encode_key(format, start) if start is not None else None
end_enc = encode_key(format, end) if end is not None else None
return IAVLTreeIterator(db, prefix_enc, version, format, start = start_enc, end = end_enc)
def count(db: plyvel.DB, prefix, version, format = '', key = None, start = None, end = None):
prefix_enc = prefix.encode('utf-8')
if key is not None:
start_enc = encode_key(format, key)
end_enc = next_bs(start_enc)
else:
start_enc = encode_key(format, start) if start is not None else None
end_enc = encode_key(format, end) if end is not None else None
startidx = indexof_raw(db, prefix_enc, version, start_enc) if start_enc is not None else 0
if end_enc is not None:
endidx = indexof_raw(db, prefix_enc, version, end_enc)
else:
# get full count
it = IAVLTreeIteratorRaw(db, prefix_enc, version)
try:
next(it)
endidx = it.stack[0][1][1] # just read the length field of the root element
except StopIteration:
endidx = 0
return endidx - startidx
def indexof_raw(db: plyvel.DB, prefix: bytes, version: int, key: bytes) -> int:
"""
Find how many items come before `key` in the tree. If `key` doesn't exist, how many
items come before the slot it would get inserted at
"""
it = IAVLTreeIteratorRaw(db, prefix, version, start=key)
try:
next(it)
except StopIteration:
# get root count
return it.stack[0][1][1]
keys = [p[1][3] for p, c in zip(it.stack, it.stack[1:]) if c[0] == p[1][4]]
keys_encoded = [prefix + write_key(k) for k in keys]
count = sum([read_node(db.get(k))[1] for k in keys_encoded])
return count
def indexof(db: plyvel.DB, prefix: str, version: int, format: str, key: list) -> int:
return indexof_raw(db, prefix.encode('utf-8'), version, encode_key(format, key))