import os
import glob
import shutil
import xml.etree.ElementTree as ET
import torch
import json
import sys
from ultralytics import YOLO
from jinja2 import Environment, FileSystemLoader
from tqdm import tqdm

# --- Configuration ---
# Directories
IMAGE_DIR = "20"
ANNOTATION_DIR = "Annotation"
MODEL_DIR = "Models"
OUTPUT_DIR = "output"
CLASSIFIED_IMG_DIR = os.path.join(OUTPUT_DIR, "classified_images")
DETECTION_RESULTS_DIR = os.path.join(OUTPUT_DIR, "detection_results")

# Common labels for normalization
STANDARD_LABELS = {"gun", "knife", "wrench", "pliers", "scissor"}

# --- Label Mappings ---
# For DN_Medium model
DN_MAP = {
    0: "gun", 1: "knife", 2: "scissor", 3: "pliers", 4: "wrench",
    5: "bullet", 6: "screwdriver", 7: "weapon", 8: "grenade", 9: "dangerous"
}

# For Box_v11 model (maps keywords to standard labels)
BOX_V11_KEYWORDS = {
    "gun": "gun",
    "knife": "knife",
    "wrench": "wrench",
    "pliers": "pliers",
    "scissor": "scissor"
}

# For other models
GENERAL_MAP = {
    "Gun": "gun", "Knife": "knife", "Wrench": "wrench",
    "Pliers": "pliers", "Scissor": "scissor"
}

def make_json_serializable(data):
    """Recursively converts sets to lists in a dictionary or list."""
    if isinstance(data, dict):
        return {k: make_json_serializable(v) for k, v in data.items()}
    if isinstance(data, list):
        return [make_json_serializable(i) for i in data]
    if isinstance(data, set):
        return sorted(list(data))
    return data

def get_model_info(model_path):
    """Gets the correct label map and key type based on the model path."""
    path_str = str(model_path)
    if "DN_Medium" in path_str:
        return DN_MAP, 'index'
    if "Box_v11" in path_str:
        return BOX_V11_KEYWORDS, 'name_keyword'
    return GENERAL_MAP, 'name_direct'

def normalize_xml_labels(label_set):
    """Normalizes labels from XML files to lowercase."""
    return {label.lower() for label in label_set}

def normalize_predicted_labels(predictions, model_map, key_type, model_names):
    """Normalizes predicted labels from different models to the standard set."""
    normalized = set()
    if not hasattr(predictions, 'cls'):
        return normalized

    pred_classes = predictions.cls.tolist()
    
    for cls_id in pred_classes:
        label_name = ""
        if key_type == 'index':
            label_name = model_map.get(int(cls_id))
        elif key_type in ['name_direct', 'name_keyword']:
            label_name = model_names.get(int(cls_id))
        
        if not label_name:
            continue
            
        label_name_lower = label_name.lower()

        if key_type == 'name_keyword':
            for keyword, standard_label in model_map.items():
                if keyword in label_name_lower:
                    normalized.add(standard_label)
                    break
        else:
            if label_name_lower in STANDARD_LABELS:
                normalized.add(label_name_lower)
                
    return normalized

def get_ground_truth(xml_path):
    """Parses an XML file to get a set of ground truth labels."""
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        labels = set()
        for obj in root.findall('object'):
            name_tag = obj.find('name')
            # Check if name_tag and its text exist before adding
            if name_tag is not None and name_tag.text is not None:
                labels.add(name_tag.text)
        return normalize_xml_labels(labels)
    except (FileNotFoundError, ET.ParseError):
        return set()

def classify_images():
    """Copies images to classified folders based on their annotations."""
    print("Step 1: Classifying images based on annotations...")
    if os.path.exists(CLASSIFIED_IMG_DIR):
        shutil.rmtree(CLASSIFIED_IMG_DIR)
    
    image_paths = glob.glob(os.path.join(IMAGE_DIR, "*.jpg"))
    if not image_paths:
        print(f"Warning: No images found in '{IMAGE_DIR}'.")
        return

    for img_path in tqdm(image_paths, desc="Classifying Images"):
        img_name = os.path.basename(img_path)
        xml_name = os.path.splitext(img_name)[0] + ".xml"
        xml_path = os.path.join(ANNOTATION_DIR, xml_name)
        
        labels = get_ground_truth(xml_path)
        if not labels:
            label_dir = os.path.join(CLASSIFIED_IMG_DIR, "unlabeled")
            os.makedirs(label_dir, exist_ok=True)
            shutil.copy(img_path, os.path.join(label_dir, img_name))
        else:
            for label in labels:
                label_dir = os.path.join(CLASSIFIED_IMG_DIR, label)
                os.makedirs(label_dir, exist_ok=True)
                shutil.copy(img_path, os.path.join(label_dir, img_name))
    print(f"Images classified into '{CLASSIFIED_IMG_DIR}'")

def run_evaluation():
    """Main function to run the full evaluation and reporting pipeline."""
    is_test_mode = '--test' in sys.argv
    if is_test_mode:
        print("--- RUNNING IN TEST MODE ---")

    print("\nStep 2: Preparing for model evaluation...")
    
    # --- Device Configuration ---
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device} for model inference.")
    # --------------------------

    # Setup output directories
    if os.path.exists(DETECTION_RESULTS_DIR):
        shutil.rmtree(DETECTION_RESULTS_DIR)
    os.makedirs(DETECTION_RESULTS_DIR)
    
    model_paths = glob.glob(os.path.join(MODEL_DIR, '**', 'best.pt'), recursive=True)
    image_paths = sorted(glob.glob(os.path.join(IMAGE_DIR, "*.jpg")))

    if is_test_mode:
        image_paths = image_paths[:5] # Use only 5 images for testing

    if not model_paths or not image_paths:
        print("Error: Models or images not found. Aborting.")
        return

    report_data = {'models': [], 'images': []}

    # Prepare image data structure for the report
    for img_path in image_paths:
        img_name = os.path.basename(img_path)
        xml_path = os.path.join(ANNOTATION_DIR, os.path.splitext(img_name)[0] + '.xml')
        
        # Create a web-safe relative path for the original image.
        # The report is in 'output/', so we need to go up one level.
        web_original_path = os.path.join('..', img_path).replace(os.sep, '/')

        report_data['images'].append({
            'name': img_name,
            'path': web_original_path,
            'ground_truth': get_ground_truth(xml_path),
            'predictions': {}
        })

    print(f"\nStep 3: Running detection with {len(model_paths)} models on {len(image_paths)} images...")
    for model_path in model_paths:
        try:
            model = YOLO(model_path)
            model_map, key_type = get_model_info(model_path)
            
            # Generate a more descriptive model name from path
            path_parts = [p for p in model_path.split(os.sep) if p not in ("Models", "weights")]
            model_name = "_".join(path_parts[:-1]) # Combine parts to make a unique name
            
            model_results_dir = os.path.join(DETECTION_RESULTS_DIR, model_name)
            os.makedirs(model_results_dir, exist_ok=True)

            correct_predictions = 0
            
            # Process images one by one to avoid "Too many open files" error
            for i, img_path in enumerate(tqdm(image_paths, desc=f"Processing {model_name}")):
                try:
                    # Predict on a single image using the selected device
                    results = model.predict(img_path, device=device, verbose=False)[0]

                    img_info = report_data['images'][i]
                    img_name = img_info['name']

                    # Save detection image
                    output_image_path = os.path.join(model_results_dir, img_name)
                    results.save(filename=output_image_path)
                    
                    # Create a web-safe relative path for the prediction image.
                    # The report is in 'output/', so the path should be relative to that.
                    web_prediction_path = os.path.relpath(output_image_path, OUTPUT_DIR).replace(os.sep, '/')

                    # Normalize predicted labels
                    predicted_labels = normalize_predicted_labels(results.boxes, model_map, key_type, results.names)
                    
                    # Store results for the report
                    img_info['predictions'][model_name] = {
                        'output_image_path': web_prediction_path,
                        'labels': predicted_labels
                    }
                    
                    # Calculate accuracy: correct if any predicted label matches any ground truth label
                    if img_info['ground_truth'] and predicted_labels:
                        if not img_info['ground_truth'].isdisjoint(predicted_labels):
                            correct_predictions += 1
                    elif not img_info['ground_truth'] and not predicted_labels: # Correctly identified no objects
                        correct_predictions += 1
                except Exception as e:
                    print(f"\nError processing image {img_path} with model {model_name}: {e}")


            accuracy = (correct_predictions / len(image_paths)) if image_paths else 0
            report_data['models'].append({'name': model_name, 'accuracy': f"{accuracy:.2%}"})
            print(f"  - Model '{model_name}' accuracy: {accuracy:.2%}")

        except Exception as e:
            print(f"Error processing model {model_path}: {e}")

    # --- Sanitize and save data to JSON for easy regeneration ---
    report_data_path = os.path.join(OUTPUT_DIR, 'report_data.json')
    print(f"\nStep 4: Saving report data to '{report_data_path}'...")
    
    # Make sure all sets are converted to lists before saving
    final_report_data = make_json_serializable(report_data)
    
    with open(report_data_path, 'w', encoding='utf-8') as f:
        json.dump(final_report_data, f, ensure_ascii=False, indent=4)
    # ---------------------------------------------

    # Generate HTML report
    print(f"\nStep 5: Generating HTML report...")
    env = Environment(loader=FileSystemLoader('.'))
    template = env.get_template('template.html')
    html_content = template.render(models=final_report_data['models'], images=final_report_data['images'])
    
    report_path = os.path.join(OUTPUT_DIR, 'report.html')
    with open(report_path, 'w', encoding='utf-8') as f:
        f.write(html_content)
        
    print(f"\nDone! Report saved to '{os.path.abspath(report_path)}'")


if __name__ == '__main__':
    # classify_images() # Let's skip this for now to speed up testing
    run_evaluation() 