from fastapi import FastAPI
from pydantic import BaseModel
import requests
import uuid
import os
import logging
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
import imagehash
import json
from io import BytesIO

# Import your actual prediction logic
from filter_images import predict_all

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
TEMP_DIR = "temp_images"
HASH_FILE = "image_hashes.json"
os.makedirs(TEMP_DIR, exist_ok=True)

# Load or initialize hash dictionary
if os.path.exists(HASH_FILE):
    with open(HASH_FILE, "r") as f:
        hash_dict = json.load(f)
else:
    hash_dict = {}

# Convert string hashes to imagehash objects
hash_dict = {imagehash.hex_to_hash(k): v for k, v in hash_dict.items()}
hash_func = imagehash.phash

# FastAPI app
app = FastAPI(title="Image Filter API")

# Request model
class ImageBatchRequest(BaseModel):
    image_urls: List[str]
    activity: str

# Response model
class FilterResponse(BaseModel):
    image_url: str
    accepted: bool
    reasons: List[str]

# POST /filter endpoint with duplicate detection
@app.post("/filter", response_model=List[FilterResponse], summary="Filter multiple images by URL")
def filter_images(request: ImageBatchRequest):
    results = []
    activity = request.activity
    print(activity)
 
    def process_url(image_url: str) -> FilterResponse:
        try:
            logger.info(f"Received image URL: {image_url}")
            response = requests.get(image_url, timeout=10)
            if response.status_code != 200:
                return FilterResponse(
                    image_url=image_url,
                    accepted=False,
                    reasons=["Failed to download image from URL."]
                )

            ext = image_url.split(".")[-1].split("?")[0].lower()
            if ext not in ["jpg", "jpeg", "png"]:
                return FilterResponse(
                    image_url=image_url,
                    accepted=False,
                    reasons=["Unsupported image format."]
                )

            filename = f"{uuid.uuid4()}.{ext}"
            filepath = os.path.join(TEMP_DIR, filename)
            with open(filepath, "wb") as f:
                f.write(response.content)

            # Check for duplicate using perceptual hash
            try:
                img = Image.open(BytesIO(response.content))
                img_hash = hash_func(img)

                if img_hash in hash_dict:
                    original_url = hash_dict[img_hash]
                    os.remove(filepath)
                    return FilterResponse(
                        image_url=image_url,
                        accepted=False,
                        reasons=[f"Duplicate image detected. Matches previously processed image: {original_url}"]
                    )

                # Save new hash
                hash_dict[img_hash] = image_url
                with open(HASH_FILE, "w") as f:
                    json.dump({str(k): v for k, v in hash_dict.items()}, f, indent=2)

            except Exception as e:
                logger.warning(f"Hashing failed for {image_url}: {e}")

            # Run prediction
            result = predict_all(filepath,activity)
            os.remove(filepath)

            return FilterResponse(
                image_url=image_url,
                accepted=result.get("accepted", False),
                reasons=result.get("reasons", ["Image rejected."])
            )

        except Exception as e:
            logger.error(f"Error processing image {image_url}: {str(e)}")
            return FilterResponse(
                image_url=image_url,
                accepted=False,
                reasons=[f"Error: {str(e)}"]
            )

    # Run predictions in parallel
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = [executor.submit(process_url, url) for url in request.image_urls]
        for future in as_completed(futures):
            results.append(future.result())

    return results

# POST /filter_dev endpoint (bypass logic)
@app.post("/filter_dev", response_model=List[FilterResponse], summary="Filter multiple images by URL (bypass logic)")
def filter_images_bypass(request: ImageBatchRequest):
    results = []
    for image_url in request.image_urls:
        logger.info(f"Bypass filter received image URL: {image_url}")
        results.append(FilterResponse(
            image_url=image_url,
            accepted=True,
            reasons=[]
        ))
    return results
















###################################################################################################
# from fastapi import FastAPI
# from pydantic import BaseModel
# import requests
# import uuid
# import os
# import logging
# from typing import List
# from concurrent.futures import ThreadPoolExecutor, as_completed

# # Import your actual prediction logic
# from filter_images import predict_all

# # Setup logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)

# # Constants
# TEMP_DIR = "temp_images"
# os.makedirs(TEMP_DIR, exist_ok=True)

# # FastAPI app
# app = FastAPI(title="Image Filter API")

# # Request model
# class ImageBatchRequest(BaseModel):
#     image_urls: List[str]

# # Response model
# class FilterResponse(BaseModel):
#     image_url: str
#     accepted: bool
#     reasons: List[str]

# # POST /filter_dev endpoint with parallel processing
# @app.post("/filter", response_model=List[FilterResponse], summary="Filter multiple images by URL")
# def filter_images(request: ImageBatchRequest):
#     results = []

#     def process_url(image_url: str) -> FilterResponse:
#         try:
#             logger.info(f"Received image URL: {image_url}")
#             response = requests.get(image_url, timeout=10)
#             if response.status_code != 200:
#                 return FilterResponse(
#                     image_url=image_url,
#                     accepted=False,
#                     reasons=["Failed to download image from URL."]
#                 )

#             ext = image_url.split(".")[-1].split("?")[0].lower()
#             if ext not in ["jpg", "jpeg", "png"]:
#                 return FilterResponse(
#                     image_url=image_url,
#                     accepted=False,
#                     reasons=["Unsupported image format."]
#                 )

#             filename = f"{uuid.uuid4()}.{ext}"
#             filepath = os.path.join(TEMP_DIR, filename)
#             with open(filepath, "wb") as f:
#                 f.write(response.content)

#             result = predict_all(filepath)
#             os.remove(filepath)

#             return FilterResponse(
#                 image_url=image_url,
#                 accepted=result.get("accepted", False),
#                 reasons=result.get("reasons", ["Image rejected."])
#             )

#         except Exception as e:
#             logger.error(f"Error processing image {image_url}: {str(e)}")
#             return FilterResponse(
#                 image_url=image_url,
#                 accepted=False,
#                 reasons=[f"Error: {str(e)}"]
#             )

#     # Run predictions in parallel
#     with ThreadPoolExecutor(max_workers=4) as executor:
#         futures = [executor.submit(process_url, url) for url in request.image_urls]
#         for future in as_completed(futures):
#             results.append(future.result())

#     return results

# # POST /filter endpoint (bypass logic)
# @app.post("/filter_dev", response_model=List[FilterResponse], summary="Filter multiple images by URL (bypass logic)")
# def filter_images_bypass(request: ImageBatchRequest):
#     results = []
#     for image_url in request.image_urls:
#         logger.info(f"Bypass filter received image URL: {image_url}")
#         results.append(FilterResponse(
#             image_url=image_url,
#             accepted=True,
#             reasons=[]
#         ))
#     return results


# # from fastapi import FastAPI
# # from pydantic import BaseModel
# # import requests
# # import uuid
# # import os
# # import logging
# # from typing import List, Dict

# # # Setup logging
# # logging.basicConfig(level=logging.INFO)
# # logger = logging.getLogger(__name__)

# # # Constants
# # TEMP_DIR = "temp_images"
# # os.makedirs(TEMP_DIR, exist_ok=True)

# # # FastAPI app
# # app = FastAPI(title="Image Filter API")

# # # Request model
# # class ImageBatchRequest(BaseModel):
# #     image_urls: List[str]

# # # Response model
# # class FilterResponse(BaseModel):
# #     image_url: str
# #     accepted: bool
# #     reasons: List[str]

# # # Dummy prediction function
# # def predict_all(filepath: str) -> Dict:
# #     logger.info(f"Running prediction on {filepath}")
# #     return {"accepted": True, "reasons": []}

# # # POST /filter_dev endpoint (batch version)
# # @app.post("/filter_dev", response_model=List[FilterResponse], summary="Filter multiple images by URL")
# # def filter_images(request: ImageBatchRequest):
# #     results = []

# #     for image_url in request.image_urls:
# #         try:
# #             logger.info(f"Received image URL: {image_url}")
# #             response = requests.get(image_url, timeout=10)
# #             if response.status_code != 200:
# #                 results.append(FilterResponse(
# #                     image_url=image_url,
# #                     accepted=False,
# #                     reasons=["Failed to download image from URL."]
# #                 ))
# #                 continue

# #             ext = image_url.split(".")[-1].split("?")[0].lower()
# #             if ext not in ["jpg", "jpeg", "png"]:
# #                 results.append(FilterResponse(
# #                     image_url=image_url,
# #                     accepted=False,
# #                     reasons=["Unsupported image format."]
# #                 ))
# #                 continue

# #             filename = f"{uuid.uuid4()}.{ext}"
# #             filepath = os.path.join(TEMP_DIR, filename)
# #             with open(filepath, "wb") as f:
# #                 f.write(response.content)

# #             result = predict_all(filepath)
# #             os.remove(filepath)

# #             accepted = result.get("accepted", False) if isinstance(result, dict) else bool(result)
# #             reasons = result.get("reasons", []) if isinstance(result, dict) else ["Image rejected. No detailed reason provided."]

# #             results.append(FilterResponse(
# #                 image_url=image_url,
# #                 accepted=accepted,
# #                 reasons=reasons
# #             ))

# #         except Exception as e:
# #             logger.error(f"Error processing image {image_url}: {str(e)}")
# #             results.append(FilterResponse(
# #                 image_url=image_url,
# #                 accepted=False,
# #                 reasons=[f"Error: {str(e)}"]
# #             ))

# #     return results

# # # POST /filter endpoint (bypass version)
# # @app.post("/filter", response_model=List[FilterResponse], summary="Filter multiple images by URL (bypass logic)")
# # def filter_images_bypass(request: ImageBatchRequest):
# #     results = []
# #     for image_url in request.image_urls:
# #         logger.info(f"Bypass filter received image URL: {image_url}")
# #         results.append(FilterResponse(
# #             image_url=image_url,
# #             accepted=True,
# #             reasons=[]
# #         ))
# #     return results


# # @app.get("/filter", response_model=FilterResponse, summary="Filter image by URL")
# # def filter_image(image_url: str = Query(..., description="Public image URL to evaluate")):
# #     try:
# #         # Download image
# #         response = requests.get(image_url, timeout=10)
# #         if response.status_code != 200:
# #             return {"accepted": False, "reasons": ["Failed to download image from URL."]}

# #         ext = image_url.split(".")[-1].split("?")[0].lower()
# #         if ext not in ["jpg", "jpeg", "png"]:
# #             return {"accepted": False, "reasons": ["Unsupported image format."]}

# #         filename = f"{uuid.uuid4()}.{ext}"
# #         filepath = os.path.join(TEMP_DIR, filename)
# #         with open(filepath, "wb") as f:
# #             f.write(response.content)

# #         # Run prediction
# #         result = predict_all(filepath)

# #         # Clean up
# #         os.remove(filepath)

# #         # Interpret result
# #         if isinstance(result, bool):
# #             return {
# #                 "accepted": result,
# #                 "reasons": [] if result else ["Image rejected. No detailed reason provided."]
# #             }
# #         elif isinstance(result, dict):
# #             return {
# #                 "accepted": result.get("accepted", False),
# #                 "reasons": result.get("reasons", ["Image rejected."])
# #             }
# #         else:
# #             return {"accepted": False, "reasons": ["Unexpected result format from filter_images."]}

# #     except Exception as e:
# #         return {"accepted": False, "reasons": [f"Error: {str(e)}"]}
    

# # @app.get("/filter_bypass", response_model=FilterResponse, summary="Filter image by URL")
# # def filter_image_demo(image_url: str = Query(..., description="Public image URL to evaluate")):
# #     # Dummy placeholder logic
# #     result = True  # or False, depending on what you want to simulate

# #     return {
# #         "accepted": result,
# #         "reasons": [] if result else ["Image rejected. No detailed reason provided."]
# #     }
