Skip to content

Commit f3672a9

Browse files
committed
updates for hash_tree
1 parent 3c3766a commit f3672a9

File tree

1 file changed

+34
-21
lines changed

1 file changed

+34
-21
lines changed

ssz/hash_tree.py

+34-21
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
partial,
33
)
44
import itertools
5-
from numbers import (
6-
Integral,
7-
)
85
from typing import (
96
Any,
107
Callable,
@@ -14,6 +11,7 @@
1411
Tuple,
1512
Union,
1613
cast,
14+
overload,
1715
)
1816

1917
# `transform` comes from a non-public API which is considered stable, but future changes
@@ -40,6 +38,7 @@
4038
from pyrsistent.typing import (
4139
PMap,
4240
PVector,
41+
PVectorEvolver,
4342
)
4443

4544
from ssz.constants import (
@@ -163,7 +162,9 @@ def extend(self, value: Iterable[Hash32]) -> "HashTree":
163162
def __add__(self, other: Iterable[Hash32]) -> "HashTree":
164163
return self.extend(other)
165164

166-
def __mul__(self, times: int) -> "HashTree":
165+
# we override __mul__ to allow for a more natural syntax
166+
# when using the evolver
167+
def __mul__(self, times: int) -> "HashTree": # type: ignore[override]
167168
if times <= 0:
168169
raise ValueError(f"Multiplier must be greater or equal to 1, got {times}")
169170

@@ -202,7 +203,7 @@ def remove(self, value: Hash32) -> "HashTree":
202203
return self.__class__.compute(chunks, self.chunk_count)
203204

204205

205-
class HashTreeEvolver:
206+
class HashTreeEvolver(PVectorEvolver[Hash32]):
206207
def __init__(self, hash_tree: "HashTree") -> None:
207208
self.original_hash_tree = hash_tree
208209
self.updated_chunks: PMap[int, Hash32] = pmap()
@@ -211,7 +212,18 @@ def __init__(self, hash_tree: "HashTree") -> None:
211212
#
212213
# Getters
213214
#
215+
@overload
214216
def __getitem__(self, index: int) -> Hash32:
217+
...
218+
219+
@overload
220+
def __getitem__(self, index: slice) -> "HashTreeEvolver":
221+
...
222+
223+
def __getitem__(self, index: Union[int, slice]) -> Union[Hash32, "HashTreeEvolver"]:
224+
if isinstance(index, slice):
225+
raise NotImplementedError("Slicing not implemented.")
226+
215227
if index < 0:
216228
index += len(self)
217229

@@ -235,33 +247,34 @@ def is_dirty(self) -> bool:
235247
#
236248
# Setters
237249
#
238-
def set(self, index: Integral, value: Hash32) -> None:
250+
def set(self, index: int, value: Hash32) -> "HashTreeEvolver":
239251
self[index] = value
252+
return self
240253

241-
def __setitem__(self, index: Integral, value: Hash32) -> None:
242-
idx = int(index)
243-
244-
if idx < 0:
245-
idx += len(self)
254+
def __setitem__(self, index: int, value: Hash32) -> None:
255+
if index < 0:
256+
index += len(self)
246257

247-
if 0 <= idx < len(self.original_hash_tree):
248-
self.updated_chunks = self.updated_chunks.set(idx, value)
249-
elif idx < len(self):
250-
index_in_appendix = idx - len(self.original_hash_tree)
258+
if 0 <= index < len(self.original_hash_tree):
259+
self.updated_chunks = self.updated_chunks.set(index, value)
260+
elif index < len(self):
261+
index_in_appendix = index - len(self.original_hash_tree)
251262
self.appended_chunks = self.appended_chunks.set(index_in_appendix, value)
252263
else:
253-
raise IndexError(f"Index out of bounds: {idx}")
264+
raise IndexError(f"Index out of bounds: {index}")
254265

255266
#
256267
# Length changing modifiers
257268
#
258-
def append(self, value: Hash32) -> None:
269+
def append(self, value: Hash32) -> "HashTreeEvolver":
259270
self.appended_chunks = self.appended_chunks.append(value)
260271
self._check_chunk_count()
272+
return self
261273

262-
def extend(self, values: Iterable[Hash32]) -> None:
274+
def extend(self, values: Iterable[Hash32]) -> "HashTreeEvolver":
263275
self.appended_chunks = self.appended_chunks.extend(values)
264276
self._check_chunk_count()
277+
return self
265278

266279
def _check_chunk_count(self) -> None:
267280
chunk_count = self.original_hash_tree.chunk_count
@@ -271,10 +284,10 @@ def _check_chunk_count(self) -> None:
271284
#
272285
# Not implemented
273286
#
274-
def delete(self, index: int, stop: Optional[int] = None) -> None:
287+
def delete(self, index: int, stop: Optional[int] = None) -> None: # type: ignore[override] # noqa: E501
275288
raise NotImplementedError()
276289

277-
def __delitem__(self, index: int) -> None:
290+
def __delitem__(self, index: Union[int, slice]) -> None:
278291
raise NotImplementedError()
279292

280293
def remove(self, value: Hash32) -> None:
@@ -432,7 +445,7 @@ def set_chunk_in_tree(hash_tree: RawHashTree, index: int, chunk: Hash32) -> RawH
432445
for layer_index, hash_index in zip(parent_layer_indices, parent_hash_indices)
433446
)
434447

435-
hash_tree_with_updated_branch = pipe(
448+
hash_tree_with_updated_branch: PVector[PVector[Any]] = pipe(
436449
hash_tree_with_updated_chunk, *update_functions
437450
)
438451

0 commit comments

Comments
 (0)