from __future__ import annotations
from typing import (
Generic,
AsyncIterator,
Optional,
Callable,
Any,
overload,
Awaitable,
)
import heapq as _heapq
from .builtins import enumerate as a_enumerate, zip as a_zip
from ._core import aiter, awaitify, ScopedIter, borrow
from ._typing import AnyIterable, ACloseable, LT, T
class _KeyIter(Generic[LT]):
__slots__ = ("head", "tail", "reverse", "head_key", "key")
@overload
def __init__(
self,
head: T,
tail: AsyncIterator[T],
reverse: bool,
head_key: LT,
key: Callable[[T], Awaitable[LT]],
) -> None: ...
@overload
def __init__(
self, head: LT, tail: AsyncIterator[LT], reverse: bool, head_key: LT, key: None
) -> None: ...
def __init__(
self,
head: Any,
tail: AsyncIterator[Any],
reverse: bool,
head_key: LT,
key: Any,
) -> None:
self.head = head
self.head_key = head_key
self.tail = tail
self.key = key
self.reverse = reverse
@overload
@classmethod
def from_iters(
cls,
iterables: "tuple[AnyIterable[T], ...]",
reverse: bool,
key: Callable[[T], Awaitable[LT]],
) -> "AsyncIterator[_KeyIter[LT]]": ...
@overload
@classmethod
def from_iters(
cls, iterables: "tuple[AnyIterable[LT], ...]", reverse: bool, key: None
) -> "AsyncIterator[_KeyIter[LT]]": ...
@classmethod
async def from_iters(
cls,
iterables: "tuple[AnyIterable[Any], ...]",
reverse: bool,
key: Optional[Callable[[Any], Any]],
) -> "AsyncIterator[_KeyIter[Any]]":
for iterable in iterables:
iterator = aiter(iterable)
try:
head = await iterator.__anext__()
except StopAsyncIteration:
pass
else:
head_key = await key(head) if key is not None else head
yield cls(head, iterator, reverse, head_key, key)
async def pull_head(self) -> bool:
"""
Pull the next ``head`` element from the iterator and signal success
"""
try:
self.head = head = await self.tail.__anext__()
except StopAsyncIteration:
return False
else:
self.head_key = await self.key(head) if self.key is not None else head
return True
def __lt__(self, other: _KeyIter[LT]) -> bool:
return self.reverse ^ (self.head_key < other.head_key)
def __eq__(self, other: _KeyIter[LT]) -> bool: # type: ignore[override]
return not (self.head_key < other.head_key or other.head_key < self.head_key)
[docs]
async def merge(
*iterables: AnyIterable[Any],
key: Optional[Callable[[Any], Any]] = None,
reverse: bool = False,
) -> AsyncIterator[Any]:
"""
Merge all pre-sorted (async) ``iterables`` into a single sorted iterator
This works similar to ``sorted(chain(*iterables), key=key, reverse=reverse)`` but
operates lazily: at any moment only one item of each iterable is stored for the
comparison. This allows merging streams of pre-sorted items, such as timestamped
records from multiple sources.
The optional ``key`` argument specifies a one-argument (async) callable, which
provides a substitute for determining the sort order of each item.
The special value and default :py:data:`None` represents the identity function,
comparing items directly.
The default sort order is ascending, that is items with ``a < b`` imply ``a``
is yielded before ``b``. Use ``reverse=True`` for descending sort order.
The ``iterables`` must be pre-sorted in the same order.
"""
a_key = awaitify(key) if key is not None else None
# sortable iterators with (reverse) position to ensure stable sort for ties
iter_heap: "list[tuple[_KeyIter[Any], int]]" = [
(itr, idx if not reverse else -idx)
async for idx, itr in a_enumerate(
_KeyIter[Any].from_iters(iterables, reverse, a_key)
)
]
try:
_heapq.heapify(iter_heap)
# there are at least two iterators that need merging
while len(iter_heap) > 1:
while True:
itr, idx = iter_heap[0]
yield itr.head
if await itr.pull_head():
_heapq.heapreplace(iter_heap, (itr, idx))
else:
_heapq.heappop(iter_heap)
break
# there is only one iterator left, no need for merging
if iter_heap:
itr, idx = iter_heap[0]
yield itr.head
async for item in itr.tail:
yield item
finally:
for itr, _ in iter_heap:
if isinstance(itr.tail, ACloseable):
await itr.tail.aclose()
class ReverseLT(Generic[LT]):
"""Helper to reverse ``a < b`` ordering"""
__slots__ = ("key",)
def __init__(self, key: LT):
self.key = key
def __lt__(self, other: ReverseLT[LT]) -> bool:
return other.key < self.key
# Python's heapq provides a *min*-heap
# When finding the n largest items, heapq tracks the *minimum* item still large enough.
# In other words, during search we maintain opposite sort order than what is requested.
# We turn the min-heap into a max-sort in the end.
async def _largest(
iterable: AnyIterable[T],
n: int,
key: Callable[[T], Awaitable[LT]],
reverse: bool,
) -> "list[T]":
ordered: Callable[[LT], LT] = ReverseLT if reverse else lambda x: x # type: ignore
async with ScopedIter(iterable) as iterator:
# assign an ordering to items to solve ties
order_sign = -1 if reverse else 1
n_heap = [
(ordered(await key(item)), index * order_sign, item)
async for index, item in a_zip(range(n), borrow(iterator))
]
if not n_heap:
return []
_heapq.heapify(n_heap)
worst_key = n_heap[0][0]
next_index = n * order_sign
async for item in iterator:
item_key = ordered(await key(item))
if worst_key < item_key:
_heapq.heapreplace(n_heap, (item_key, next_index, item))
worst_key = n_heap[0][0]
next_index += 1 * order_sign
n_heap.sort(reverse=True)
return [item for _, _, item in n_heap]
async def _identity(x: T) -> T:
return x
[docs]
async def nlargest(
iterable: AnyIterable[T],
n: int,
key: Optional[Callable[[Any], Awaitable[Any]]] = None,
) -> "list[T]":
"""
Return a sorted list of the ``n`` largest elements from the (async) iterable
The optional ``key`` argument specifies a one-argument (async) callable, which
provides a substitute for determining the sort order of each item.
The special value and default :py:data:`None` represents the identity functions,
comparing items directly.
The result is equivalent to ``sorted(iterable, key=key, reverse=True)[:n]``,
but ``iterable`` is consumed lazily and items are discarded eagerly.
"""
a_key: Callable[[Any], Awaitable[Any]] = (
awaitify(key) if key is not None else _identity # type: ignore
)
return await _largest(iterable=iterable, n=n, key=a_key, reverse=False)
[docs]
async def nsmallest(
iterable: AnyIterable[T],
n: int,
key: Optional[Callable[[Any], Awaitable[Any]]] = None,
) -> "list[T]":
"""
Return a sorted list of the ``n`` smallest elements from the (async) iterable
Provides the reverse functionality to :py:func:`~.nlargest`.
"""
a_key: Callable[[Any], Awaitable[Any]] = (
awaitify(key) if key is not None else _identity # type: ignore
)
return await _largest(iterable=iterable, n=n, key=a_key, reverse=True)