Source code for resolver_athena_client.client.channel

"""Channel creation utilities for the Athena client."""

import asyncio
import json
import time
from typing import override

import grpc
import httpx
from grpc.aio import Channel

from resolver_athena_client.client.exceptions import (
    CredentialError,
    InvalidHostError,
    OAuthError,
)


class TokenMetadataPlugin(grpc.AuthMetadataPlugin):
    """Plugin that adds authorization token to gRPC metadata."""

    def __init__(self, token: str) -> None:
        """Initialize the plugin with the auth token.

        Args:
        ----
            token: The authorization token to add to requests

        """
        self._token: str = token

    @override
    def __call__(
        self,
        _: grpc.AuthMetadataContext,
        callback: grpc.AuthMetadataPluginCallback,
    ) -> None:
        """Pass authentication metadata to the provided callback.

        This method will be invoked asynchronously in a separate thread.

        Args:
        ----
            callback: An AuthMetadataPluginCallback to be invoked either
            synchronously or asynchronously.

        """
        metadata = (("authorization", f"Token {self._token}"),)
        callback(metadata, None)


[docs] class CredentialHelper: """OAuth credential helper for managing authentication tokens."""
[docs] def __init__( self, client_id: str, client_secret: str, auth_url: str = "https://crispthinking.auth0.com/oauth/token", audience: str = "crisp-athena-live", ) -> None: """Initialize the credential helper. Args: ---- client_id: OAuth client ID client_secret: OAuth client secret auth_url: OAuth token endpoint URL audience: OAuth audience """ if not client_id: msg = "client_id cannot be empty" raise CredentialError(msg) if not client_secret: msg = "client_secret cannot be empty" raise CredentialError(msg) self._client_id: str = client_id self._client_secret: str = client_secret self._auth_url: str = auth_url self._audience: str = audience self._token: str | None = None self._token_expires_at: float | None = None self._lock: asyncio.Lock = asyncio.Lock()
[docs] async def get_token(self) -> str: """Get a valid authentication token. This method will return a cached token if it's still valid, or fetch a new token if needed. Returns ------- A valid authentication token Raises ------ OAuthError: If token acquisition fails TokenExpiredError: If token has expired and refresh fails """ async with self._lock: if self._is_token_valid(): if self._token is None: msg = "Token should be valid but is None" raise RuntimeError(msg) return self._token await self._refresh_token() if self._token is None: msg = "Token refresh failed" raise RuntimeError(msg) return self._token
def _is_token_valid(self) -> bool: """Check if the current token is valid and not expired. Returns ------- True if token is valid, False otherwise """ if not self._token or not self._token_expires_at: return False # Add 30 second buffer before expiration return time.time() < (self._token_expires_at - 30) async def _refresh_token(self) -> None: """Refresh the authentication token by making an OAuth request. Raises ------ OAuthError: If the OAuth request fails """ payload = { "client_id": self._client_id, "client_secret": self._client_secret, "audience": self._audience, "grant_type": "client_credentials", } headers = {"content-type": "application/json"} try: async with httpx.AsyncClient() as client: response = await client.post( self._auth_url, json=payload, headers=headers, timeout=30.0, ) _ = response.raise_for_status() token_data = response.json() self._token = token_data["access_token"] expires_in = token_data.get( "expires_in", 3600 ) # Default 1 hour self._token_expires_at = time.time() + expires_in except httpx.HTTPStatusError as e: error_detail = "" try: error_data = e.response.json() error_desc = error_data.get( "error_description", error_data.get("error", "") ) error_detail = f": {error_desc}" except (json.JSONDecodeError, KeyError): pass msg = ( f"OAuth request failed with status " f"{e.response.status_code}{error_detail}" ) raise OAuthError(msg) from e except (httpx.RequestError, httpx.TimeoutException) as e: msg = f"Failed to connect to OAuth server: {e}" raise OAuthError(msg) from e except KeyError as e: msg = f"Invalid OAuth response format: missing {e}" raise OAuthError(msg) from e except Exception as e: msg = f"Unexpected error during OAuth: {e}" raise OAuthError(msg) from e
[docs] async def invalidate_token(self) -> None: """Invalidate the current token to force a refresh on next use.""" async with self._lock: self._token = None self._token_expires_at = None
[docs] async def create_channel_with_credentials( host: str, credential_helper: CredentialHelper, ) -> Channel: """Create a gRPC channel with OAuth credential helper. Args: ---- host: The host address to connect to credential_helper: The credential helper for OAuth authentication Returns: ------- A secure gRPC channel with OAuth authentication Raises: ------ InvalidHostError: If host is empty OAuthError: If OAuth authentication fails """ if not host: raise InvalidHostError(InvalidHostError.default_message) # Get a valid token from the credential helper token = await credential_helper.get_token() # Create credentials with token authentication credentials = grpc.composite_channel_credentials( grpc.ssl_channel_credentials(), grpc.access_token_call_credentials(token), ) # Configure gRPC options for persistent connections options = [ # Keep connections alive longer ("grpc.keepalive_time_ms", 60000), # Send keepalive every 60s ("grpc.keepalive_timeout_ms", 30000), # Wait 30s for keepalive ack ( "grpc.keepalive_permit_without_calls", 1, ), # Allow keepalive when idle # Optimize for persistent streams ("grpc.http2.max_pings_without_data", 0), # Allow unlimited pings ( "grpc.http2.min_time_between_pings_ms", 60000, ), # Min 60s between pings ( "grpc.http2.min_ping_interval_without_data_ms", 30000, ), # Min 30s when idle # Increase buffer sizes for better performance ("grpc.http2.write_buffer_size", 1024 * 1024), # 1MB write buffer ( "grpc.max_receive_message_length", 64 * 1024 * 1024, ), # 64MB max message ] return grpc.aio.secure_channel(host, credentials, options=options)