import torch
import timm
from collections import OrderedDict
import types
from PIL import Image
import os
import re
from torchvision import transforms
from concurrent.futures import ThreadPoolExecutor, as_completed

# Paths to trained models
MODEL_PATHS = {
    "obstacle": "/var/www/html/ai-image-ml/obstacle_classifier.pt"
     ,
     "moisture":"/var/www/html/ai-image-ml/moisture_classifier.pt",
    # "soil": "/var/www/html/agri_image_ai/soil_visibility_classifier.pt",
    # "ruler": "/var/www/html/agri_image_ai/ruler_classifier.pt",
     "clarity": "/var/www/html/ai-image-ml/clarity_classifier.pt",
     "pipe":"/var/www/html/ai-image-ml/pipe_classifier.pt"
}


########################################################

import os
import torch
from collections import OrderedDict
import timm
import types
from torchvision.models import resnet  # for allowlisting if needed

# your MODEL_ARCH and MODEL_PATHS as before

MODEL_ARCH = {
    "obstacle": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1),
    "clarity":  lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1),
    "pipe":     lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1),
    "moisture": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1)
}

def strip_module_prefix(state_dict):
    new_state = OrderedDict()
    for k, v in state_dict.items():
        new_key = k[len("module."):] if k.startswith("module.") else k
        new_state[new_key] = v
    return new_state

def try_extract_state_dict(checkpoint):
    # Returns (state_dict_or_None, model_obj_or_None)
    if isinstance(checkpoint, OrderedDict) or isinstance(checkpoint, dict):
        # if it's a plain mapping with string keys, assume state_dict
        if all(isinstance(k, str) for k in checkpoint.keys()):
            return strip_module_prefix(checkpoint), None

    if isinstance(checkpoint, dict):
        for key in ("state_dict", "model_state_dict", "model"):
            if key in checkpoint:
                candidate = checkpoint[key]
                if isinstance(candidate, (OrderedDict, dict)):
                    return strip_module_prefix(candidate), None
                if hasattr(candidate, "eval"):
                    return None, candidate

    # if it's an nn.Module-like object
    if hasattr(checkpoint, "eval") and isinstance(checkpoint.eval, types.MethodType):
        return None, checkpoint

    return None, None

def robust_torch_load(path):
    """
    Attempt to load checkpoint safely:
      1) try weights_only=True (safe)
      2) on specific failure, retry inside safe_globals for known classes with weights_only=False
    Returns whatever torch.load returns (OrderedDict, dict, or nn.Module)
    """
    try:
        # try the safe weights-only load first (no unpickle execution)
        return torch.load(path, map_location=torch.device("cpu"), weights_only=True)
    except Exception as e:
        msg = str(e)
        # If message indicates weights-only load failed and mentions unsupported global, try allowlist
        if "Weights only load failed" in msg or "Unsupported global" in msg or "UnsupportedGlobal" in msg:
            # --- ONLY do this for files you trust ---
            try:
                # Temporarily allowlist the ResNet global for unpickling
                with torch.serialization.safe_globals([resnet.ResNet]):
                    return torch.load(path, map_location=torch.device("cpu"), weights_only=False)
            except Exception as e2:
                # re-raise the second failure for clarity
                raise RuntimeError(f"Failed loading checkpoint even inside safe_globals: {e2}") from e2
        # otherwise re-raise original error
        raise

# MAIN loader loop - creates architecture then loads the checkpoint (weights or full model)
models_dict = {}
for key, path in MODEL_PATHS.items():
    try:
        if not os.path.exists(path):
            raise FileNotFoundError(f"path not found: {path}")

        if key not in MODEL_ARCH:
            raise KeyError(f"No architecture defined for key '{key}'")

        # create architecture (we'll load weights into this if the checkpoint is state_dict-like)
        arch_model = MODEL_ARCH[key]()

        # robustly load the checkpoint (returns dict/OrderedDict or nn.Module)
        checkpoint = robust_torch_load(path)

        state_dict, loaded_model = try_extract_state_dict(checkpoint)

        if state_dict is not None:
            # load into architecture
            arch_model.load_state_dict(state_dict, strict=False)
            arch_model.eval()
            models_dict[key] = arch_model
            print(f"✅ Loaded weights into architecture for: {key}")

        elif loaded_model is not None:
            # checkpoint was a saved model object; unwrap if DataParallel
            real_model = getattr(loaded_model, "module", loaded_model)
            real_model.eval()
            models_dict[key] = real_model
            print(f"✅ Loaded full saved model for: {key}")

        else:
            raise RuntimeError("Unrecognized checkpoint format (neither state_dict nor model object).")

    except Exception as e:
        print(f"❌ Failed to load model '{key}' from {path}: {e}")
        raise


########################################################
# # Define model architecture per classifier
# MODEL_ARCH = {
#     "obstacle": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1)
#      ,
#     # "soil": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1),
#     # "ruler": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1),
#      "clarity": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1),
#      "pipe": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1)
# }

#     # Load models
# models_dict = {}
# for key, path in MODEL_PATHS.items():
#     try:
#         model = MODEL_ARCH[key]()
#         torch.serialization.add_safe_globals([resnet.ResNet])
#         #model.load_state_dict(torch.load(path, map_location=torch.device("cpu")))
#         model = torch.load(path, map_location=torch.device("cpu"))
#         model.eval()
#         models_dict[key] = model
#         print(f"✅ Loaded model: {key}")
#     except Exception as e:
#         print(f"❌ Failed to load model '{key}' from {path}: {e}")
#         raise

# Image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Prediction for a single image
def predict_all(image_path,activity):
    print("activity ,", activity)
    try:
        image = Image.open(image_path).convert("RGB")
        if image.size[0] < 50 or image.size[1] < 50:
            return {"accepted": False, "reasons": ["Image too small."]}

        input_tensor = transform(image).unsqueeze(0)
        results = {}
        with torch.no_grad():
            for key, model in models_dict.items():
                output = model(input_tensor)
                prob = torch.sigmoid(output).item()
                results[key] = prob > 0.5

        failed = []
        # if not results["soil"]:
        #     failed.append("Soil not visible inside pipe.")
        if results["obstacle"]:
            failed.append("Pipe is obstructed.")
        # if not results["ruler"]:
        #     failed.append("Ruler not present inside pipe.")
        if not results["clarity"]:
            failed.append("Image is unclear.")
        if not results["pipe"]:
            failed.append("No pipe in the Image.")
        #if re.search("dry", activity, re.IGNORECASE):
             # if not results["moisture"]:  
                # failed.append("Soil has moisture.") 
        # if results["ruler"] and not results["soil"]:
        #     failed.append("Ruler is present but soil is not visible through pipe.")

        return {
            "accepted": len(failed) == 0,
            "reasons": failed
        }

    except Exception as e:
        return {"accepted": False, "reasons": [f"Error: {str(e)}"]}

# Parallel batch filtering
def batch_filter_images_parallel(source_dir, accepted_dir, max_workers=4):
    os.makedirs(accepted_dir, exist_ok=True)
    results = []

    def process_image(filename):
        path = os.path.join(source_dir, filename)
        result = predict_all(path)
        result["filename"] = filename
        if result["accepted"]:
            os.rename(path, os.path.join(accepted_dir, filename))
        return result

    image_files = [
        f for f in os.listdir(source_dir)
        if f.lower().endswith((".jpg", ".jpeg", ".png"))
    ]

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_file = {executor.submit(process_image, f): f for f in image_files}
        for future in as_completed(future_to_file):
            try:
                result = future.result()
                results.append(result)
                status = "✅ Accepted" if result["accepted"] else "❌ Rejected"
                print(f"{status}: {result['filename']} → {', '.join(result['reasons']) if result['reasons'] else 'No issues'}")
            except Exception as e:
                filename = future_to_file[future]
                print(f"❌ Error processing {filename}: {str(e)}")
                results.append({
                    "filename": filename,
                    "accepted": False,
                    "reasons": [f"Error: {str(e)}"]
                })

        return  results

# Run batch if executed directly
if __name__ == "__main__":
    SOURCE_DIR = "raw_images"
    ACCEPTED_DIR = "accepted_images"
    results = batch_filter_images_parallel(SOURCE_DIR, ACCEPTED_DIR)



#!/usr/bin/env python3
# """
# filter_images.py

# Batch-filter images using multiple classifiers (obstacle, clarity, pipe).
# Robust checkpoint loader compatible with PyTorch 2.x (handles weights_only behavior
# and fallbacks for older torch versions).

# Usage:
#     python filter_images.py --source raw_images --accepted accepted_images --workers 6
#     python filter_images.py --source raw_images --accepted accepted_images --allow-unsafe
# """
# import os
# import re
# import importlib
# import argparse
# from concurrent.futures import ThreadPoolExecutor, as_completed
# from PIL import Image

# import torch
# import timm
# from torchvision import transforms

# # Compatibility import for UnpicklingError (PyTorch may or may not re-export it)
# try:
#     from torch.serialization import UnpicklingError
# except ImportError:
#     from pickle import UnpicklingError

# # -----------------------------
# # Configuration: update paths
# # -----------------------------
# MODEL_PATHS = {
#     "obstacle": "/var/www/html/agri_image_ai/obstacle_classifier.pt",
#     # "soil": "/var/www/html/agri_image_ai/soil_visibility_classifier.pt",
#     # "ruler": "/var/www/html/agri_image_ai/ruler_classifier.pt",
#     "clarity": "/var/www/html/agri_image_ai/clarity_classifier.pt",
#     "pipe": "/var/www/html/agri_image_ai/pipe_classifier.pt"
# }

# MODEL_ARCH = {
#     "obstacle": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1),
#     # "soil": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1),
#     # "ruler": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1),
#     "clarity": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1),
#     "pipe": lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=1)
# }

# # -----------------------------
# # Image transform
# # -----------------------------
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # -----------------------------
# # Helper: import dotted attribute
# # -----------------------------
# def _try_import_attr(dotted_path):
#     """
#     Import and return the attribute/class for a dotted path like
#     "torchvision.models.resnet.ResNet". Returns the imported object on success,
#     or None on failure.
#     """
#     try:
#         parts = dotted_path.split(".")
#         for split_idx in range(len(parts) - 1, 0, -1):
#             module_name = ".".join(parts[:split_idx])
#             attr_name = ".".join(parts[split_idx:])
#             try:
#                 module = importlib.import_module(module_name)
#             except Exception:
#                 continue
#             obj = module
#             for a in attr_name.split("."):
#                 obj = getattr(obj, a)
#             return obj
#     except Exception:
#         return None
#     return None

# # -----------------------------
# # Robust checkpoint loader
# # -----------------------------
# def load_checkpoint_safely(path, allow_unsafe=False):
#     """
#     Attempts to read the checkpoint and return either a state_dict-like mapping
#     or the loaded object. Strategy:
#       1) Try torch.load with weights_only=True (safe).
#       2) If UnpicklingError mentions disallowed globals, try to import them and add to safe globals, then retry.
#       3) If still failing, and allow_unsafe True (or env var set), fallback to torch.load(..., weights_only=False) (UNSAFE).
#     """
#     # 1) try safe load first
#     try:
#         obj = torch.load(path, map_location=torch.device("cpu"), weights_only=True)
#         return obj
#     except UnpicklingError as e:
#         msg = str(e)
#         # extract dotted globals from common error message formats
#         globals_found = re.findall(r"GLOBAL\s+([A-Za-z0-9_.]+)", msg)
#         if not globals_found:
#             globals_found = re.findall(r"Unsupported global:\s*([A-Za-z0-9_.]+)", msg)

#         allowed_objs = []
#         for g in globals_found:
#             imported = _try_import_attr(g)
#             if imported is not None:
#                 allowed_objs.append(imported)

#         if allowed_objs:
#             # Try to add to safe globals if API exists (PyTorch 2.6+)
#             try:
#                 if hasattr(torch.serialization, "add_safe_globals"):
#                     torch.serialization.add_safe_globals(allowed_objs)
#                     obj = torch.load(path, map_location=torch.device("cpu"), weights_only=True)
#                     return obj
#                 else:
#                     # fallback to context manager if available (some versions)
#                     safe_ctx = getattr(torch.serialization, "safe_globals", None)
#                     if safe_ctx is not None:
#                         with safe_ctx(allowed_objs):
#                             obj = torch.load(path, map_location=torch.device("cpu"), weights_only=True)
#                             return obj
#                     # else: cannot register safe globals in this torch version; continue
#             except Exception:
#                 # continue to optional unsafe fallback
#                 pass

#         # 3) final fallback if explicitly allowed
#         if allow_unsafe or os.environ.get("ALLOW_UNSAFE_CHECKPOINTS", "") == "1":
#             # WARNING: this will execute pickled code. Only do if you trust the file.
#             obj = torch.load(path, map_location=torch.device("cpu"), weights_only=False)
#             return obj

#         # If we reach here, no safe way to load; raise helpful message
#         raise UnpicklingError(
#             "Failed to load checkpoint safely. "
#             "If you trust this checkpoint and want to permit unsafe loading, set environment variable "
#             "`ALLOW_UNSAFE_CHECKPOINTS=1` or run with --allow-unsafe. "
#             "Preferred alternative: re-save the checkpoint on the training machine using "
#             "`torch.save(model.state_dict(), 'weights.pth')`."
#         )
#     except Exception as e:
#         # propagate other exceptions (file corrupted, incompatible versions, etc.)
#         raise

# # -----------------------------
# # Load all models
# # -----------------------------
# def load_all_models(model_paths, model_arch, allow_unsafe=False):
#     models_dict = {}
#     for key, path in model_paths.items():
#         if not os.path.isfile(path):
#             raise FileNotFoundError(f"Model file not found: {path} for key '{key}'")
#         try:
#             model = model_arch[key]()
#             loaded = load_checkpoint_safely(path, allow_unsafe=allow_unsafe)

#             if isinstance(loaded, dict):
#                 # common keys that hold state dict
#                 if "state_dict" in loaded:
#                     sd = loaded["state_dict"]
#                 elif "model_state_dict" in loaded:
#                     sd = loaded["model_state_dict"]
#                 else:
#                     sd = loaded

#                 # remove "module." prefix if present (DataParallel)
#                 new_sd = {}
#                 for k, v in sd.items():
#                     new_key = k[len("module."):] if k.startswith("module.") else k
#                     new_sd[new_key] = v

#                 model.load_state_dict(new_sd)
#             else:
#                 # loaded could be a full nn.Module saved object
#                 if hasattr(loaded, "state_dict"):
#                     model.load_state_dict(loaded.state_dict())
#                 else:
#                     raise RuntimeError("Checkpoint loaded to unsupported object type; expected dict-like or nn.Module.")

#             model.eval()
#             models_dict[key] = model
#             print(f"✅ Loaded model: {key} from {path}")
#         except Exception as e:
#             print(f"❌ Failed to load model '{key}' from {path}: {e}")
#             raise
#     return models_dict

# # -----------------------------
# # Prediction logic
# # -----------------------------
# def predict_all(image_path, models_dict):
#     try:
#         image = Image.open(image_path).convert("RGB")
#         if image.size[0] < 50 or image.size[1] < 50:
#             return {"accepted": False, "reasons": ["Image too small."]}

#         input_tensor = transform(image).unsqueeze(0)  # shape (1, C, H, W)
#         results = {}
#         with torch.no_grad():
#             for key, model in models_dict.items():
#                 output = model(input_tensor)
#                 # some models return tuple (outputs, ...) so handle that
#                 if isinstance(output, tuple):
#                     output = output[0]
#                 # flatten to scalar safely
#                 try:
#                     prob = torch.sigmoid(output.view(-1)[0]).item()
#                 except Exception:
#                     # fallback if output already scalar tensor
#                     prob = torch.sigmoid(output).item() if hasattr(output, "item") else float(output)

#                 results[key] = prob > 0.5

#         failed = []
#         # rules (adjust per your requirements)
#         if results.get("obstacle"):
#             failed.append("Pipe is obstructed.")
#         if not results.get("clarity", False):
#             failed.append("Image is unclear.")
#         if not results.get("pipe", False):
#             failed.append("No pipe in the Image.")

#         return {"accepted": len(failed) == 0, "reasons": failed}

#     except Exception as e:
#         return {"accepted": False, "reasons": [f"Error: {str(e)}"]}

# # -----------------------------
# # Batch processing (parallel)
# # -----------------------------
# def batch_filter_images_parallel(source_dir, accepted_dir, models_dict, max_workers=4):
#     os.makedirs(accepted_dir, exist_ok=True)
#     results = []

#     def process_image(filename):
#         path = os.path.join(source_dir, filename)
#         result = predict_all(path, models_dict)
#         result["filename"] = filename
#         if result["accepted"]:
#             dest = os.path.join(accepted_dir, filename)
#             try:
#                 os.replace(path, dest)
#             except Exception:
#                 # fallback to copy + remove if os.replace fails (cross-device)
#                 from shutil import copy2
#                 copy2(path, dest)
#                 os.remove(path)
#         return result

#     image_files = [
#         f for f in os.listdir(source_dir)
#         if f.lower().endswith((".jpg", ".jpeg", ".png"))
#     ]

#     with ThreadPoolExecutor(max_workers=max_workers) as executor:
#         future_to_file = {executor.submit(process_image, f): f for f in image_files}
#         for future in as_completed(future_to_file):
#             filename = future_to_file[future]
#             try:
#                 result = future.result()
#                 results.append(result)
#                 status = "✅ Accepted" if result["accepted"] else "❌ Rejected"
#                 reasons = ", ".join(result["reasons"]) if result["reasons"] else "No issues"
#                 print(f"{status}: {result['filename']} → {reasons}")
#             except Exception as e:
#                 print(f"❌ Error processing {filename}: {str(e)}")
#                 results.append({"filename": filename, "accepted": False, "reasons": [f"Error: {str(e)}"]})

#     return results

# # -----------------------------
# # CLI entrypoint
# # -----------------------------
# def main():
#     parser = argparse.ArgumentParser(description="Filter images via multiple classifiers.")
#     parser.add_argument("--source", "-s", default="raw_images", help="Directory with source images")
#     parser.add_argument("--accepted", "-a", default="accepted_images", help="Directory to move accepted images into")
#     parser.add_argument("--workers", "-w", type=int, default=4, help="Number of parallel workers")
#     parser.add_argument("--allow-unsafe", action="store_true", help="Allow unsafe checkpoint loading (weights_only=False). Use only for trusted checkpoints.")
#     args = parser.parse_args()

#     # If user passed allow-unsafe, set env var for downstream behavior
#     if args.allow_unsafe:
#         os.environ["ALLOW_UNSAFE_CHECKPOINTS"] = "1"

#     print("Loading models...")
#     models_dict = load_all_models(MODEL_PATHS, MODEL_ARCH, allow_unsafe=(args.allow_unsafe or os.environ.get("ALLOW_UNSAFE_CHECKPOINTS","")=="1"))

#     print(f"Processing images from '{args.source}' → accepted will go to '{args.accepted}' with {args.workers} workers.")
#     results = batch_filter_images_parallel(args.source, args.accepted, models_dict, max_workers=args.workers)

#     # Print summary
#     accepted = sum(1 for r in results if r["accepted"])
#     rejected = len(results) - accepted
#     print(f"\nSummary: processed {len(results)} images — accepted: {accepted}, rejected: {rejected}")

# if __name__ == "__main__":
#     main()


# import torch
# import torchvision.models.resnet  # Required for safe deserialization
# from torchvision import transforms
# from PIL import Image
# import os
# from concurrent.futures import ThreadPoolExecutor, as_completed

# # Paths to trained models
# # "soil": "/var/www/html/agri_image_ai/soil_visibility_classifier.pt",
# #     "obstacle": "/var/www/html/agri_image_ai/obstacle_classifier.pt",
# #     "ruler": "/var/www/html/agri_image_ai/ruler_classifier.pt",
# # "clarity": "/var/www/html/agri_image_ai/clarity_classifier.pt"
# MODEL_PATHS = { 
#     "obstacle": "/var/www/html/agri_image_ai/obstacle_classifier.pt",
# }

# # Load models safely
# models_dict = {}
# for key, path in MODEL_PATHS.items():
#     try:
#         with torch.serialization.safe_globals([torchvision.models.resnet.ResNet]):
#             model = torch.load(path, map_location=torch.device("cpu"), weights_only=False)
#             model.eval()
#             models_dict[key] = model
#     except Exception as e:
#         print(f"❌ Failed to load model '{key}' from {path}: {e}")
#         raise

# # Image transform
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # Prediction for a single image
# def predict_all(image_path):
#     try:
#         image = Image.open(image_path).convert("RGB")
#         if image.size[0] < 50 or image.size[1] < 50:
#             return {"accepted": False, "reasons": ["Image too small."]}

#         input_tensor = transform(image).unsqueeze(0)
#         results = {}
#         with torch.no_grad():
#             for key, model in models_dict.items():
#                 output = model(input_tensor)
#                 prob = torch.sigmoid(output).item()
#                 results[key] = prob > 0.5

#         failed = []
#         # if not results["soil"]:
#         #     failed.append("Soil not visible inside pipe.")
#         if results["obstacle"]:
#             failed.append("Pipe is obstructed.")
#         # if not results["ruler"]:
#         #     failed.append("Ruler not present inside pipe.")
#         # if not results["clarity"]:
#         #     failed.append("Image is unclear.")
#         # if results["ruler"] and not results["soil"]:
#         #     failed.append("Ruler is present but soil is not visible through pipe.")

#         return {
#             "accepted": len(failed) == 0,
#             "reasons": failed
#         }

#     except Exception as e:
#         return {"accepted": False, "reasons": [f"Error: {str(e)}"]}

# # Parallel batch filtering
# def batch_filter_images_parallel(source_dir, accepted_dir, max_workers=4):
#     os.makedirs(accepted_dir, exist_ok=True)
#     results = []

#     def process_image(filename):
#         path = os.path.join(source_dir, filename)
#         result = predict_all(path)
#         result["filename"] = filename
#         if result["accepted"]:
#             os.rename(path, os.path.join(accepted_dir, filename))
#         return result

#     image_files = [
#         f for f in os.listdir(source_dir)
#         if f.lower().endswith((".jpg", ".jpeg", ".png"))
#     ]

#     with ThreadPoolExecutor(max_workers=max_workers) as executor:
#         future_to_file = {executor.submit(process_image, f): f for f in image_files}
#         for future in as_completed(future_to_file):
#             try:
#                 result = future.result()
#                 results.append(result)
#                 status = "✅ Accepted" if result["accepted"] else "❌ Rejected"
#                 print(f"{status}: {result['filename']} → {', '.join(result['reasons']) if result['reasons'] else 'No issues'}")
#             except Exception as e:
#                 filename = future_to_file[future]
#                 print(f"❌ Error processing {filename}: {str(e)}")
#                 results.append({
#                     "filename": filename,
#                     "accepted": False,
#                     "reasons": [f"Error: {str(e)}"]
#                 })

#     return results

# # Run batch if executed directly
# if __name__ == "__main__":
#     SOURCE_DIR = "raw_images"
#     ACCEPTED_DIR = "accepted_images"
#     results = batch_filter_images_parallel(SOURCE_DIR, ACCEPTED_DIR)

# import torch
# import torchvision.models.resnet  # Required for safe deserialization
# from torchvision import transforms
# from PIL import Image
# import os

# # Paths to trained models (excluding 'moisture' and 'angle')
# MODEL_PATHS = {
#     "soil": "/var/www/html/agri_image_ai/soil_visibility_classifier.pt",
#     "obstacle": "/var/www/html/agri_image_ai/obstacle_classifier.pt",
#     "ruler": "/var/www/html/agri_image_ai/ruler_classifier.pt",
#     "clarity": "/var/www/html/agri_image_ai/clarity_classifier.pt"
# }

# # Load models safely using PyTorch 2.6 rules
# models_dict = {}
# for key, path in MODEL_PATHS.items():
#     try:
#         with torch.serialization.safe_globals([torchvision.models.resnet.ResNet]):
#             model = torch.load(path, map_location=torch.device("cpu"), weights_only=False)
#             model.eval()
#             models_dict[key] = model
#     except Exception as e:
#         print(f"❌ Failed to load model '{key}' from {path}: {e}")
#         raise

# # Image transform (same as training)
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # Prediction function for a single image
# def predict_all(image_path):
#     try:
#         image = Image.open(image_path).convert("RGB")
#         if image.size[0] < 50 or image.size[1] < 50:
#             return {"accepted": False, "reasons": ["Image too small."]}

#         input_tensor = transform(image).unsqueeze(0)
#         results = {}
#         with torch.no_grad():
#             for key, model in models_dict.items():
#                 output = model(input_tensor)
#                 prob = torch.sigmoid(output).item()
#                 results[key] = prob > 0.5

#         # Evaluate conditions
#         failed = []
#         if not results["soil"]:
#             failed.append("Soil not visible inside pipe.")
#         if results["obstacle"]:
#             failed.append("Pipe is obstructed.")
#         if not results["ruler"]:
#             failed.append("Ruler not present inside pipe.")
#         if not results["clarity"]:
#             failed.append("Image is unclear.")

#         # Dependency: if ruler is present, soil must be visible
#         if results["ruler"] and not results["soil"]:
#             failed.append("Ruler is present but soil is not visible through pipe.")

#         return {
#             "accepted": len(failed) == 0,
#             "reasons": failed
#         }

#     except Exception as e:
#         return {"accepted": False, "reasons": [f"Error: {str(e)}"]}

# # Optional batch filtering
# if __name__ == "__main__":
#     SOURCE_DIR = "raw_images"
#     ACCEPTED_DIR = "accepted_images"
#     os.makedirs(ACCEPTED_DIR, exist_ok=True)

#     for filename in os.listdir(SOURCE_DIR):
#         if filename.lower().endswith((".jpg", ".jpeg", ".png")):
#             path = os.path.join(SOURCE_DIR, filename)
#             result = predict_all(path)
#             if result["accepted"]:
#                 os.rename(path, os.path.join(ACCEPTED_DIR, filename))
#                 print(f"✅ Accepted: {filename}")
#             else:
#                 print(f"❌ Rejected: {filename} → Reasons: {', '.join(result['reasons'])}")




# import torch
# from torchvision import transforms, models
# from PIL import Image
# import os

# # Paths to trained models
# MODEL_PATHS = {
#    # "angle": "angle_classifier.pt",
#     "soil": "soil_visibility_classifier.pt",
#     "obstacle": "obstacle_classifier.pt",
#     "ruler": "ruler_classifier.pt",
#    # "moisture": "moisture_classifier.pt",
#     "clarity": "clarity_classifier.pt"
# }

# # Load models
# models_dict = {}
# for key, path in MODEL_PATHS.items():
#     model = torch.load(path, map_location=torch.device("cpu"))
#     model.eval()
#     models_dict[key] = model

# # Image transform
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # Prediction function
# def predict_all(image_path):
#     try:
#         image = Image.open(image_path).convert("RGB")
#         if image.size[0] < 50 or image.size[1] < 50:
#             return {"accepted": False, "reasons": ["Image too small."]}

#         input_tensor = transform(image).unsqueeze(0)
#         results = {}
#         with torch.no_grad():
#             for key, model in models_dict.items():
#                 output = model(input_tensor)
#                 prob = torch.sigmoid(output).item()
#                 results[key] = prob > 0.5

#         # Evaluate conditions
#         failed = []
#       #  if not results["angle"]: failed.append("Image not perpendicular.")
#         if not results["soil"]: failed.append("Soil not visible inside pipe.")
#         if results["obstacle"]: failed.append("Pipe is obstructed.")
#         if not results["ruler"]: failed.append("Ruler not present inside pipe.")
#       #  if results["moisture"]: failed.append("Field appears wet.")
#         if not results["clarity"]: failed.append("Image is unclear.")

#         return {
#             "accepted": len(failed) == 0,
#             "reasons": failed
#         }

#     except Exception as e:
#         return {"accepted": False, "reasons": [f"Error: {str(e)}"]}

# # Optional batch filtering
# if __name__ == "__main__":
#     SOURCE_DIR = "raw_images"
#     ACCEPTED_DIR = "accepted_images"
#     os.makedirs(ACCEPTED_DIR, exist_ok=True)

#     for filename in os.listdir(SOURCE_DIR):
#         if filename.lower().endswith((".jpg", ".jpeg", ".png")):
#             path = os.path.join(SOURCE_DIR, filename)
#             result = predict_all(path)
#             if result["accepted"]:
#                 os.rename(path, os.path.join(ACCEPTED_DIR, filename))
#                 print(f"✅ Accepted: {filename}")
#             else:
#                 print(f"❌ Rejected: {filename} → Reasons: {', '.join(result['reasons'])}")


# import torch
# from torchvision import transforms
# from PIL import Image
# import os

# # Load models
# models = {
#     "angle": torch.load("angle_classifier.pt", map_location=torch.device('cpu')).eval(),
#     "soil": torch.load("soil_visibility_classifier.pt", map_location=torch.device('cpu')).eval(),
#     "obstacle": torch.load("obstacle_classifier.pt", map_location=torch.device('cpu')).eval(),
#     "ruler": torch.load("ruler_classifier.pt", map_location=torch.device('cpu')).eval(),
#     "moisture": torch.load("moisture_classifier.pt", map_location=torch.device('cpu')).eval(),
#     "clarity": torch.load("clarity_classifier.pt", map_location=torch.device('cpu')).eval()
# }

# # Transform
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # Prediction function
# def predict_all(image_path):
#     try:
#         image = Image.open(image_path).convert("RGB")
#         if image.size[0] < 50 or image.size[1] < 50:
#             return False
#         input_tensor = transform(image).unsqueeze(0)

#         results = {}
#         with torch.no_grad():
#             for key, model in models.items():
#                 output = model(input_tensor)
#                 prob = torch.sigmoid(output).item()
#                 results[key] = prob > 0.5

#         # Accept only if all required conditions are met
#         return all([
#             results["angle"],       # Perpendicular
#             results["soil"],        # Soil visible
#             not results["obstacle"],# No obstacle
#             results["ruler"],       # Ruler present
#             not results["moisture"],# Dry soil
#             results["clarity"]      # Image is clear
#         ])

#     except Exception as e:
#         print(f"Error with {image_path}: {e}")
#         return False

# # Filter images
# SOURCE_DIR = "raw_images"
# ACCEPTED_DIR = "accepted_images"
# os.makedirs(ACCEPTED_DIR, exist_ok=True)

# for filename in os.listdir(SOURCE_DIR):
#     if filename.lower().endswith((".jpg", ".jpeg", ".png")):
#         path = os.path.join(SOURCE_DIR, filename)
#         if predict_all(path):
#             os.rename(path, os.path.join(ACCEPTED_DIR, filename))
#             print(f"✅ Accepted: {filename}")
#         else:
#             print(f"❌ Rejected: {filename}")
