Source code for scholar_flux.api.multisearch_coordinator

# /api/multisearch_coordinator.py
"""Defines the MultiSearchCoordinator that builds on the features implemented by the SearchCoordinator to create
multiple queries to different providers either sequentially or by using multithreading.

This implementation uses shared rate limiting to ensure that rate limits to different providers are not exceeded.

"""
from __future__ import annotations
from typing import Optional, Generator, Sequence, Iterable, Any
from typing_extensions import Self
from concurrent.futures import ThreadPoolExecutor
import concurrent.futures
import logging

from collections import UserDict, defaultdict
from scholar_flux.api import ProviderConfig
from scholar_flux.utils import generate_repr_from_string
from scholar_flux.api.models import SearchResultList, SearchResult, PageListInput
from scholar_flux.api.rate_limiting import threaded_rate_limiter_registry
from scholar_flux.api import SearchAPI, SearchCoordinator, ErrorResponse, APIResponse, NonResponse
from scholar_flux.exceptions import InvalidCoordinatorParameterException


logger = logging.getLogger(__name__)


[docs] class MultiSearchCoordinator(UserDict[str, SearchCoordinator]): """The MultiSearchCoordinator is a utility class for orchestrating searches across multiple providers, pages, and queries sequentially or using multithreading. This coordinator builds on the SearchCoordinator's core structure to ensure consistent, rate-limited API requests. The multi-search coordinator uses shared rate limiters to ensure that requests to the same provider (even across different queries) will use the same rate limiter. This implementation uses the `ThreadedRateLimiter.min_interval` parameter from the shared rate limiter of each provider to determine the `request_delay` across all queries. These settings can be found and modified in the `scholar_flux.api.providers.threaded_rate_limiter_registry` by `provider_name`. For new, unregistered providers, users can override the `MultiSearchCoordinator.DEFAULT_THREADED_REQUEST_DELAY` class variable to adjust the shared request_delay. # Examples: >>> from scholar_flux import MultiSearchCoordinator, SearchCoordinator, RecursiveDataProcessor >>> from scholar_flux.api.rate_limiting import threaded_rate_limiter_registry >>> multi_search_coordinator = MultiSearchCoordinator() >>> threaded_rate_limiter_registry['arxiv'].min_interval = 6 # arbitrary rate limit (seconds per request) >>> >>> # Create coordinators for different queries and providers >>> coordinators = [ ... SearchCoordinator( ... provider_name=provider, ... query=query, ... processor=RecursiveDataProcessor(), ... user_agent="SammieH", ... cache_requests=True ... ) ... for query in ('ml', 'nlp') ... for provider in ('plos', 'arxiv', 'openalex', 'crossref') ... ] >>> >>> # Add coordinators to the multi-search coordinator >>> multi_search_coordinator.add_coordinators(coordinators) >>> >>> # Execute searches across multiple pages >>> all_pages = multi_search_coordinator.search_pages(pages=[1, 2, 3]) >>> >>> # filters and retains successful requests from the multi-provider search >>> filtered_pages = all_pages.filter() >>> # The results will contain successfully processed responses across all queries, pages, and providers >>> print(filtered_pages) # Output will be a list of SearchResult objects >>> # Extracts successfully processed records into a list of records where each record is a dictionary >>> record_dict = filtered_pages.join() # retrieves a list of records >>> print(record_dict) # Output will be a flattened list of all records """ DEFAULT_THREADED_REQUEST_DELAY: float | int = 6.0
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: """Initializes the MultiSearchCoordinator, allowing positional and keyword arguments to be specified when creating the MultiSearchCoordinator. The initialization of the MultiSearchCoordinator operates similarly to that of a regular dict with the caveat that values are statically typed as SearchCoordinator instances. """ super().__init__(*args, **kwargs)
def __setitem__( self, key: str, value: SearchCoordinator, ) -> None: """Sets an item in the MultiSearchCoordinator. Args: key (str): The key used to retrieve a SearchCoordinator value (SearchCoordinator): The value (SearchCoordinator) to associate with the key. Raises: InvalidCoordinatorParameterException: If the value is not a SearchCoordinator instance. """ self._verify_search_coordinator(value) super().__setitem__(key, value) @classmethod def _verify_search_coordinator(cls, search_coordinator: SearchCoordinator) -> None: """Helper method that ensures that the current value is a SearchCoordinator. Raises: InvalidCoordinatorParameterException: If the received value is not a SearchCoordinator instance """ if not isinstance(search_coordinator, SearchCoordinator): raise InvalidCoordinatorParameterException( f"Expected a SearchCoordinator, received type {type(search_coordinator)}" ) @property def coordinators(self) -> list[SearchCoordinator]: """Utility property for quickly retrieving a list of all currently registered coordinators.""" return list(self.data.values())
[docs] def add(self, search_coordinator: SearchCoordinator) -> None: """Adds a new SearchCoordinator to the MultiSearchCoordinator instance. Args: search_coordinator (SearchCoordinator): A search coordinator to add to the MultiSearchCoordinator dict Raises: InvalidCoordinatorParameterException: If the expected type is not a SearchCoordinator """ self._verify_search_coordinator(search_coordinator) search_coordinator = self._normalize_rate_limiter(search_coordinator) key = self._create_key(search_coordinator) # skipping re-evaluation via __setitem__ super().__setitem__(key, search_coordinator)
[docs] @classmethod def from_coordinators(cls, search_coordinators: Iterable[SearchCoordinator]) -> Self: """Constructs a new `MultiSearchCoordinator` instance from a sequence of coordinators at a time.""" multi_search_coordinator = cls() multi_search_coordinator.add_coordinators(search_coordinators) return multi_search_coordinator
[docs] def add_coordinators(self, search_coordinators: Iterable[SearchCoordinator]) -> None: """Helper method for adding a sequence of coordinators at a time.""" # ignore flagging singular coordinators as invalid by adding them to a list beforehand search_coordinators = ( [search_coordinators] if isinstance(search_coordinators, SearchCoordinator) else search_coordinators ) if not isinstance(search_coordinators, (Sequence, Iterable)) or isinstance(search_coordinators, str): raise InvalidCoordinatorParameterException( f"Expected a sequence or iterable of search_coordinators, received type {type(search_coordinators)}" ) for search_coordinator in search_coordinators: self.add(search_coordinator)
[docs] def search( self, page: int = 1, iterate_by_group: bool = False, max_workers: Optional[int] = None, multithreading: bool = True, **kwargs: Any, ) -> SearchResultList: """Public method used to search for a single or multiple pages from multiple providers at once using a sequential or multithreading approach. This approach delegates the search to search_pages to retrieve a single page for query and provider using an iterative approach to search for articles grouped by provider. Note that the `MultiSearchCoordinator.search_pages` method uses shared rate limiters to ensure that APIs are not overwhelmed by the number of requests being sent within a specific time interval. Args: page (int): The page number to iteratively request from each API Provider. iterate_by_group (bool): Determines whether all searches should be performed by page or by group. Note that page-based iteration is significantly faster due to API rate limits. This is set to `False` by default as a result. max_workers (Optional[int]): Determines how many threads should operate at one time. Applies only when `multithreading` is set to `True`. When `None`, as many threads are used as required. multithreading (bool): Multithreading is used when this parameter is set to `True`. Otherwise, sequential iteration is performed. Multithreading is enabled by default. from_request_cache (bool): This parameter determines whether to try to retrieve the response from the requests-cache storage. from_process_cache (bool): This parameter determines whether to attempt to pull processed responses from the cache storage. use_workflow (bool): Indicates whether to use a workflow if available. Workflows are utilized by default. Returns: SearchResultList: The list containing all retrieved and processed pages from the API. If any non-stopping errors occur, this will return an ErrorResponse instead with error and message attributes further explaining any issues that occurred during processing. """ return self.search_pages( pages=[page] if isinstance(page, int) else page, iterate_by_group=iterate_by_group, max_workers=max_workers, multithreading=multithreading, **kwargs, )
[docs] def search_page( self, page: int = 1, **kwargs: Any, ) -> SearchResultList: """Retrieves a single page from all registered coordinators. This method provides API compatibility with SearchCoordinator.search_page, returning results wrapped in SearchResult containers with provider metadata. Args: page (int): The page number to retrieve from each provider. **kwargs: Additional arguments to pass to `MultiSearchCoordinator.search_pages` or the `search_pages` method for each individual coordinator. Returns: SearchResultList: Results from all coordinators for the specified page. """ return self.search_pages( pages=[page] if isinstance(page, int) else page, **kwargs, )
[docs] def search_pages( self, pages: Sequence[int] | PageListInput, iterate_by_group: bool = False, max_workers: Optional[int] = None, multithreading: bool = True, *, min_records: Optional[int] = None, page_offset: int = 0, **kwargs: Any, ) -> SearchResultList: """Searches for records from multiple providers using a sequential or multithreading approach. Note that the `MultiSearchCoordinator.search_pages` method uses shared rate limiters to ensure that APIs are not overwhelmed by the number of requests being sent within a specific time interval. Args: pages (Sequence[int]): A sequence of page numbers to iteratively request from the API Provider. from_request_cache (bool): This parameter determines whether to try to retrieve the response from the requests-cache storage. from_process_cache (bool): This parameter determines whether to attempt to pull processed responses from the cache storage. min_records (int): The total number of records to retrieve sequentially. If not provided as an integer, the `pages` argument is validated immediately instead. No-Op when `pages` is a non-empty/non-zero value. page_offset (int): The page offset to begin record retrieval from (0 by default). This parameter is only relevant when a `min_records` value is provided instead of a page number. use_workflow (bool): Indicates whether to use a workflow if available Workflows are utilized by default. Returns: SearchResultList: The list containing all retrieved and processed pages from the API. If any non-stopping errors occur, this will return an ErrorResponse instead with error and message attributes further explaining any issues that occurred during processing. """ search_results = SearchResultList() if max_workers is not None and not isinstance(max_workers, int): raise InvalidCoordinatorParameterException( f"Expected max_workers to be a positive integer, but received a value of type {type(max_workers)}" ) pages = [] if not pages and isinstance(min_records, int) else SearchCoordinator._validate_page_list_input(pages) if not self.data: logger.warning( "A coordinator has not yet been registered with the MultiSearchCoordinator: " "returning an empty list..." ) return search_results if multithreading: search_iterator: Generator[SearchResult, None, None] = self.iter_pages_threaded( pages, max_workers=max_workers, min_records=min_records, page_offset=page_offset, **kwargs ) else: search_iterator = self.iter_pages( pages, iterate_by_group=iterate_by_group, min_records=min_records, page_offset=page_offset, **kwargs ) for search_result in search_iterator: search_results.append(search_result) logging.debug("Completed multi-search coordinated retrieval and processing") return search_results
[docs] def search_records(self, min_records: int, page_offset: int = 0, **kwargs: Any) -> SearchResultList: """Helper method for retrieving a minimum of `min_records` records across all API providers. This method retrieves a minimum of `min_records` per provider unless no pages remain to be retrieved or a non-retryable error occurs during processing. Note that this method uses shared rate limiters to ensure that APIs are not overwhelmed by the number of requests being sent within a specific time interval. Args: min_records (int): The total number of records to retrieve sequentially. page_offset (int): The page offset to begin record retrieval from (0 by default). from_request_cache (bool): This parameter determines whether to try to retrieve the response from the requests-cache storage. from_process_cache (bool): This parameter determines whether to attempt to pull processed responses from the cache storage. use_workflow (bool): Indicates whether to use a workflow if available Workflows are utilized by default. Returns: SearchResultList: The list containing all retrieved and processed pages from the API. If any non-stopping errors occur, this will return an ErrorResponse instead with error and message attributes further explaining any issues that occurred during processing. """ kwargs["pages"] = [] return self.search_pages(min_records=min_records, page_offset=page_offset, **kwargs)
[docs] def iter_pages( self, pages: Sequence[int] | PageListInput, iterate_by_group: bool = False, **kwargs: Any ) -> Generator[SearchResult, None, None]: """Helper method that creates and joins a sequence of generator functions for retrieving and processing records from each combination of queries, pages, and providers in sequence. This implementation uses the SearchCoordinator.iter_pages to dynamically identify when page retrieval should halt for each API provider, accounting for errors, timeouts, and less than the expected amount of records before filtering records with pre- specified criteria. Args: pages (Sequence[int]): A sequence of page numbers to iteratively request from the API Provider. from_request_cache (bool): This parameter determines whether to try to retrieve the response from the requests-cache storage. from_process_cache (bool): This parameter determines whether to attempt to pull processed responses from the cache storage. use_workflow (bool): Indicates whether to use a workflow if available Workflows are utilized by default. Yields: SearchResult: Iteratively returns the SearchResult for each provider, query, and page using a generator expression. Each result contains the requested page number (page), the name of the provider (provider_name), and the result of the search containing a ProcessedResponse, an ErrorResponse, or None (api response) """ # to eventually be used for threading by provider where each is assigned to the same chain provider_search_dict = self.group_by_provider() # creates a dictionary of generators grouped by provider. On each yield, each generator retrieves a single page provider_generator_dict = { provider_name: self._process_provider_group(group, pages, **kwargs) for provider_name, group in provider_search_dict.items() } if iterate_by_group: # Retrieve all pages from a single provider before moving to the next provider yield from self._grouped_iteration(provider_generator_dict) else: # Retrieve a single page number for all providers before moving to the next page yield from self._round_robin_iteration(provider_generator_dict)
@classmethod def _grouped_iteration( cls, provider_generator_dict: dict[str, Generator[SearchResult, None, None]] ) -> Generator[SearchResult, None, None]: """Helper method for iteratively retrieves all pages from a single provider before moving to the next. Args: generator_dict (Mapping[str, Generator[SearchResult, None, None]]): A dictionary containing provider names as keys and generators as values. Yields: SearchResult: A search result containing the provider name, query, and page, and response from each API Provider """ for provider_name, generator in provider_generator_dict.items(): yield from cls._process_page_generator(provider_name, generator) @classmethod def _round_robin_iteration( cls, provider_generator_dict: dict[str, Generator[SearchResult, None, None]] ) -> Generator[SearchResult, None, None]: """Helper method for iteratively yielding each page from each provider in a cyclical order. This method is implemented to ensure faster iteration given common rate-limits associated with API Providers. Note that the received generator dictionary will be popped as each generator is consumed. Args: provider_generator_dict (Mapping[str, Generator[SearchResult, None, None]]): A dictionary containing provider names as keys and generators as values. Yields: SearchResult: A search result containing the provider name, query, and page, and response from each API Provider """ while provider_generator_dict: inactive_generators = [] for provider_name, generator in provider_generator_dict.items(): try: yield next(generator) # If successful, put it back at the end except StopIteration: logger.debug(f"Successfully halted retrieval for provider, {provider_name}") inactive_generators.append(provider_name) except Exception as e: logger.error( "Encountered an unexpected error during iteration for provider, " f"{provider_name}: {e}" ) inactive_generators.append(provider_name) for provider_name in inactive_generators: provider_generator_dict.pop(provider_name)
[docs] def iter_pages_threaded( self, pages: Sequence[int] | PageListInput, max_workers: Optional[int] = None, **kwargs: Any, ) -> Generator[SearchResult, None, None]: """Threading by provider to respect rate limits Helper method that implements threading to simultaneously retrieve a sequence of generator functions for retrieving and processing records from each combination of queries, pages, and providers in a multi-threaded set of sequences grouped by provider. This implementation also uses the SearchCoordinator.iter_pages to dynamically identify when page retrieval should halt for each API provider, accounting for errors, timeouts, and less than the expected amount of records before filtering records with pre-specified criteria. Note, that as threading is performed by provider, this method will not differ significantly in speed from the `MultiSearchCoordinator.iter_pages` method if only a single provider has been specified. Args: pages (Sequence[int] | PageListInput): A sequence of page numbers to request from the API Provider. from_request_cache (bool): This parameter determines whether to try to retrieve the response from the requests-cache storage. from_process_cache (bool): This parameter determines whether to attempt to pull processed responses from the cache storage. use_workflow (bool): Indicates whether to use a workflow if available Workflows are utilized by default. Yields: SearchResult: Iteratively returns the SearchResult for each provider, query, and page using a generator expression as each SearchResult becomes available after multi-threaded processing. Each result contains the requested page number (page), the name of the provider (provider_name), and the result of the search containing a ProcessedResponse, an ErrorResponse, or None (api response) """ provider_groups = self.group_by_provider() workers = max_workers if max_workers is not None else min(8, len(provider_groups) or 1) if workers < 1: logger.warning(f"The value for workers ({workers}) is non-positive: defaulting to 1 worker") workers = 1 # creates a dictionary of generators grouped by provider. On each yield, each generator retrieves a single page provider_generator_dict = { provider_name: self._process_provider_group(group, pages, **kwargs) for provider_name, group in provider_groups.items() } with ThreadPoolExecutor(max_workers=workers) as executor: futures = [ executor.submit(list, self._process_page_generator(provider_name, generator)) for provider_name, generator in provider_generator_dict.items() ] for future in concurrent.futures.as_completed(futures): yield from future.result()
@classmethod def _process_page_generator( cls, provider_name: str, generator: Generator[SearchResult, None, None] ) -> Generator[SearchResult, None, None]: """Helper method for safely consuming a generator, accounting for errors that could stop iteration during threaded retrieval of page data. Args: provider_name (str): The name of the current provider generator (Generator[SearchResult, None, None]): A generator that returns a SearchResult upon the successful retrieval of the next page Yields: SearchResult: The next search result from the generator if there is at least one more page to retrieve """ try: yield from generator logger.debug(f"Successfully halted retrieval for provider, {provider_name}") except Exception as e: logger.error("Encountered an unexpected error during iteration for provider, " f"{provider_name}: {e}") def _process_provider_group( self, provider_coordinators: dict[str, SearchCoordinator], pages: Sequence[int] | PageListInput, *, min_records: Optional[int] = None, page_offset: int = 0, **kwargs: Any, ) -> Generator[SearchResult, None, None]: """Helper method used to process all queries and pages for a single provider under a common thread. This method is especially useful during multithreading given that API Providers often have hard limits on the total number of requests that can be sent within a provider-specific interval. Args: provider_coordinators (dict[str, SearchCoordinator]): A dictionary of all coordinators corresponding to a single provider. pages (Sequence[int] | PageListInput): A list, set, or other common sequence of integer page numbers corresponding to records/articles to iteratively request from the API Provider. min_records (int): The total number of records to retrieve sequentially. page_offset (int): The page offset to begin record retrieval from (0 by default). **kwargs: Keyword arguments to pass to the `iter_pages` method call to facilitate single or multithreaded record page retrieval Yields: SearchResult: Iteratively returns the SearchResult for each provider, query, and page using a generator expression as each SearchResult becomes available after multi-threaded processing. Each result contains the requested page number (page), the name of the provider (provider_name), and the result of the search containing a ProcessedResponse, an ErrorResponse, or None (api response) """ # All coordinators in this group share the same threaded rate limiter # will be used to flag non-retryable error codes from the provider for early stopping across queries if needed last_response: Optional[APIResponse] = None for search_coordinator in provider_coordinators.values(): if ( isinstance(last_response, ErrorResponse) and not isinstance(last_response, NonResponse) and isinstance(last_response.status_code, int) and last_response != 200 and last_response.status_code not in search_coordinator.retry_handler.retry_statuses ): # breaks if a non-retryable status code is encountered. logger.warning( f"Encountered a non-retryable response during retrieval: {last_response}. " f"Halting retrieval for provider, {search_coordinator.display_name}" ) break # retrieve the rate from within the threaded rate limiter default_request_delay = search_coordinator.api._rate_limiter.min_interval request_delay = kwargs.pop("request_delay", default_request_delay) current_pages = ( PageListInput.from_record_count(min_records, search_coordinator.api.records_per_page, page_offset) if not pages and min_records is not None else pages ) # iterate over the current coordinator given its session, query, and settings for page in search_coordinator.iter_pages(current_pages, **kwargs, request_delay=request_delay): if isinstance(page, SearchResult): last_response = page.response_result yield page
[docs] def current_providers(self) -> set[str]: """Extracts a set of names corresponding to each API provider assigned to the MultiSearchCoordinator.""" return {ProviderConfig._normalize_name(coordinator.provider_name) for coordinator in self.data.values()}
[docs] def group_by_provider(self) -> dict[str, dict[str, SearchCoordinator]]: """Groups all coordinators by provider name to facilitate retrieval with normalized components where needed. Especially helpful in the latter retrieval of articles when using multithreading by provider (as opposed to by page) to account for strict rate limits. All coordinated searches corresponding to a provider would appear under a nested dictionary to facilitate orchestration on the same thread with the same rate limiter. Returns: dict[str, dict[str, SearchCoordinator]]: All elements in the final dictionary map provider-specific coordinators to the normalized provider name for the nested dictionary of coordinators. """ provider_search_dict: dict[str, dict[str, SearchCoordinator]] = defaultdict(dict) for key, coordinator in self.data.items(): provider_name = ProviderConfig._normalize_name(coordinator.provider_name) provider_search_dict[provider_name][key] = coordinator return dict(provider_search_dict)
def _normalize_rate_limiter(self, search_coordinator: SearchCoordinator) -> SearchCoordinator: """Helper method that retrieves the threaded rate_limiter for the coordinator's provider and normalizes the rate limiter used for searches.""" provider_name = ProviderConfig._normalize_name(search_coordinator.provider_name) # ensure that the same rate limiter is used with threading if needed to ensure rate limiting across providers # if the provider doesn't already exist, initialize the provider rate limiter in the registry threaded_rate_limiter = threaded_rate_limiter_registry.get_or_create( provider_name, self.DEFAULT_THREADED_REQUEST_DELAY ) if threaded_rate_limiter: search_coordinator.api = SearchAPI.update(search_coordinator.api, rate_limiter=threaded_rate_limiter) return search_coordinator @classmethod def _create_key(cls, search_coordinator: SearchCoordinator) -> str: """Create a hashed key from a coordinator using the provider name, query, and structure of the SearchCoordinator.""" hash_value = hash(repr(search_coordinator)) provider_name = ProviderConfig._normalize_name(search_coordinator.provider_name) query = str(search_coordinator.api.query) key = f"{provider_name}_{query}:{hash_value}" return key
[docs] def select( self, query: Optional[str] = None, provider_name: Optional[str] = None, ) -> list[SearchCoordinator]: """Helper method that enables the selection of coordinators based on their query or provider name.""" provider_name = ( ProviderConfig._normalize_name(provider_name) if isinstance(provider_name, str) else provider_name ) return [ coordinator for coordinator in self.coordinators if (query is None or query == coordinator.api.query) and provider_name is None or provider_name == coordinator.provider_name ]
[docs] def structure(self, flatten: bool = False, show_value_attributes: bool = True) -> str: """Helper method that shows the current structure of the MultiSearchCoordinator.""" class_name = self.__class__.__name__ attributes = {key: coordinator.summary() for key, coordinator in self.data.items()} return generate_repr_from_string(class_name, attributes)
def __repr__(self) -> str: """Helper method for generating a string representation of the current list of coordinators.""" return self.structure()
__all__ = ["MultiSearchCoordinator"]