Source code for scholar_flux.api.rate_limiting.rate_limiter

# /api/rate_limiting/rate_limiter.py
"""The scholar_flux.api.rate_limiting.rate_limiter module implements a simple, general RateLimiter.

ScholarFlux uses and builds upon this `RateLimiter` implementation to ensure that the number of requests to an API
provider does not exceed the limit within the specified time interval.

"""
from __future__ import annotations
from contextlib import contextmanager
import time
from functools import wraps
from scholar_flux.exceptions import APIParameterException
from scholar_flux.utils.repr_utils import generate_repr_from_string
from scholar_flux.api.rate_limiting.history import (
    HistoryDeque,
    RateLimitEvent,
)
from datetime import datetime

from typing import Optional, Iterator, Dict, Any, Callable, TypeVar, ParamSpec, TYPE_CHECKING
import logging

if TYPE_CHECKING:
    from typing_extensions import Self, Literal
    from types import TracebackType

logger = logging.getLogger(__name__)

P = ParamSpec("P")
R = TypeVar("R")


[docs] class RateLimiter: """A basic rate limiter used to ensure that function calls (such as API requests) do not exceed a specified rate. The `RateLimiter` is used within ScholarFlux to throttle the total number of requests that can be made within a defined time interval (measured in seconds). This class ensures that calls to `RateLimiter.wait()` (or any decorated function) are spaced by at least `min_interval` seconds. For multithreading applications, the `RateLimiter` is not thread-safe. Instead, the `ThreadedRateLimiter` subclass can provide a thread-safe implementation when required. Args: min_interval (Optional[float | int]): The minimum number of seconds that must elapse before another request sent or call is performed. If `min_interval` is not specified, then class attribute, `RateLimiter.DEFAULT_MIN_INTERVAL` will be assigned to `RateLimiter.min_interval` instead. Examples: >>> import requests >>> from scholar_flux.api import RateLimiter >>> rate_limiter = RateLimiter(min_interval = 5) >>> # The first call won't sleep, because a prior call using the rate limiter doesn't yet exist >>> with rate_limiter: ... response = requests.get("http://httpbin.org/get") >>> # will sleep if 5 seconds since the last call hasn't elapsed. >>> with rate_limiter: ... response = requests.get("http://httpbin.org/get") >>> # Or simply call the `wait` method directly: >>> rate_limiter.wait() >>> response = requests.get("http://httpbin.org/get") Note: The class-level history deque is a design choice, This attribute allows class-level monitoring and introspection into how request delays are computed. The `HistoryDeque` is thread-safe (uses cpython on the backend) and allows global observability which is helpful for debugging, especially in cases where you need to adjust the total amount of requests sent within a given interval to avoid 429 errors. """ DEFAULT_MIN_INTERVAL: float | int = 6.1 history: HistoryDeque[RateLimitEvent] = HistoryDeque.create()
[docs] def __init__(self, min_interval: Optional[float | int] = None): """Initializes the rate limiter with the `min_interval` argument. Args: min_interval (Optional[float | int]): Minimum number of seconds to wait before the next call is performed or request sent. """ self.min_interval = min_interval if min_interval is not None else self.DEFAULT_MIN_INTERVAL self._last_call: float | int | None = None
@property def min_interval(self) -> float | int: """The minimum number of seconds that must elapse before another request sent or action is taken.""" return self._min_interval @min_interval.setter def min_interval(self, min_interval: float | int) -> None: """Validates the `min_interval` property upon assignment to ensure that the received value is numeric. This setter allows the `min_interval` property to be assigned directly to a rate limiter instance and requires no further action (e.g., `rate_limiter.min_interval=4`). Args: min_interval (float | int): The minimum number of seconds that must elapse before another request sent or call is performed. Raises: APIParameterException: If the received value is a non-missing value that is not a float or integer """ self._min_interval = self._validate(min_interval) @staticmethod def _validate(min_interval: float | int) -> float: """Helper that verifies if the input to `min_interval` is a valid number that is greater than or equal to 0.""" if not isinstance(min_interval, (int, float)): raise APIParameterException( f"`min_interval` must be a number greater than or equal to 0. Received value, '{min_interval}'" ) if min_interval < 0: raise APIParameterException("min_interval must be non-negative") return min_interval @staticmethod def _validate_timestamp(timestamp: float | int) -> float | int: """Validates timestamp is a non-negative number.""" if not isinstance(timestamp, (float, int)) or timestamp < 0: raise APIParameterException( f"`timestamp` must a valid timestamp formatted as a non-negative number. Received value, '{timestamp}'" ) return timestamp
[docs] def wait(self, min_interval: Optional[float | int] = None, metadata: Optional[Dict[str, Any]] = None) -> None: """Block (`time.sleep`) until at least `min_interval` has passed since last call. This method can be used with the min_interval attribute to determine when a search was last sent and throttle requests to make sure rate limits aren't exceeded. If not enough time has passed, the API will wait before sending the next request. Args: min_interval (Optional[float | int]): The minimum time to wait until another call is sent. Note that the min_interval attribute or argument must be non-null, otherwise, the default min_interval value is used. metadata (Optional[Dict[str, Any]]): Optional metadata for observability (e.g., url, caller, reason). Exceptions: APIParameterException: Occurs if the value provided is either not an integer/float or is less than 0 """ min_interval = self._validate(min_interval if min_interval is not None else self.default_min_interval()) if self._last_call is not None and min_interval: self._wait(min_interval, self._last_call, metadata=metadata) # record the time we actually proceed self._last_call = time.time()
@classmethod def _wait( cls, min_interval: float | int, last_call: float | int, metadata: Optional[Dict[str, Any]] = None ) -> None: """Helper Method that calls `time.sleep()` in the background to wait for a specific number of seconds. This method determines how long to wait by referencing when `._wait()` was last called along with the `min_interval` that defines the minimum amount of time between successive calls/requests. Args: min_interval (float | int): The minimum time to wait until another call is sent. last_call (float | int): The start time. In context, the previously recorded time when the function was called metadata (Optional[Dict[str, Any]]): Optional metadata for observability (e.g., url, caller, reason). The time to wait is essentially calculated as follows: 1. Determine the number of seconds that have elapsed since the last call: (e.g., `elapsed = time.time() - rate_limiter._last_call`) 2. Calculate the number of seconds remaining until the minimum interval is reached: (e.g., `remaining = rate_limiter.min_interval - elapsed`) 3. If `remaining` is positive, sleep for that duration: (e.g., `time.sleep(remaining)`) """ now = time.time() elapsed = now - last_call remaining = min_interval - elapsed if remaining > 0: cls._sleep(remaining, metadata=metadata)
[docs] def default_min_interval(self) -> float | int: """Returns the default minimum interval for the current rate limiter.""" return self.min_interval if self.min_interval is not None else self.DEFAULT_MIN_INTERVAL
[docs] def sleep(self, interval: Optional[float | int] = None, metadata: Optional[Dict[str, Any]] = None) -> None: """Simple Instance level implementation of `sleep` that can be overridden when needed. Args: interval (Optional[float | int]): The time interval to sleep. If None, the default minimum interval for the current rate limiter is used. must be non-null, otherwise, the default min_interval value is used. metadata (Optional[Dict[str, Any]]): Optional metadata for observability (e.g., url, caller, reason). Exceptions: APIParameterException: Occurs if the value provided is either not an integer/float or is less than 0 """ interval = self._validate(interval if interval is not None else self.default_min_interval()) if interval > 0: self._sleep(interval, metadata=metadata)
[docs] def wait_since( self, min_interval: Optional[float | int] = None, timestamp: Optional[float | int | datetime] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: """Wait based on a reference timestamp or datetime. Args: min_interval (Optional[float | int]): Minimum interval to wait. Uses default if None. timestamp (Optional[float | int | datetime]): Reference time such as a Unix timestamp or datetime. If None, sleeps for min_interval. metadata (Optional[Dict[str, Any]]): Optional metadata for observability (e.g., url, caller, reason). """ if timestamp is not None: timestamp = self._validate_timestamp( timestamp.timestamp() if isinstance(timestamp, datetime) else timestamp ) min_interval = self._validate(min_interval if min_interval is not None else self.default_min_interval()) # Append the caller to metadata without overwriting if it exists metadata = {"caller": "wait_since"} | (metadata or {}) if timestamp is None: self._sleep(min_interval, metadata=metadata) else: self._wait(min_interval, timestamp, metadata=metadata)
@classmethod def _sleep(cls, interval: int | float, metadata: Optional[Dict[str, Any]] = None) -> None: """Logs the sleeping duration and blocks (`time.sleep`) until at `interval` has passed. Args: interval (int | float): The number of seconds to sleep. metadata (Optional[Dict[str, Any]]): Optional metadata for observability (e.g., url, caller, reason). """ logger.info(f"RateLimiter: sleeping {interval:.2f}s to respect rate limit") record = RateLimitEvent.from_metadata( interval=interval, **(metadata or {}), ) cls.history.append(record) time.sleep(interval) def __call__(self, fn: Callable[P, R]) -> Callable[P, R]: """Implements a rate limit for the defined function when the `RateLimiter` is used as a decorator. This decorator can be used to ensure a function can be called once every `min_interval` seconds and helps to ensure that API rate limits are not exceeded. Decorator syntax: @limiter def send_request(...): ... response = send_request(...) """ @wraps(fn) def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: """Wraps and decorates a function using the rate limiter to limit how frequently it can be called.""" self.wait() return fn(*args, **kwargs) return wrapped def __enter__(self) -> Self: """Enables a `RateLimiter` instance to be used as a context manager for throttling function calls or requests. Example: >>> with limiter: ... do_slow_call() """ self.wait() return self def __exit__( self, exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> Literal[False]: """Exits the context manager after the execution of the wrapped function.""" return False
[docs] @contextmanager def rate(self, min_interval: float | int, metadata: Optional[Dict[str, Any]] = None) -> Iterator[Self]: """Temporarily adjusts the minimum interval between function calls or requests when used with a context manager. After the context manager exits, the original minimum interval value is then reassigned its previous value, and the time of the last call is recorded. Args: min_interval (float | int): Indicates the minimum interval to be temporarily used during the call metadata (Optional[Dict[str, Any]]): Optional metadata for observability (e.g., url, caller, reason). Yields: RateLimiter: The original rate limiter with a temporarily changed minimum interval """ current_min_interval = self.min_interval try: self.min_interval = self._validate(min_interval) self.wait(metadata=metadata) yield self finally: self.min_interval = current_min_interval
[docs] @classmethod def resize_history(cls, maxlen: int) -> None: """Resize the global history deque, preserving existing records up to the new limit.""" cls.history = cls.history.modify_history_size(maxlen)
def __repr__(self) -> str: """Defines the string representation of the RateLimiter/subclasses to show the class name and `min_interval`.""" class_name = self.__class__.__name__ attributes = dict(min_interval=self.min_interval) return generate_repr_from_string(class_name, attributes, flatten=True)
__all__ = ["RateLimiter"]