# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
"""Result processor for converting Foundry scenario results to JSONL format."""

import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional

from pyrit.models import AttackOutcome, AttackResult
from pyrit.scenario import DatasetConfiguration


def _get_attack_type_name(attack_identifier) -> str:
    """Extract attack type name from attack_identifier regardless of form.

    Handles both the current dict form (pyrit 0.11.0) and a future
    Identifier-object form (anticipated when pyrit adds AttackIdentifier).

    :param attack_identifier: The identifier from AttackResult, either dict or object
    :return: The attack type name string
    :rtype: str
    """
    if attack_identifier is None:
        return "Unknown"
    if isinstance(attack_identifier, dict):
        return attack_identifier.get("__type__", "Unknown")
    # Future: Identifier-style object with class_name attribute
    return getattr(attack_identifier, "class_name", "Unknown")


def _read_seed_content(seed) -> str:
    """Read seed content, handling both direct values and file paths.

    For binary_path data type, reads the file contents. For other types,
    returns the value directly.

    :param seed: The seed object containing the value
    :type seed: SeedPrompt
    :return: The content string
    :rtype: str
    """
    value = seed.value
    data_type = getattr(seed, "data_type", "text")

    if data_type == "binary_path" and os.path.isfile(value):
        try:
            with open(value, "r", encoding="utf-8") as f:
                return f.read()
        except Exception:
            return value  # Fallback to raw value if file read fails
    return value


class FoundryResultProcessor:
    """Processes Foundry scenario results into JSONL format.

    Extracts AttackResult objects from the completed Foundry scenario and
    converts them to the JSONL format expected by the main ResultProcessor.
    This ensures compatibility with existing result processing and reporting
    infrastructure.

    Handles binary_path data type by reading file contents when reconstructing
    context data.
    """

    def __init__(
        self,
        scenario,
        dataset_config: DatasetConfiguration,
        risk_category: str,
    ):
        """Initialize the processor.

        :param scenario: Completed Foundry scenario (ScenarioOrchestrator)
        :type scenario: ScenarioOrchestrator
        :param dataset_config: DatasetConfiguration used for the scenario
        :type dataset_config: DatasetConfiguration
        :param risk_category: The risk category being processed
        :type risk_category: str
        """
        self.scenario = scenario
        self.dataset_config = dataset_config
        self.risk_category = risk_category
        self._context_lookup: Dict[str, Dict[str, Any]] = {}
        self._build_context_lookup()

    def _read_context_content(self, seed) -> str:
        """Read context content, handling both direct values and file paths.

        Delegates to the module-level _read_seed_content function.

        :param seed: The seed object containing the value
        :type seed: SeedPrompt
        :return: The context content string
        :rtype: str
        """
        return _read_seed_content(seed)

    def _build_context_lookup(self) -> None:
        """Build lookup from prompt_group_id (UUID) to context data."""
        for seed_group in self.dataset_config.get_all_seed_groups():
            if not seed_group.seeds:
                continue

            # Get prompt_group_id from first seed
            group_id = seed_group.seeds[0].prompt_group_id
            if not group_id:
                continue

            # Find objective and context seeds
            objective_seed = None
            context_seeds = []

            for seed in seed_group.seeds:
                seed_class = seed.__class__.__name__
                if seed_class == "SeedObjective":
                    objective_seed = seed
                elif seed_class == "SeedPrompt":
                    context_seeds.append(seed)

            if objective_seed:
                # Extract context data
                contexts = []
                for ctx_seed in context_seeds:
                    metadata = ctx_seed.metadata or {}
                    # Read content from file if binary_path, otherwise use value directly
                    content = self._read_context_content(ctx_seed)

                    # For XPIA, include the injected vehicle
                    if metadata.get("is_attack_vehicle"):
                        contexts.append(
                            {
                                "content": content,
                                "tool_name": metadata.get("tool_name"),
                                "context_type": metadata.get("context_type"),
                                "is_attack_vehicle": True,
                            }
                        )
                    elif not metadata.get("is_original_context"):
                        # Standard context
                        contexts.append(
                            {
                                "content": content,
                                "tool_name": metadata.get("tool_name"),
                                "context_type": metadata.get("context_type"),
                            }
                        )

                self._context_lookup[str(group_id)] = {
                    "contexts": contexts,
                    "metadata": objective_seed.metadata or {},
                    "objective": objective_seed.value,
                }

    def to_jsonl(self, output_path: str) -> str:
        """Convert scenario results to JSONL format.

        :param output_path: Path to write JSONL file
        :type output_path: str
        :return: JSONL content string
        :rtype: str
        """
        # Get attack results from scenario
        attack_results = self.scenario.get_attack_results()

        # Get memory instance for querying conversations
        memory = self.scenario.get_memory()

        jsonl_lines = []

        # Process each AttackResult
        for attack_result in attack_results:
            entry = self._process_attack_result(attack_result, memory)
            if entry:
                jsonl_lines.append(json.dumps(entry, ensure_ascii=False))

        # Write to file
        jsonl_content = "\n".join(jsonl_lines)
        Path(output_path).parent.mkdir(parents=True, exist_ok=True)
        with open(output_path, "w", encoding="utf-8") as f:
            f.write(jsonl_content)

        return jsonl_content

    def _process_attack_result(
        self,
        attack_result: AttackResult,
        memory,
    ) -> Optional[Dict[str, Any]]:
        """Process a single AttackResult into JSONL entry format.

        :param attack_result: The attack result to process
        :type attack_result: AttackResult
        :param memory: Memory interface for conversation lookup
        :type memory: MemoryInterface
        :return: JSONL entry dictionary or None if processing fails
        :rtype: Optional[Dict[str, Any]]
        """
        try:
            # Get conversation messages for this result
            conversation_pieces = memory.get_message_pieces(conversation_id=attack_result.conversation_id)

            # Extract prompt_group_id from conversation metadata
            group_id = self._get_prompt_group_id_from_conversation(conversation_pieces)

            # Lookup context and metadata
            context_data = self._context_lookup.get(str(group_id), {}) if group_id else {}

            # Build conversation structure (matching existing format)
            messages = self._build_messages_from_pieces(conversation_pieces)

            conversation = {
                "messages": messages,
            }

            # Build JSONL entry (matching format expected by ResultProcessor)
            entry: Dict[str, Any] = {
                "conversation": conversation,
            }

            # Add context if available
            contexts = context_data.get("contexts", [])
            if contexts:
                entry["context"] = json.dumps({"contexts": contexts})

            # Add risk_sub_type if present in metadata
            metadata = context_data.get("metadata", {})
            if metadata.get("risk_subtype"):
                entry["risk_sub_type"] = metadata["risk_subtype"]

            # Add attack success based on outcome
            if attack_result.outcome == AttackOutcome.SUCCESS:
                entry["attack_success"] = True
            elif attack_result.outcome == AttackOutcome.FAILURE:
                entry["attack_success"] = False
            # UNDETERMINED leaves attack_success unset

            # Add strategy information
            raw_strategy = _get_attack_type_name(attack_result.attack_identifier)
            # Clean PyRIT class name for display (e.g., "PromptSendingAttack" → "PromptSending")
            entry["attack_strategy"] = raw_strategy.replace("Attack", "").replace("Converter", "")

            # Add score information if available
            if attack_result.last_score:
                score = attack_result.last_score
                entry["score"] = {
                    "value": score.score_value,
                    "rationale": score.score_rationale,
                    "metadata": score.score_metadata,
                }

            return entry

        except Exception as e:
            # Log error but don't fail entire processing
            return {
                "conversation": {"messages": []},
                "error": str(e),
                "conversation_id": attack_result.conversation_id,
            }

    def _get_prompt_group_id_from_conversation(
        self,
        conversation_pieces: List,
    ) -> Optional[str]:
        """Extract prompt_group_id from conversation pieces.

        :param conversation_pieces: List of message pieces from conversation
        :type conversation_pieces: List
        :return: prompt_group_id string or None
        :rtype: Optional[str]
        """
        for piece in conversation_pieces:
            if hasattr(piece, "prompt_metadata") and piece.prompt_metadata:
                group_id = piece.prompt_metadata.get("prompt_group_id")
                if group_id:
                    return str(group_id)

            # Also check labels
            if hasattr(piece, "labels") and piece.labels:
                group_id = piece.labels.get("prompt_group_id")
                if group_id:
                    return str(group_id)

        return None

    def _build_messages_from_pieces(
        self,
        conversation_pieces: List,
    ) -> List[Dict[str, Any]]:
        """Build message list from conversation pieces.

        :param conversation_pieces: List of message pieces
        :type conversation_pieces: List
        :return: List of message dictionaries
        :rtype: List[Dict[str, Any]]
        """
        messages = []

        # Sort by sequence if available
        sorted_pieces = sorted(conversation_pieces, key=lambda p: getattr(p, "sequence", 0))

        for piece in sorted_pieces:
            # Get role, handling api_role property
            role = getattr(piece, "api_role", None) or getattr(piece, "role", "user")

            # Get content (prefer converted_value over original_value)
            content = getattr(piece, "converted_value", None) or getattr(piece, "original_value", "")

            message: Dict[str, Any] = {
                "role": role,
                "content": content,
            }

            # Add context from labels if present (for XPIA)
            if hasattr(piece, "labels") and piece.labels:
                context_str = piece.labels.get("context")
                if context_str:
                    try:
                        context_dict = json.loads(context_str) if isinstance(context_str, str) else context_str
                        if isinstance(context_dict, dict) and "contexts" in context_dict:
                            message["context"] = context_dict["contexts"]
                    except (json.JSONDecodeError, TypeError):
                        pass

            messages.append(message)

        return messages

    def get_summary_stats(self) -> Dict[str, Any]:
        """Get summary statistics from the scenario results.

        :return: Dictionary with ASR and other metrics
        :rtype: Dict[str, Any]
        """
        attack_results = self.scenario.get_attack_results()

        if not attack_results:
            return {
                "total": 0,
                "successful": 0,
                "failed": 0,
                "undetermined": 0,
                "asr": 0.0,
            }

        successful = sum(1 for r in attack_results if r.outcome == AttackOutcome.SUCCESS)
        failed = sum(1 for r in attack_results if r.outcome == AttackOutcome.FAILURE)
        undetermined = sum(1 for r in attack_results if r.outcome == AttackOutcome.UNDETERMINED)
        total = len(attack_results)

        decided = successful + failed
        return {
            "total": total,
            "successful": successful,
            "failed": failed,
            "undetermined": undetermined,
            "asr": successful / decided if decided > 0 else 0.0,
        }
