import os
from pathlib import Path
from ultralytics import YOLO
import torch
import shutil
import json

# --- Configuration ---
MODELS_DIR = Path("Models")
DATA_DIR = Path("Data_fp1+fp2")
RESULTS_DIR = Path("results")
JSON_RESULTS_FILE = RESULTS_DIR / "results.json"

def run_detection_and_create_json():
    """
    Processes images one-by-one with each model (one at a time) to prevent memory errors.
    Saves annotated images and compiles all information into a single structured JSON file.
    """
    # --- 1. Setup and Sanity Checks ---
    if not MODELS_DIR.is_dir() or not DATA_DIR.is_dir():
        print(f"Error: Directory '{MODELS_DIR}' or '{DATA_DIR}' does not exist.")
        return

    model_paths = sorted(list(MODELS_DIR.rglob("**/weights/best.pt"))) # Sort for consistent order
    if not model_paths:
        print(f"No 'best.pt' files found in '{MODELS_DIR}'.")
        return

    image_files = sorted([f for f in DATA_DIR.glob('*') if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.hif0.jpg']])
    if not image_files:
        print(f"No image files found in '{DATA_DIR}'.")
        return

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print(f"Found {len(model_paths)} models and {len(image_files)} images.")
    print(f"Using device: {device.upper()}. Processing one model and one image at a time.")

    # --- 2. Clean and Prepare Output Directory ---
    if RESULTS_DIR.exists():
        print(f"Deleting old results directory: '{RESULTS_DIR}'")
        shutil.rmtree(RESULTS_DIR)
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)

    # --- 3. Initialize Report Data Structure ---
    report_data = {
        "model_names": ["_".join(p.parts[1:-2]) for p in model_paths],
        "image_results": {}
    }
    for img_path in image_files:
        report_data["image_results"][img_path.name] = {
            # Use as_posix() to get web-friendly relative paths (e.g. 'folder/image.jpg')
            "original_path": img_path.as_posix(),
            "models": {}
        }

    # --- 4. Main Processing Loop (Model by Model) ---
    for model_idx, model_path in enumerate(model_paths):
        model_name = report_data["model_names"][model_idx]
        print(f"\n--- Loading and processing with model: {model_name} ---")

        model = None
        try:
            model = YOLO(model_path)
            output_dir_for_model = RESULTS_DIR / model_name / 'images'

            print(f"  Processing {len(image_files)} images for this model...")
            for i, image_path in enumerate(image_files):
                if (i + 1) % 20 == 0 or i == len(image_files) - 1:
                     print(f"    - Image {i+1}/{len(image_files)}: {image_path.name}")

                results = model.predict(
                    source=str(image_path),
                    save=True,
                    project=str(RESULTS_DIR / model_name),
                    name='images',
                    exist_ok=True,
                    device=device,
                    verbose=False
                )
                result = results[0]
                
                # --- 5. Populate JSON data ---
                detected_classes = []
                if result.boxes and len(result.boxes) > 0:
                    names = result.names
                    # Check if the model is the DN model and apply the custom mapping
                    if "DN" in model_name:
                        dn_labels = {
                            0: "gun", 1: "knife", 2: "scissor", 3: "plier", 4: "wrench",
                            5: "bullet", 6: "screwdriver", 7: "weapon", 8: "grenade", 9: "dangerous"
                        }
                        names = dn_labels
                    class_ids = result.boxes.cls.tolist()
                    detected_classes = sorted(list(set([names[int(cls_id)] for cls_id in class_ids])))

                result_image_path = output_dir_for_model / image_path.name
                report_data["image_results"][image_path.name]["models"][model_name] = {
                    "result_path": result_image_path.as_posix(),
                    "detected_classes": detected_classes
                }

        except Exception as e:
            print(f"An error occurred with model {model_name}: {e}")
        finally:
            if model is not None: del model
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                print("  - Cleared CUDA cache.")
    
    # --- 6. Write the final JSON file ---
    with open(JSON_RESULTS_FILE, 'w', encoding='utf-8') as f:
        json.dump(report_data, f, indent=4)

    print("\n--- All models processed successfully! ---")
    print(f"All image results are saved in '{RESULTS_DIR}'.")
    print(f"Structured data report saved to '{JSON_RESULTS_FILE}'.")

if __name__ == "__main__":
    run_detection_and_create_json() 