Source code for scholar_flux.api.models.rate_limiter_registry

# /api/models/rate_limiter_registry.py
"""The scholar_flux.api.models.rate_limiter_registry module implements a registry that stores rate limiters by provider.

The `RateLimiterRegistry` implements several helpers for interacting with, retrieving, and creating default and thread-
safe rate limiters for both default and new providers.

"""
from __future__ import annotations

from typing import Any, Optional
from typing_extensions import Self

from scholar_flux.api.models.base_provider_dict import BaseProviderDict
from scholar_flux.api.rate_limiting import RateLimiter, ThreadedRateLimiter
from scholar_flux.exceptions import APIParameterException
import scholar_flux.api.providers as api_providers
import logging

logger = logging.getLogger(__name__)


[docs] class RateLimiterRegistry(BaseProviderDict): """A registry for creating, retrieving, updating, and deleting rate limiters by provider. The `RateLimiterRegistry` standardizes CRUD operations with thread-safe rate limiters for both default and custom providers. It ensures compatibility when using rate limiters in active applications. This implementation is especially important when using `MultiSearchCoordinators` to enforce normalized rate limiting by provider. Attributes: threaded (bool): Indicates whether the registry should use ThreadedRateLimiters. """
[docs] def __init__(self, *args: Any, threaded: bool = False, **kwargs: Any) -> None: """Initializes the RateLimiterRegistry and enforces the use of ThreadedRateLimiters when `threaded=True`.""" self.threaded = threaded super().__init__(*args, **kwargs)
@property def rate_limiter(self) -> type[RateLimiter | ThreadedRateLimiter]: """Helper method that returns the class constructor for a rate limiter. Returns: A ThreadedRateLimiter if `self.threaded=True`, otherwise the core `RateLimiter` """ return ThreadedRateLimiter if self.threaded else RateLimiter def __getitem__(self, key: str) -> RateLimiter: """Attempt to retrieve a RateLimiter instance for the given provider name. Args: key (str): Name of the provider Returns: RateLimiter: The RateLimiter for the provider if it exists """ return super().__getitem__(key) def __setitem__( self, key: str, value: RateLimiter | ThreadedRateLimiter, ) -> None: """Sets a key-value pair to the current registry where all keys are strings and values are RateLimiters. This method overrides the core functionality of dictionaries to ensure that validation for the `value` parameter occurs before saving the provider name-rate limiter pair Args: key (str): Name of the provider to add to the registry value (RateLimiter | ThreadedRateLimiter): The rate limiter for the API Provider """ try: if not isinstance(value, self.rate_limiter): raise TypeError( f"The value provided to the {self.__class__.__name__} is invalid. " f"Expected a {self.rate_limiter.__name__}, received {type(value)}" ) super().__setitem__(key, value) except (TypeError, ValueError) as e: raise APIParameterException(e) from e
[docs] def get_from_url(self, provider_url: Optional[str]) -> Optional[RateLimiter | ThreadedRateLimiter]: """Attempts to retrieve a RateLimiter for the specified provider from a URL. This method retrieves the rate limiter of the provider associated with the provided URL if the URL after normalization exists within the `scholar_flux.api.provider_registry`. If a provider does not exist, a value of None will be returned instead. Args: provider_url (Optional[str]): URL of the provider to look up. Returns: Optional[RateLimiter | ThreadedRateLimiter]: The rate limiter of the provider when available. Otherwise None. """ if provider_config := api_providers.provider_registry.get_from_url(provider_url): return self.data.get(provider_config.provider_name) return None
[docs] def get_or_create( self, key: str, default_request_delay: Optional[int | float] = None ) -> RateLimiter | ThreadedRateLimiter: """Helper method that retrieves rate limiter from the registry or creates one if it doesn't exist. This method is useful when a provider may or may not exist in the current registry and otherwise needs to be added. If a provider's rate limiter does not yet exist, the registry attempts to create a new rate limiter. Args: key (str): The name of the provider to retrieve a rate limiter for, and otherwise create a new rate limiter if it doesn't exist. default_request_delay (Optional[int | float]): The default minimum interval to use when creating a new rate limiter if one does not already exist for the provider Returns: RateLimiter | ThreadedRateLimiter: The retrieved rate limiter for the current provider if available. Otherwise a new, RateLimiter will be created, registered, and returned. """ # If a rate limiter exists for the current key, return it if rate_limiter := self.data.get(key): return rate_limiter # otherwise, create a new rate limiter return self.create(key, default_request_delay)
[docs] def create( self, provider_name: str, default_request_delay: Optional[int | float] = None ) -> RateLimiter | ThreadedRateLimiter: """Helper method that creates a new rate limiter for the current provider. The minimum interval for the provider is chosen based on the following order of priority: 1. If the provider exists in the `provider_registry`, use the `request_delay` from its configuration settings. 2. Otherwise, use the `default_request_delay` parameter if it is a float or integer. 3. If a provider doesn't exist in the registry and `default_request_delay` isn't specified, use the `RateLimiter.DEFAULT_MIN_INTERVAL` class parameter. Args: provider_name (str): The name of the provider to create a new rate limiter for. default_request_delay (Optional[int | float]): The default minimum interval to use when creating a new rate limiter. """ try: if default_request_delay is not None: self.rate_limiter._validate(default_request_delay) if provider_config := api_providers.provider_registry.get(provider_name): # Otherwise, create a new rate limiter from the `provider_registry` request_delay = provider_config.request_delay else: # Default to `default_request_delay` when possible request_delay = default_request_delay # Creates a new rate limiter with the `RateLimiter.default_request_delay` or `default_request_delay` if available rate_limiter = self.rate_limiter(request_delay) # adds the rate limiter to the registry self.add(provider_name, rate_limiter) return rate_limiter except Exception as e: raise APIParameterException( f"Encountered an error when creating a new rate limiter with the provider name, '{provider_name}': {e}" )
[docs] def add(self, provider_name: str, rate_limiter: RateLimiter | ThreadedRateLimiter) -> None: """Helper method for adding a new provider and rate limiter to the provider registry.""" provider_name = self._normalize_name(provider_name) if provider_name in self.data: logger.warning(f"Overwriting the previous RateLimiter for the provider, '{provider_name}'") self.data[provider_name] = rate_limiter logger.debug( f"Created a new rate limiter for the provider, {provider_name} " f"with a request delay of {rate_limiter.min_interval}" )
[docs] def remove(self, provider_name: str) -> None: """Helper method for removing a provider configuration from the provider registry.""" provider_name = ( self._normalize_name(provider_name) if isinstance(provider_name, str) and provider_name else provider_name ) if self.data.pop(provider_name, None): logger.info(f"Removed the rate limiter for the provider, '{provider_name}' from the rate limiter registry") else: logger.warning(f"A RateLimiter with the provider name, '{provider_name}' was not found")
[docs] @classmethod def from_defaults(cls, threaded: bool = False) -> Self: """Initializes a new `RateLimiterRegistry` for known providers based on their configurations. This method specifically uses the `provider_name` and `request_delay` of each provider listed within the `scholar_flux.api.providers.provider_registry` to create rate limiters for all known configurations. Returns: RateLimiterRegistry: A new rate limiter registry that contains default rate limiters for known providers. """ Limiter = ThreadedRateLimiter if threaded else RateLimiter return cls( { provider_name: Limiter(provider_config.request_delay) for provider_name, provider_config in api_providers.provider_registry.items() }, threaded=threaded, )
__all__ = ["RateLimiterRegistry"]