# import torch
# import torch.nn as nn
# import timm
# import onnx
# from onnxsim import simplify
# import os
# import shutil
# import zipfile

# # =========================
# # CONFIG
# # =========================
# ONNX_DIR = "onnx"
# ZIP_NAME = "latest.zip"

# os.makedirs(ONNX_DIR, exist_ok=True)

# # =========================
# # NORMALIZATION WRAPPER
# # =========================
# class NormalizedModel(nn.Module):
#     def __init__(self, backbone):
#         super().__init__()
#         self.backbone = backbone
#         self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
#         self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

#     def forward(self, x):
#         x = (x - self.mean) / self.std
#         return self.backbone(x)

# # =========================
# # AUTO-DETECT ARCHITECTURE
# # =========================
# def get_backbone_from_checkpoint(checkpoint):
#     if not isinstance(checkpoint, dict):
#         print("📦 Full model detected")
#         return checkpoint

#     keys = list(checkpoint.keys())

#     # EfficientNet
#     if any("conv_stem" in k for k in keys):
#         print("🔍 Detected: EfficientNet-B0")
#         return timm.create_model('efficientnet_b0', pretrained=False, num_classes=1)

#     # ResNet
#     if any("layer1" in k for k in keys) or any("conv1.weight" in k for k in keys):
#         print("🔍 Detected: ResNet18")
#         from torchvision import models
#         model = models.resnet18(pretrained=False)
#         model.fc = nn.Linear(model.fc.in_features, 1)
#         return model

#     # Fallback
#     print("⚠️ Unknown architecture → defaulting to EfficientNet")
#     return timm.create_model('efficientnet_b0', pretrained=False, num_classes=1)

# # =========================
# # CONVERT FUNCTION
# # =========================
# def convert_model(pt_path, onnx_name):
#     print(f"\n🔄 Converting {pt_path}...")

#     checkpoint = torch.load(pt_path, map_location="cpu", weights_only=False)

#     backbone = get_backbone_from_checkpoint(checkpoint)

#     try:
#         if isinstance(checkpoint, dict):
#             backbone.load_state_dict(checkpoint)
#         else:
#             backbone = checkpoint
#     except Exception as e:
#         print(f"❌ Skipping {pt_path} due to mismatch:\n{e}")
#         return

#     model = NormalizedModel(backbone)
#     model.eval()

#     dummy_input = torch.randn(1, 3, 224, 224)

#     # Export ONNX
#     torch.onnx.export(
#         model,
#         dummy_input,
#         onnx_name,
#         input_names=['input'],
#         output_names=['output'],
#         dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
#         opset_version=12
#     )

#     # Simplify
#     try:
#         onnx_model = onnx.load(onnx_name)
#         model_simp, check = simplify(onnx_model)

#         if check:
#             simplified_name = onnx_name.replace(".onnx", "_simplified.onnx")
#             onnx.save(model_simp, simplified_name)
#             print(f"✅ Simplified: {simplified_name}")

#             # Move to ONNX folder
#             dest_path = os.path.join(ONNX_DIR, os.path.basename(simplified_name))
#             shutil.move(simplified_name, dest_path)

#         else:
#             print("⚠️ Simplification failed, keeping raw ONNX")

#     except Exception as e:
#         print(f"⚠️ Simplification error: {e}")

# # =========================
# # MODEL LIST
# # =========================
# MODELS_TO_CONVERT = [
#     ("pipe_classifier.pt", "pipe_classifier.onnx"),
#     ("obstacle_classifier.pt", "obstacle_classifier.onnx"),
#     ("clarity_classifier.pt", "clarity_classifier.onnx"),
#     ("ruler_classifier.pt", "ruler_classifier.onnx"),
#     ("soil_visibility_classifier.pt", "soil_visibility_classifier.onnx")
# ]

# # =========================
# # RUN CONVERSION
# # =========================
# for pt, onnx_file in MODELS_TO_CONVERT:
#     if os.path.exists(pt):
#         convert_model(pt, onnx_file)
#     else:
#         print(f"❌ Skip: {pt} not found")

# # =========================
# # ZIP OUTPUT
# # =========================
# print("\n📦 Creating ZIP...")

# with zipfile.ZipFile(ZIP_NAME, 'w', zipfile.ZIP_DEFLATED) as zipf:
#     for file in os.listdir(ONNX_DIR):
#         file_path = os.path.join(ONNX_DIR, file)
#         zipf.write(file_path, arcname=file)

# print(f"✅ ZIP created: {ZIP_NAME}")











import torch
import torch.nn as nn
import timm
import onnx
from onnxsim import simplify
import os
import shutil
import zipfile

# =========================
# CONFIG
# =========================
ONNX_DIR = "onnx"
ZIP_NAME = "latest.zip"

os.makedirs(ONNX_DIR, exist_ok=True)

# =========================
# NORMALIZATION WRAPPER
# =========================
class NormalizedModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x):
        x = (x - self.mean) / self.std
        return self.backbone(x)

# =========================
# CONVERT FUNCTION
# =========================
def convert_model(pt_path, onnx_name):
    print(f"\n🔄 Converting {pt_path}...")

    # Always EfficientNet-B0 now
    backbone = timm.create_model("efficientnet_b0", pretrained=False, num_classes=1)

    checkpoint = torch.load(pt_path, map_location="cpu", weights_only=False)

    if isinstance(checkpoint, dict):
        backbone.load_state_dict(checkpoint)
    else:
        backbone = checkpoint

    model = NormalizedModel(backbone)
    model.eval()

    dummy_input = torch.randn(1, 3, 224, 224)

    # Export ONNX
    torch.onnx.export(
        model,
        dummy_input,
        onnx_name,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
        opset_version=12
    )

    # Simplify
    try:
        onnx_model = onnx.load(onnx_name)
        model_simp, check = simplify(onnx_model)

        if check:
            simplified_name = onnx_name.replace(".onnx", "_simplified.onnx")
            onnx.save(model_simp, simplified_name)
            print(f"✅ Simplified: {simplified_name}")

            # Move to ONNX folder
            dest_path = os.path.join(ONNX_DIR, os.path.basename(simplified_name))
            shutil.move(simplified_name, dest_path)

        else:
            print("⚠️ Simplification failed, keeping raw ONNX")

    except Exception as e:
        print(f"⚠️ Simplification error: {e}")

# =========================
# MODEL LIST
# =========================
MODELS_TO_CONVERT = [
    ("pipe_classifier.pt", "pipe_classifier.onnx"),
    ("obstacle_classifier.pt", "obstacle_classifier.onnx"),
    ("clarity_classifier.pt", "clarity_classifier.onnx"),
    ("ruler_classifier.pt", "ruler_classifier.onnx"),
    ("soil_visibility_classifier.pt", "soil_visibility_classifier.onnx")
]

# =========================
# RUN CONVERSION
# =========================
for pt, onnx_file in MODELS_TO_CONVERT:
    if os.path.exists(pt):
        convert_model(pt, onnx_file)
    else:
        print(f"❌ Skip: {pt} not found")

# =========================
# ZIP OUTPUT
# =========================
print("\n📦 Creating ZIP...")

with zipfile.ZipFile(ZIP_NAME, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for file in os.listdir(ONNX_DIR):
        file_path = os.path.join(ONNX_DIR, file)
        zipf.write(file_path, arcname=file)

print(f"✅ ZIP created: {ZIP_NAME}")




# https://grq-img-store.sgp1.cdn.digitaloceanspaces.com/prod/classifiers_zip/latest.zip
# dop_v1_9b1e535aa3a2dc517a49654fe838f6e772e00d21155ec85938d2ce822e8fc033
###########################################################################
