# /api/search_coordinator.py
"""Defines the MultiSearchCoordinator that builds on the features implemented by the SearchCoordinator to create
multiple queries to different providers either sequentially or by using multithreading.
This implementation uses shared rate limiting to ensure that rate limits to different providers are not exceeded.
"""
from __future__ import annotations
from typing import Optional, Generator, Sequence, Iterable
from concurrent.futures import ThreadPoolExecutor
import concurrent.futures
import logging
from collections import UserDict, defaultdict
from scholar_flux.api import ProviderConfig
from scholar_flux.utils import generate_repr_from_string
from scholar_flux.api.models import SearchResultList, SearchResult, PageListInput
from scholar_flux.api.rate_limiting import threaded_rate_limiter_registry
from scholar_flux.api import SearchAPI, SearchCoordinator, ErrorResponse, APIResponse, NonResponse
from scholar_flux.exceptions import InvalidCoordinatorParameterException
logger = logging.getLogger(__name__)
[docs]
class MultiSearchCoordinator(UserDict):
"""The MultiSearchCoordinator is a utility class for orchestrating searches across multiple providers, pages, and
queries sequentially or using multithreading. This coordinator builds on the SearchCoordinator's core structure to
ensure consistent, rate-limited API requests.
The multi-search coordinator uses shared rate limiters to ensure that requests to the same provider (even across
different queries) will use the same rate limiter.
This implementation uses the `ThreadedRateLimiter.min_interval` parameter from the shared rate limiter of each
provider to determine the `request_delay` across all queries. These settings can be found and modified in the
`scholar_flux.api.providers.threaded_rate_limiter_registry` by `provider_name`.
For new, unregistered providers, users can override the `MultiSearchCoordinator.DEFAULT_THREADED_REQUEST_DELAY`
class variable to adjust the shared request_delay.
# Examples:
>>> from scholar_flux import MultiSearchCoordinator, SearchCoordinator, RecursiveDataProcessor
>>> from scholar_flux.api.rate_limiting import threaded_rate_limiter_registry
>>> multi_search_coordinator = MultiSearchCoordinator()
>>> threaded_rate_limiter_registry['arxiv'].min_interval = 6 # arbitrary rate limit (seconds per request)
>>>
>>> # Create coordinators for different queries and providers
>>> coordinators = [
... SearchCoordinator(
... provider_name=provider,
... query=query,
... processor=RecursiveDataProcessor(),
... user_agent="SammieH",
... cache_requests=True
... )
... for query in ('ml', 'nlp')
... for provider in ('plos', 'arxiv', 'openalex', 'crossref')
... ]
>>>
>>> # Add coordinators to the multi-search coordinator
>>> multi_search_coordinator.add_coordinators(coordinators)
>>>
>>> # Execute searches across multiple pages
>>> all_pages = multi_search_coordinator.search_pages(pages=[1, 2, 3])
>>>
>>> # filters and retains successful requests from the multi-provider search
>>> filtered_pages = all_pages.filter()
>>> # The results will contain successfully processed responses across all queries, pages, and providers
>>> print(filtered_pages) # Output will be a list of SearchResult objects
>>> # Extracts successfully processed records into a list of records where each record is a dictionary
>>> record_dict = filtered_pages.join() # retrieves a list of records
>>> print(record_dict) # Output will be a flattened list of all records
"""
DEFAULT_THREADED_REQUEST_DELAY: float | int = 6.0
[docs]
def __init__(self, *args, **kwargs):
"""Initializes the MultiSearchCoordinator, allowing positional and keyword arguments to be specified when
creating the MultiSearchCoordinator.
The initialization of the MultiSearchCoordinator operates similarly to that of a regular dict with the caveat
that values are statically typed as SearchCoordinator instances.
"""
super().__init__(*args, **kwargs)
def __setitem__(
self,
key: str,
value: SearchCoordinator,
) -> None:
"""Sets an item in the MultiSearchCoordinator.
Args:
key (str): The key used to retrieve a SearchCoordinator
value (SearchCoordinator): The value (SearchCoordinator) to associate with the key.
Raises:
InvalidCoordinatorParameterException: If the value is not a SearchCoordinator instance.
"""
self._verify_search_coordinator(value)
super().__setitem__(key, value)
@classmethod
def _verify_search_coordinator(cls, search_coordinator: SearchCoordinator):
"""Helper method that ensures that the current value is a SearchCoordinator.
Raises:
InvalidCoordinatorParameterException: If the received value is not a SearchCoordinator instance
"""
if not isinstance(search_coordinator, SearchCoordinator):
raise InvalidCoordinatorParameterException(
f"Expected a SearchCoordinator, received type {type(search_coordinator)}"
)
@property
def coordinators(self) -> list[SearchCoordinator]:
"""Utility property for quickly retrieving a list of all currently registered coordinators."""
return list(self.data.values())
[docs]
def add(self, search_coordinator: SearchCoordinator):
"""Adds a new SearchCoordinator to the MultiSearchCoordinator instance.
Args:
search_coordinator (SearchCoordinator): A search coordinator to add to the MultiSearchCoordinator dict
Raises: InvalidCoordinatorParameterException: If the expected type is not a SearchCoordinator
"""
self._verify_search_coordinator(search_coordinator)
search_coordinator = self._normalize_rate_limiter(search_coordinator)
key = self._create_key(search_coordinator)
# skipping re-evaluation via __setitem___
super().__setitem__(key, search_coordinator)
[docs]
def add_coordinators(self, search_coordinators: Iterable[SearchCoordinator]):
"""Helper method for adding a sequence of coordinators at a time."""
# ignore flagging singular coordinators as invalid by adding them to a list beforehand
search_coordinators = (
[search_coordinators] if isinstance(search_coordinators, SearchCoordinator) else search_coordinators
)
if not isinstance(search_coordinators, (Sequence, Iterable)) or isinstance(search_coordinators, str):
raise InvalidCoordinatorParameterException(
f"Expected a sequence or iterable of search_coordinators, received type {type(search_coordinators)}"
)
for search_coordinator in search_coordinators:
self.add(search_coordinator)
[docs]
def search(
self,
page: int = 1,
iterate_by_group: bool = False,
max_workers: Optional[int] = None,
multithreading: bool = True,
**kwargs,
) -> SearchResultList:
"""Public method used to search for a single or multiple pages from multiple providers at once using a
sequential or multithreading approach. This approach delegates the search to search_pages to retrieve a single
page for query and provider using an iterative approach to search for articles grouped by provider.
Note that the `MultiSearchCoordinator.search_pages` method uses shared rate limiters to ensure
that APIs are not overwhelmed by the number of requests being sent within a specific time interval.
Args:
pages (Sequence[int]): A sequence of page numbers to iteratively request from the API Provider.
from_request_cache (bool): This parameter determines whether to try to retrieve the response from the
requests-cache storage.
from_process_cache (bool): This parameter determines whether to attempt to pull processed responses from
the cache storage.
use_workflow (bool): Indicates whether to use a workflow if available Workflows are utilized by default.
Returns:
SearchResultList: The list containing all retrieved and processed pages from the API. If any non-stopping
errors occur, this will return an ErrorResponse instead with error and message attributes
further explaining any issues that occurred during processing.
"""
return self.search_pages(
pages=[page] if isinstance(page, int) else page,
iterate_by_group=iterate_by_group,
max_workers=max_workers,
multithreading=multithreading,
)
[docs]
def search_pages(
self,
pages: Sequence[int] | PageListInput,
iterate_by_group: bool = False,
max_workers: Optional[int] = None,
multithreading: bool = True,
**kwargs,
) -> SearchResultList:
"""Public method used to search articles from multiple providers at once using a sequential or multithreading
approach. This approach uses `iter_pages` under the.
Note that the `MultiSearchCoordinator.search_pages` method uses shared rate limiters to ensure
that APIs are not overwhelmed by the number of requests being sent within a specific time interval.
Args:
pages (Sequence[int]): A sequence of page numbers to iteratively request from the API Provider.
from_request_cache (bool): This parameter determines whether to try to retrieve the response from the
requests-cache storage.
from_process_cache (bool): This parameter determines whether to attempt to pull processed responses from
the cache storage.
use_workflow (bool): Indicates whether to use a workflow if available Workflows are utilized by default.
Returns:
SearchResultList: The list containing all retrieved and processed pages from the API. If any non-stopping
errors occur, this will return an ErrorResponse instead with error and message attributes
further explaining any issues that occurred during processing.
"""
search_results = SearchResultList()
if max_workers is not None and not isinstance(max_workers, int):
raise InvalidCoordinatorParameterException(
"Expected max_workers to be a positive integer, " f"Received a value of type {type(max_workers)}"
)
pages = SearchCoordinator._validate_page_list_input(pages)
if not self.data:
logger.warning(
"A coordinator has not yet been registered with the MultiSearchCoordinator: "
"returning an empty list..."
)
return search_results
if multithreading:
search_iterator: Generator[SearchResult, None, None] = self.iter_pages_threaded(
pages, max_workers=max_workers, **kwargs
)
else:
search_iterator = self.iter_pages(pages, iterate_by_group=iterate_by_group, **kwargs)
for search_result in search_iterator:
search_results.append(search_result)
logging.debug("Completed multi-search coordinated retrieval and processing")
return search_results
[docs]
def iter_pages(
self, pages: Sequence[int] | PageListInput, iterate_by_group: bool = False, **kwargs
) -> Generator[SearchResult, None, None]:
"""Helper method that creates and joins a sequence of generator functions for retrieving and processing records
from each combination of queries, pages, and providers in sequence. This implementation uses the
SearchCoordinator.iter_pages to dynamically identify when page retrieval should halt for each API provider,
accounting for errors, timeouts, and less than the expected amount of records before filtering records with pre-
specified criteria.
Args:
pages (Sequence[int]): A sequence of page numbers to iteratively request from the API Provider.
from_request_cache (bool): This parameter determines whether to try to retrieve the response from the
requests-cache storage.
from_process_cache (bool): This parameter determines whether to attempt to pull processed responses from
the cache storage.
use_workflow (bool): Indicates whether to use a workflow if available Workflows are utilized by default.
Yields:
SearchResult: Iteratively returns the SearchResult for each provider, query, and page using a generator
expression. Each result contains the requested page number (page), the name of the provider
(provider_name), and the result of the search containing a ProcessedResponse,
an ErrorResponse, or None (api response)
"""
# to eventually be used for threading by provider where each is assigned to the same chain
provider_search_dict = self.group_by_provider()
# creates a dictionary of generators grouped by provider. On each yield, each generator retrieves a single page
provider_generator_dict = {
provider_name: self._process_provider_group(group, pages, **kwargs)
for provider_name, group in provider_search_dict.items()
}
if iterate_by_group:
# Retrieve all pages from a single provider before moving to the next provider
yield from self._grouped_iteration(provider_generator_dict)
else:
# Retrieve a single page number for all providers before moving to the next page
yield from self._round_robin_iteration(provider_generator_dict)
@classmethod
def _grouped_iteration(
cls, provider_generator_dict: dict[str, Generator[SearchResult, None, None]]
) -> Generator[SearchResult, None, None]:
"""Helper method for iteratively retrieves all pages from a single provider before moving to the next.
Args:
generator_dict (Mapping[str, Generator[SearchResult, None, None]]):
A dictionary containing provider names as keys and generators as values.
Yields:
SearchResult: A search result containing the provider name, query, and page, and response from each
API Provider
"""
for provider_name, generator in provider_generator_dict.items():
yield from cls._process_page_generator(provider_name, generator)
@classmethod
def _round_robin_iteration(
cls, provider_generator_dict: dict[str, Generator[SearchResult, None, None]]
) -> Generator[SearchResult, None, None]:
"""Helper method for iteratively yielding each page from each provider in a cyclical order. This method is
implemented to ensure faster iteration given common rate-limits associated with API Providers. Note that the
received generator dictionary will be popped as each generator is consumed.
Args:
provider_generator_dict (Mapping[str, Generator[SearchResult, None, None]]):
A dictionary containing provider names as keys and generators as values.
Yields:
SearchResult: A search result containing the provider name, query, and page, and response from each
API Provider
"""
while provider_generator_dict:
inactive_generators = []
for provider_name, generator in provider_generator_dict.items():
try:
yield next(generator)
# If successful, put it back at the end
except StopIteration:
logger.debug(f"Successfully halted retrieval for provider, {provider_name}")
inactive_generators.append(provider_name)
except Exception as e:
logger.error(
"Encountered an unexpected error during iteration for provider, " f"{provider_name}: {e}"
)
inactive_generators.append(provider_name)
for provider_name in inactive_generators:
provider_generator_dict.pop(provider_name)
[docs]
def iter_pages_threaded(
self, pages: Sequence[int] | PageListInput, max_workers: Optional[int] = None, **kwargs
) -> Generator[SearchResult, None, None]:
"""Threading by provider to respect rate limits Helper method that implements threading to simultaneously
retrieve a sequence of generator functions for retrieving and processing records from each combination of
queries, pages, and providers in a multi-threaded set of sequences grouped by provider.
This implementation also uses the SearchCoordinator.iter_pages to dynamically identify when page retrieval
should halt for each API provider, accounting for errors, timeouts, and less than the expected amount of
records before filtering records with pre-specified criteria.
Note, that as threading is performed by provider, this method will not differ significantly in speed from
the `MultiSearchCoordinator.iter_pages` method if only a single provider has been specified.
Args:
pages (Sequence[int] | PageListInput): A sequence of page numbers to request from the API Provider.
from_request_cache (bool): This parameter determines whether to try to retrieve the response from the
requests-cache storage.
from_process_cache (bool): This parameter determines whether to attempt to pull processed responses from
the cache storage.
use_workflow (bool): Indicates whether to use a workflow if available Workflows are utilized by default.
Yields:
SearchResult: Iteratively returns the SearchResult for each provider, query, and page using a generator
expression as each SearchResult becomes available after multi-threaded processing.
Each result contains the requested page number (page), the name of the provider
(provider_name), and the result of the search containing a ProcessedResponse, an ErrorResponse,
or None (api response)
"""
provider_groups = self.group_by_provider()
workers = max_workers if max_workers is not None else min(8, len(provider_groups) or 1)
if workers < 1:
logger.warning(f"The value for workers ({workers}) is non-positive: defaulting to 1 worker")
workers = 1
# creates a dictionary of generators grouped by provider. On each yield, each generator retrieves a single page
provider_generator_dict = {
provider_name: self._process_provider_group(group, pages, **kwargs)
for provider_name, group in provider_groups.items()
}
with ThreadPoolExecutor(max_workers=workers) as executor:
futures = [
executor.submit(list, self._process_page_generator(provider_name, generator))
for provider_name, generator in provider_generator_dict.items()
]
for future in concurrent.futures.as_completed(futures):
yield from future.result()
@classmethod
def _process_page_generator(
cls, provider_name: str, generator: Generator[SearchResult, None, None]
) -> Generator[SearchResult, None, None]:
"""Helper method for safely consuming a generator, accounting for errors that could stop iteration during
threaded retrieval of page data.
Args:
provider_name (str): The name of the current provider
generator (Generator[SearchResult, None, None]):
A generator that returns a SearchResult upon the successful retrieval of the next page
Yields:
SearchResult: The next search result from the generator if there is at least one more page to retrieve
"""
try:
yield from generator
logger.debug(f"Successfully halted retrieval for provider, {provider_name}")
except Exception as e:
logger.error("Encountered an unexpected error during iteration for provider, " f"{provider_name}: {e}")
def _process_provider_group(
self, provider_coordinators: dict[str, SearchCoordinator], pages: Sequence[int] | PageListInput, **kwargs
) -> Generator[SearchResult, None, None]:
"""Helper method used to process all queries and pages for a single provider under a common thread. This method
is especially useful during multithreading given that API Providers often have hard limits on the total number
of requests that can be sent within a provider-specific interval.
Args:
provider_coordinators (dict[str, SearchCoordinator]):
A dictionary of all coordinators corresponding to a single provider.
pages (Sequence[int] | PageListInput): A list, set, or other common sequence of integer page numbers
corresponding to records/articles to iteratively request from the API Provider.
**kwargs: Keyword arguments to pass to the `iter_pages` method call to facilitate single or multithreaded
record page retrieval
Yields:
SearchResult: Iteratively returns the SearchResult for each provider, query, and page using a generator
expression as each SearchResult becomes available after multi-threaded processing.
Each result contains the requested page number (page), the name of the provider
(provider_name), and the result of the search containing a ProcessedResponse, an ErrorResponse,
or None (api response)
"""
# All coordinators in this group share the same threaded rate limiter
# will be used to flag non-retryable error codes from the provider for early stopping across queries if needed
last_response: Optional[APIResponse] = None
for search_coordinator in provider_coordinators.values():
provider_name = ProviderConfig._normalize_name(search_coordinator.api.provider_name)
if (
isinstance(last_response, ErrorResponse)
and not isinstance(last_response, NonResponse)
and isinstance(last_response.status_code, int)
and last_response != 200
and last_response.status_code not in search_coordinator.retry_handler.retry_statuses
):
# breaks if a non-retryable status code is encountered.
logger.warning(
f"Encountered a non-retryable response during retrieval: {last_response}. "
f"Halting retrieval for provider, {provider_name}"
)
break
# retrieve the rate from within the threaded rate limiter
default_request_delay = search_coordinator.api._rate_limiter.min_interval
request_delay = kwargs.pop("request_delay", default_request_delay)
# iterate over the current coordinator given its session, query, and settings
for page in search_coordinator.iter_pages(pages, **kwargs, request_delay=request_delay):
if isinstance(page, SearchResult):
last_response = page.response_result
yield page
[docs]
def current_providers(self) -> set[str]:
"""Extracts a set of names corresponding to the each API provider assigned to the MultiSearchCoordinator."""
return {ProviderConfig._normalize_name(coordinator.api.provider_name) for coordinator in self.data.values()}
[docs]
def group_by_provider(self) -> dict[str, dict[str, SearchCoordinator]]:
"""Groups all coordinators by provider name to facilitate retrieval with normalized components where needed.
Especially helpful in the latter retrieval of articles when using multithreading by provider (as opposed to by
page) to account for strict rate limits. All coordinated searches corresponding to a provider would appear under
a nested dictionary to facilitate orchestration on the same thread with the same rate limiter.
Returns:
dict[str, dict[str, SearchCoordinator]]:
All elements in the final dictionary map provider-specific coordinators to the normalized provider name
for the nested dictionary of coordinators.
"""
provider_search_dict: dict[str, dict[str, SearchCoordinator]] = defaultdict(dict)
for key, coordinator in self.data.items():
provider_name = ProviderConfig._normalize_name(coordinator.api.provider_name)
provider_search_dict[provider_name][key] = coordinator
return dict(provider_search_dict)
def _normalize_rate_limiter(self, search_coordinator: SearchCoordinator):
"""Helper method that retrieves the threaded rate_limiter for the coordinator's provider and normalizes the rate
limiter used for searches."""
provider_name = ProviderConfig._normalize_name(search_coordinator.api.provider_name)
# ensure that the same rate limiter is used with threading if needed to ensure rate limiting across providers
# if the provider doesn't already exist, initialize the provider rate limiter in the registry
threaded_rate_limiter = threaded_rate_limiter_registry.get_or_create(
provider_name, self.DEFAULT_THREADED_REQUEST_DELAY
)
if threaded_rate_limiter:
search_coordinator.api = SearchAPI.update(search_coordinator.api, rate_limiter=threaded_rate_limiter)
return search_coordinator
@classmethod
def _create_key(cls, search_coordinator: SearchCoordinator):
"""Create a hashed key from a coordinator using the provider name, query, and structure of the
SearchCoordinator."""
hash_value = hash(repr(search_coordinator))
provider_name = ProviderConfig._normalize_name(search_coordinator.api.provider_name)
query = str(search_coordinator.api.query)
key = f"{provider_name}_{query}:{hash_value}"
return key
[docs]
def structure(self, flatten: bool = False, show_value_attributes: bool = True) -> str:
"""Helper method that shows the current structure of the MultiSearchCoordinator."""
class_name = self.__class__.__name__
attributes = {key: coordinator.summary() for key, coordinator in self.data.items()}
return generate_repr_from_string(class_name, attributes)
def __repr__(self) -> str:
"""Helper method for generating a string representation of the current list of coordinators."""
return self.structure()
__all__ = ["MultiSearchCoordinator"]