# /api/models/base_provider_dict.py
"""The scholar_flux.api.models.base_provider_dict.py module implements a BaseProviderDict to extend the dictionary and
resolve provider names to a generic key, handling the normalization of provider names for consistent access."""
from __future__ import annotations
from typing import Any, Optional
import re
from collections import UserDict
from scholar_flux.api.models.provider_config import ProviderConfig
from scholar_flux.utils.repr_utils import generate_repr_from_string
[docs]
class BaseProviderDict(UserDict[str, Any]):
"""The BaseProviderDict extends the dictionary to resolve minor naming variations in keys to the same provider name.
The BaseProviderDict uses the `ProviderConfig._normalize_name` method to ignore underscores and case-sensitivity.
"""
def __contains__(self, key: object) -> bool:
"""Helper method for determining whether a specific provider name after normalization can be found within the
current ProviderDict.
Args:
key (str): Name of the default provider
Returns:
bool: indicates the presence or absence of a key in the dictionary
"""
if isinstance(key, str):
key = self._normalize_name(key)
return key in self.data
return False
def __getitem__(self, key: str) -> Any:
"""Attempt to retrieve a value instance for the given provider name.
Args:
provider_name (str): Name of the default provider
Returns:
Any: The value associated with the current provider
"""
key = self._normalize_name(key) if isinstance(key, str) else key
return super().__getitem__(key)
def __setitem__(
self,
key: str,
value: Any,
) -> None:
"""Adds a key-value pair to the BaseProviderDict with key normalization.
This method overrides the original dict.__setitem__ method to verify that the key used as a provider name
is a non-empty string.
Args:
key (str): Name of the provider to add to the dictionary
value (Any): The value to associate with the provider
Raises:
TypeError: If the current key is not a string
ValueError: If the normalized key is an empty string
"""
# normalizes the key/provider before ever registering
normalized_key = self._normalize_name(key)
if not normalized_key:
raise ValueError(f"The key provided to the {self.__class__.__name__} is empty. Expected a non-empty string")
super().__setitem__(normalized_key, value)
def __delitem__(self, key: str) -> None:
"""Deletes an element from the ProviderDict for the given provider.
Args:
key (str): Name of the default provider
Raises:
KeyError: If the current key does not exist in the dictionary
"""
key = self._normalize_name(key) if isinstance(key, str) else key
return super().__delitem__(key)
@classmethod
def _normalize_name(cls, key: str) -> str:
"""Helper method that is used to validate and normalize provider names.
Args:
key (str): Name of the provider
Returns:
str: The normalized provider name
Raises:
TypeError: If the current key is not a string
"""
# Check if the key already exists and handle overwriting behavior
if not isinstance(key, str):
raise TypeError(
f"The key provided to the {cls.__name__} is invalid. Expected a string, received {type(key)}"
)
normalized_key = ProviderConfig._normalize_name(key)
return normalized_key
@property
def providers(self) -> list[str]:
"""Returns a list containing the names of all (keys) in the current registry.
Returns:
A complete list of all keys shown in the current registry
"""
return list(self.data)
[docs]
def find(self, key: str | re.Pattern, regex: Optional[bool] = None) -> list[str]:
"""Identifies providers with names matching the specified pattern using either prefix or regex pattern matching.
This implementation uses `fuzzy` finding, or "flexible matching that's more forgiving than exact". When
`regex=True` or a compiled Pattern is provided, regex matching is used. Otherwise, provider names are
filtered using prefix matching via `str.startswith` after normalizing the provided key and provider names.
Args:
key (str | re.Pattern): The key or pattern to match using regular expressions or prefix matching.
regex (Optional[bool]): Indicates whether regular expressions should be used to match provider names.
Returns:
list[str]: A list of strings containing provider names that match the key/pattern.
Note:
Unless either pattern is received or `regex=True`, providers are matched if the normalized key prefix is
present in the normalized provider name.
"""
if not isinstance(key, (str, re.Pattern)):
return []
use_regex = regex or (regex is None and isinstance(key, re.Pattern))
if use_regex:
normalized_key = self._normalize_name(key) if isinstance(key, str) else key
return [provider_name for provider_name in self.providers if re.search(normalized_key, provider_name)]
normalized_key = self._normalize_name(key.pattern if isinstance(key, re.Pattern) else key)
return [provider_name for provider_name in self.providers if provider_name.startswith(normalized_key)]
[docs]
def structure(self, flatten: bool = False, show_value_attributes: bool = True) -> str:
"""Helper method that shows the current structure of the BaseProviderDict or subclass."""
class_name = self.__class__.__name__
dictionary_elements = self.data
return generate_repr_from_string(
class_name, dictionary_elements, flatten=flatten, show_value_attributes=show_value_attributes, as_dict=True
)
def __repr__(self) -> str:
"""Helper method for displaying the config in a user-friendly manner."""
return self.structure()
__all__ = ["BaseProviderDict"]