"""Channel creation utilities for the Athena client."""
import json
import logging
import threading
import time
from dataclasses import dataclass
from typing import override
import grpc
import httpx
from grpc.aio import Channel
from resolver_athena_client.client.exceptions import (
CredentialError,
InvalidHostError,
OAuthError,
)
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class TokenData:
"""Immutable snapshot of token state."""
access_token: str
expires_at: float
scheme: str
issued_at: float
proactive_refresh_threshold: float = 0.25
def __post_init__(self) -> None:
"""Validate that proactive_refresh_threshold is between 0 and 1."""
if (
self.proactive_refresh_threshold <= 0
or self.proactive_refresh_threshold >= 1
):
msg = "proactive_refresh_threshold must be between 0 and 1"
raise ValueError(msg)
def is_valid(self) -> bool:
"""Check if this token is still valid (with a 30-second buffer)."""
return time.time() < (self.expires_at - 30)
def is_old(self) -> bool:
"""Check if this token should be proactively refreshed.
A token is considered "old" if less than the
proactive_refresh_threshold of its lifetime remains. This allows
background refresh to happen before expiry while the token is still
usable.
"""
current_time = time.time()
total_lifetime = self.expires_at - self.issued_at
time_remaining = self.expires_at - current_time
return time_remaining < (
total_lifetime * self.proactive_refresh_threshold
)
[docs]
class CredentialHelper:
"""OAuth credential helper for managing authentication tokens."""
[docs]
def __init__(
self,
client_id: str,
client_secret: str,
auth_url: str = "https://crispthinking.auth0.com/oauth/token",
audience: str = "crisp-athena-live",
proactive_refresh_threshold: float = 0.25,
) -> None:
"""Initialize the credential helper.
Args:
----
client_id: OAuth client ID
client_secret: OAuth client secret
auth_url: OAuth token endpoint URL
audience: OAuth audience
proactive_refresh_threshold: Fraction of token lifetime to trigger
proactive refresh (default 0.25 for 25%)
"""
if not client_id:
msg = "client_id cannot be empty"
raise CredentialError(msg)
if not client_secret:
msg = "client_secret cannot be empty"
raise CredentialError(msg)
self._client_id: str = client_id
self._client_secret: str = client_secret
self._auth_url: str = auth_url
self._audience: str = audience
self._token_data: TokenData | None = None
self._lock: threading.Lock = threading.Lock()
self._refresh_thread: threading.Thread | None = None
if proactive_refresh_threshold <= 0 or proactive_refresh_threshold >= 1:
msg = "proactive_refresh_threshold must be a float between 0 and 1"
raise ValueError(msg)
self._proactive_refresh_threshold: float = proactive_refresh_threshold
[docs]
def get_token(self) -> TokenData:
"""Get valid token data, refreshing if necessary.
Returns
-------
A valid ``TokenData`` containing access token, expiry, and scheme
Raises
------
OAuthError: If token acquisition fails
RuntimeError: If token is unexpectedly None after refresh
"""
token_data = self._token_data
# Fast path: token is valid and fresh
if token_data is not None and token_data.is_valid():
# If token is old, trigger background refresh
if token_data.is_old():
self._start_background_refresh()
return token_data
# Slow path: token is expired or missing, must block
with self._lock:
token_data = self._token_data
if token_data is not None and token_data.is_valid():
return token_data
self._refresh_token()
token_data = self._token_data
if token_data is None:
msg = "Token is unexpectedly None after refresh"
raise RuntimeError(msg)
return token_data
def _start_background_refresh(self) -> None:
"""Start a background thread to refresh the token.
Only starts a new thread if one isn't already running.
This method is safe to call multiple times - it only starts a new
thread if no refresh is currently in progress.
"""
# Quick check without lock - if refresh thread exists and is
# alive, skip
if self._refresh_thread is not None and self._refresh_thread.is_alive():
return
# Try to acquire lock and start refresh
if self._lock.acquire(blocking=False):
try:
# Double-check: another thread might have started refresh,
# or the token may have been refreshed.
refresh_not_active = (
self._refresh_thread is None
or not self._refresh_thread.is_alive()
)
token_needs_refresh = (
self._token_data is None or self._token_data.is_old()
)
refresh_needed = refresh_not_active and token_needs_refresh
if refresh_needed:
self._refresh_thread = threading.Thread(
target=self._background_refresh,
daemon=True,
)
self._refresh_thread.start()
finally:
self._lock.release()
def _background_refresh(self) -> None:
"""Background thread target for token refresh.
Acquires the lock and refreshes the token. Errors are logged
but silently ignored since the next foreground request will
retry if needed.
"""
with self._lock:
# Check if token still needs refresh (prevent stampede)
token_data = self._token_data
if token_data is not None and not token_data.is_old():
# Token was already refreshed by another thread
return
try:
self._refresh_token()
except Exception as e: # noqa: BLE001
# Log but don't raise - background refresh failures
# are recoverable (next get_token() will retry)
logger.debug(
"Background token refresh failed, "
"will retry on next request: %s",
e,
)
def _refresh_token(self) -> None:
"""Refresh the authentication token by making an OAuth request.
This is a synchronous call (suitable for the gRPC metadata-plugin
thread) and must be called while ``self._lock`` is held.
Raises
------
OAuthError: If the OAuth request fails
"""
payload = {
"client_id": self._client_id,
"client_secret": self._client_secret,
"audience": self._audience,
"grant_type": "client_credentials",
}
headers = {"content-type": "application/json"}
try:
with httpx.Client() as client:
response = client.post(
self._auth_url,
json=payload,
headers=headers,
timeout=30.0,
)
_ = response.raise_for_status()
raw = response.json()
access_token: str = raw["access_token"]
expires_in: int = raw.get("expires_in", 3600) # Default 1 hour
token_type = raw.get("token_type", "Bearer")
scheme: str = token_type.strip() if token_type else "Bearer"
current_time = time.time()
self._token_data = TokenData(
access_token=access_token,
expires_at=current_time + expires_in,
scheme=scheme,
issued_at=current_time,
proactive_refresh_threshold=self._proactive_refresh_threshold,
)
except httpx.HTTPStatusError as e:
error_detail = ""
try:
error_data = e.response.json()
error_desc = error_data.get(
"error_description", error_data.get("error", "")
)
error_detail = f": {error_desc}"
except (json.JSONDecodeError, KeyError):
pass
msg = (
f"OAuth request failed with status "
f"{e.response.status_code}{error_detail}"
)
raise OAuthError(msg) from e
except (httpx.RequestError, httpx.TimeoutException) as e:
msg = f"Failed to connect to OAuth server: {e}"
raise OAuthError(msg) from e
except KeyError as e:
msg = f"Invalid OAuth response format: missing {e}"
raise OAuthError(msg) from e
except Exception as e:
msg = f"Unexpected error during OAuth: {e}"
raise OAuthError(msg) from e
[docs]
def invalidate_token(self) -> None:
"""Invalidate the current token to force a refresh on next use."""
with self._lock:
self._token_data = None
class _AutoRefreshTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
"""gRPC auth plugin that fetches a fresh token for every RPC."""
def __init__(self, credential_helper: CredentialHelper) -> None:
"""Initialize with a credential helper.
Args:
----
credential_helper: The helper that manages token lifecycle
"""
self._credential_helper: CredentialHelper = credential_helper
@override
def __call__(
self,
_: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
) -> None:
"""Supply authorization metadata for an RPC.
Called by the gRPC runtime on a background thread before each
RPC. On success the token is forwarded using the scheme from
the OAuth token response (typically ``Bearer``); on failure
the error is passed to the callback so gRPC can surface it as
an RPC error.
Args:
----
callback: gRPC callback to receive metadata or an error
"""
try:
token_data = self._credential_helper.get_token()
scheme = token_data.scheme
token = token_data.access_token
metadata = (("authorization", f"{scheme} {token}"),)
callback(metadata, None)
except Exception as err: # noqa: BLE001
callback((), err)
[docs]
async def create_channel_with_credentials(
host: str,
credential_helper: CredentialHelper,
) -> Channel:
"""Create a gRPC channel with OAuth credential helper.
Args:
----
host: The host address to connect to
credential_helper: The credential helper for OAuth authentication
Returns:
-------
A secure gRPC channel with OAuth authentication
Raises:
------
InvalidHostError: If host is empty
"""
if not host:
raise InvalidHostError(InvalidHostError.default_message)
# Create credentials with per-RPC token refresh
credentials = grpc.composite_channel_credentials(
grpc.ssl_channel_credentials(),
grpc.metadata_call_credentials(
_AutoRefreshTokenAuthMetadataPlugin(credential_helper)
),
)
# Configure gRPC options for persistent connections
options = [
# Keep connections alive longer
("grpc.keepalive_time_ms", 60000), # Send keepalive every 60s
("grpc.keepalive_timeout_ms", 30000), # Wait 30s for keepalive ack
(
"grpc.keepalive_permit_without_calls",
1,
), # Allow keepalive when idle
# Optimize for persistent streams
("grpc.http2.max_pings_without_data", 0), # Allow unlimited pings
(
"grpc.http2.min_time_between_pings_ms",
60000,
), # Min 60s between pings
(
"grpc.http2.min_ping_interval_without_data_ms",
30000,
), # Min 30s when idle
# Increase buffer sizes for better performance
("grpc.http2.write_buffer_size", 1024 * 1024), # 1MB write buffer
(
"grpc.max_receive_message_length",
64 * 1024 * 1024,
), # 64MB max message
]
return grpc.aio.secure_channel(host, credentials, options=options)