Source code for resolver_athena_client.client.utils
"""Utility functions for handling classification responses and errors."""
import logging
from typing import TYPE_CHECKING
from resolver_athena_client.client.exceptions import ClassificationOutputError
if TYPE_CHECKING:
from resolver_athena_client.generated.athena.models_pb2 import (
ClassificationOutput,
ClassifyResponse,
)
logger = logging.getLogger(__name__)
[docs]
def process_classification_outputs(
response: "ClassifyResponse",
*,
raise_on_error: bool = False,
log_errors: bool = True,
) -> list["ClassificationOutput"]:
"""Process classification outputs from a response, handling errors properly.
Args:
----
response: The ClassifyResponse containing outputs to process
raise_on_error: If True, raises ClassificationOutputError when an output
contains an error. If False, logs the error and skips the output.
log_errors: If True, logs error information for failed outputs
Returns:
-------
List of successful ClassificationOutput objects (excludes outputs with
errors when raise_on_error=False)
Raises:
------
ClassificationOutputError: When raise_on_error=True and an output
contains an error
"""
successful_outputs: list[ClassificationOutput] = []
for output in response.outputs:
if output.error and output.error.message:
error_msg = (
f"Classification failed for {output.correlation_id[:8]}: "
f"{output.error.message}"
)
if output.error.details:
error_msg += f" ({output.error.details})"
if log_errors:
logger.error(
"Output error [%s]: %s (code: %s)",
output.correlation_id[:8],
output.error.message,
output.error.code,
)
if raise_on_error:
raise ClassificationOutputError(
correlation_id=output.correlation_id,
error=output.error,
)
# Skip this output if not raising
continue
successful_outputs.append(output)
return successful_outputs
[docs]
def get_output_error_summary(response: "ClassifyResponse") -> dict[str, int]:
"""Get a summary of error types in the response outputs.
Args:
----
response: The ClassifyResponse to analyze
Returns:
-------
Dictionary mapping error code names to their counts
"""
error_counts: dict[str, int] = {}
for output in response.outputs:
if output.error and output.error.message:
error_code_name = str(output.error.code)
error_counts[error_code_name] = (
error_counts.get(error_code_name, 0) + 1
)
return error_counts
[docs]
def has_output_errors(response: "ClassifyResponse") -> bool:
"""Check if any outputs in the response contain errors.
Args:
----
response: The ClassifyResponse to check
Returns:
-------
True if any output contains an error, False otherwise
"""
return any(
output.error and output.error.message for output in response.outputs
)
def get_successful_outputs(
response: "ClassifyResponse",
) -> list["ClassificationOutput"]:
"""Get only the successful outputs from a response, filtering out errors.
Args:
----
response: The ClassifyResponse to filter
Returns:
-------
List of ClassificationOutput objects that don't contain errors
"""
return [
output
for output in response.outputs
if not (output.error and output.error.message)
]
def log_output_errors(response: "ClassifyResponse") -> None:
"""Log all output errors in a response for debugging purposes.
Args:
----
response: The ClassifyResponse to analyze for errors
"""
for output in response.outputs:
if output.error and output.error.message:
logger.error(
"Classification error [%s]: %s (code: %s)",
output.correlation_id[:8],
output.error.message,
output.error.code,
)
if output.error.details:
logger.debug(
"Error details [%s]: %s",
output.correlation_id[:8],
output.error.details,
)