import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from collections import Counter
import timm

# Paths
DATA_DIR = "data_clarity"
MODEL_PATH = "clarity_classifier.pt"

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load dataset
dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
print("✅ Class mapping:", dataset.class_to_idx)

# Label mapping: clear_image → 1, hazy_image → 0
label_map = {
    dataset.class_to_idx['clear_image']: 1,
    dataset.class_to_idx['hazy_image']: 0
}
labels = [label_map[label] for _, label in dataset.samples]
label_counts = Counter(labels)
print("📊 Class distribution:", label_counts)

# Weighted sampling
weights = [1.0 / label_counts[label] for label in labels]
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, stratify=labels, random_state=42)

train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)

train_sampler = WeightedRandomSampler([weights[i] for i in train_idx], len(train_idx), replacement=True)
train_loader = DataLoader(train_dataset, batch_size=16, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=16)

# Model setup using timm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=1)
model = model.to(device)

# Loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Evaluation function
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            mapped_labels = torch.tensor([label_map[label.item()] for label in labels]).float().unsqueeze(1).to(device)
            outputs = model(inputs)
            preds = (torch.sigmoid(outputs) > 0.5).float()
            correct += (preds == mapped_labels).sum().item()
            total += mapped_labels.size(0)
    return correct / total

# Training loop
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        mapped_labels = torch.tensor([label_map[label.item()] for label in labels]).float().unsqueeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, mapped_labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    val_acc = evaluate(model, val_loader)
    print(f"📦 Epoch {epoch+1}/10 - Loss: {running_loss/len(train_loader):.4f} - Val Acc: {val_acc:.4f}")

# Save model weights only
torch.save(model.state_dict(), MODEL_PATH)
print(f"✅ Model weights saved to {MODEL_PATH}")


# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torchvision import datasets, transforms, models
# from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
# from sklearn.model_selection import train_test_split
# from collections import Counter

# # Paths
# DATA_DIR = "data_clarity"
# MODEL_PATH = "clarity_classifier.pt"

# # Transforms
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(10),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # Load dataset
# dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
# print("✅ Class mapping:", dataset.class_to_idx)

# # Force correct label mapping: clear_image → 0, hazy_image → 1
# label_map = {dataset.class_to_idx['clear_image']: 1, dataset.class_to_idx['hazy_image']: 0}
# labels = [label_map[label] for _, label in dataset.samples]
# label_counts = Counter(labels)
# print("📊 Class distribution:", label_counts)

# # Weighted sampling to balance classes
# class_weights = [1.0 / label_counts[label_map[label]] for _, label in dataset.samples]
# train_idx, val_idx = train_test_split(
#     list(range(len(dataset))),
#     test_size=0.2,
#     stratify=labels,
#     random_state=42
# )

# train_dataset = Subset(dataset, train_idx)
# val_dataset = Subset(dataset, val_idx)

# train_sampler = WeightedRandomSampler(
#     [class_weights[i] for i in train_idx],
#     num_samples=len(train_idx),
#     replacement=True
# )

# train_loader = DataLoader(train_dataset, batch_size=16, sampler=train_sampler)
# val_loader = DataLoader(val_dataset, batch_size=16)

# # Model setup
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = models.resnet18(pretrained=True)
# model.fc = nn.Linear(model.fc.in_features, 1)
# model = model.to(device)

# # Loss and optimizer
# criterion = nn.BCEWithLogitsLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# # Evaluation function
# def evaluate(model, dataloader):
#     model.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for inputs, labels in dataloader:
#             inputs = inputs.to(device)
#             mapped_labels = torch.tensor([label_map[label.item()] for label in labels]).float().unsqueeze(1).to(device)
#             outputs = model(inputs)
#             preds = (torch.sigmoid(outputs) > 0.5).float()
#             correct += (preds == mapped_labels).sum().item()
#             total += mapped_labels.size(0)
#     return correct / total

# # Training loop
# for epoch in range(10):
#     model.train()
#     running_loss = 0.0
#     for inputs, labels in train_loader:
#         inputs = inputs.to(device)
#         mapped_labels = torch.tensor([label_map[label.item()] for label in labels]).float().unsqueeze(1).to(device)

#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, mapped_labels)
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()

#     val_acc = evaluate(model, val_loader)
#     print(f"📦 Epoch {epoch+1}/10 - Loss: {running_loss/len(train_loader):.4f} - Val Acc: {val_acc:.4f}")

# # Save model
# torch.save(model, MODEL_PATH)
# print(f"✅ Model saved to {MODEL_PATH}")













#################################################################################
# import os
# import torch
# from torchvision import datasets, transforms
# from torch.utils.data import DataLoader
# from PIL import Image

# # Paths
# MODEL_PATH = "clarity_classifier.pt"
# TEST_DIR = "dataclarity"

# # Load model
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = torch.load(MODEL_PATH, map_location=device)
# model.eval()

# # Transform (must match training)
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # Load test dataset
# test_dataset = datasets.ImageFolder(TEST_DIR, transform=transform)
# test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# # Class mapping
# idx_to_class = {v: k for k, v in test_dataset.class_to_idx.items()}
# print("Class mapping:", idx_to_class)

# # Run predictions
# print("\n🔍 Predictions:")
# for i, (inputs, labels) in enumerate(test_loader):
#     inputs = inputs.to(device)

#     with torch.no_grad():
#         outputs = model(inputs)
#         probs = torch.sigmoid(outputs)
#         preds = (probs > 0.5).float()

#     predicted_class = idx_to_class[int(preds.item())]
#     true_class = idx_to_class[int(labels.item())]
#     print(f"[{i+1}] True: {true_class} | Predicted: {predicted_class} | Confidence: {probs.item():.4f}")


# from torchvision import datasets, transforms

# # Path to your training data folder
# DATA_DIR = "data_clarity"  # Update this if your training folder is named differently

# # Use the same transform as training (optional here, but good practice)
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # Load dataset
# dataset = datasets.ImageFolder(DATA_DIR, transform=transform)

# # Print class-to-index mapping
# print("✅ Class to index mapping:")
# print(dataset.class_to_idx)


# from torchvision import datasets, transforms, models
# from torch.utils.data import DataLoader
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from collections import Counter

# # Paths
# DATA_DIR = "data_clarity"
# MODEL_PATH = "clarity_classifier.pt"

# # Transforms
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(10),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # Dataset and DataLoader
# dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
# dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# # Check class mapping
# print("Class to index mapping:", dataset.class_to_idx)
# # Example: {'clear': 0, 'unclear': 1}

# # Model setup
# model = models.resnet18(pretrained=True)
# model.fc = nn.Linear(model.fc.in_features, 1)  # Binary classification
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = model.to(device)

# # Loss and optimizer
# criterion = nn.BCEWithLogitsLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# # Training loop
# for epoch in range(10):
#     model.train()
#     running_loss = 0.0
#     for inputs, labels in dataloader:
#         inputs = inputs.to(device)
#         labels = labels.float().unsqueeze(1).to(device)  # BCE expects float labels

#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()

#     print(f"Epoch {epoch+1}/10 - Loss: {running_loss / len(dataloader):.4f}")

# # Save model
# torch.save(model, MODEL_PATH)
# print(f"✅ Model saved to {MODEL_PATH}")