import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from collections import Counter
import timm

# Paths
DATA_DIR = "data_obstacle"
MODEL_PATH = "obstacle_classifier.pt"

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(5),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load dataset
dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
print("✅ Class mapping:", dataset.class_to_idx)

# Label mapping: no_obstacle → 0, obstacle_present → 1
label_map = {
    dataset.class_to_idx['no_obstacle']: 0,
    dataset.class_to_idx['obstacle_present']: 1
}
labels = [label_map[label] for _, label in dataset.samples]
label_counts = Counter(labels)
print("📊 Class distribution:", label_counts)

# Weighted sampling
weights = [1.0 / label_counts[label] for label in labels]
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, stratify=labels, random_state=42)

train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)

train_sampler = WeightedRandomSampler([weights[i] for i in train_idx], len(train_idx), replacement=True)
train_loader = DataLoader(train_dataset, batch_size=16, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=16)

# Model setup using timm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=1)
model = model.to(device)

# Loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Evaluation function with prediction logging
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = torch.tensor([label_map[label.item()] for label in y]).float().unsqueeze(1).to(device)
            out = model(x)
            pred = (torch.sigmoid(out) > 0.5).float()
            correct += (pred == y).sum().item()
            total += y.size(0)
            for i in range(min(3, x.size(0))):
                print(f"🧪 Pred: {pred[i].item():.2f}, Label: {y[i].item():.0f}")
    return correct / total

# Training loop
for epoch in range(10):
    model.train()
    total_loss = 0
    for x, y in train_loader:
        x = x.to(device)
        y = torch.tensor([label_map[label.item()] for label in y]).float().unsqueeze(1).to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    val_acc = evaluate(model, val_loader)
    print(f"📦 Epoch {epoch+1}/10 - Loss: {total_loss/len(train_loader):.4f} - Val Acc: {val_acc:.4f}")

# Save model
torch.save(model.state_dict(), MODEL_PATH)
print(f"✅ Model weights saved to {MODEL_PATH}")







# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torchvision import datasets, transforms, models
# from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
# from sklearn.model_selection import train_test_split
# from collections import Counter

# # Paths
# DATA_DIR = "data_obstacle"
# MODEL_PATH = "obstacle_classifier.pt"

# # Transforms
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(5),
#     transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
#     transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # Load dataset
# dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
# print("✅ Class mapping:", dataset.class_to_idx)

# # Flip labels: obstacle_present → 0, no_obstacle → 1
# label_map = {
#     dataset.class_to_idx['obstacle_present']: 0,
#     dataset.class_to_idx['no_obstacle']: 1
# }
# labels = [label_map[label] for _, label in dataset.samples]
# label_counts = Counter(labels)
# print("📊 Class distribution:", label_counts)

# # Weighted sampling
# weights = [1.0 / label_counts[label] for label in labels]
# indices = list(range(len(dataset)))
# train_idx, val_idx = train_test_split(indices, test_size=0.2, stratify=labels, random_state=42)

# train_dataset = Subset(dataset, train_idx)
# val_dataset = Subset(dataset, val_idx)

# train_sampler = WeightedRandomSampler([weights[i] for i in train_idx], len(train_idx), replacement=True)
# train_loader = DataLoader(train_dataset, batch_size=16, sampler=train_sampler)
# val_loader = DataLoader(val_dataset, batch_size=16)

# # Model setup
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = models.resnet18(pretrained=True)
# model.fc = nn.Linear(model.fc.in_features, 1)
# model = model.to(device)

# # Loss and optimizer
# criterion = nn.BCEWithLogitsLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# # Evaluation function with prediction logging
# def evaluate(model, loader):
#     model.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for x, y in loader:
#             x = x.to(device)
#             y = torch.tensor([label_map[label.item()] for label in y]).float().unsqueeze(1).to(device)
#             out = model(x)
#             pred = (torch.sigmoid(out) > 0.5).float()
#             correct += (pred == y).sum().item()
#             total += y.size(0)
#             # Log sample predictions
#             for i in range(min(3, x.size(0))):
#                 print(f"🧪 Pred: {pred[i].item():.2f}, Label: {y[i].item():.0f}")
#     return correct / total

# # Training loop
# for epoch in range(10):
#     model.train()
#     total_loss = 0
#     for x, y in train_loader:
#         x = x.to(device)
#         y = torch.tensor([label_map[label.item()] for label in y]).float().unsqueeze(1).to(device)
#         optimizer.zero_grad()
#         out = model(x)
#         loss = criterion(out, y)
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
#     val_acc = evaluate(model, val_loader)
#     print(f"📦 Epoch {epoch+1}/10 - Loss: {total_loss/len(train_loader):.4f} - Val Acc: {val_acc:.4f}")

# # Save model
# torch.save(model, MODEL_PATH)
# print(f"✅ Model saved to {MODEL_PATH}")





# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torchvision import datasets, transforms, models
# from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
# from sklearn.model_selection import train_test_split
# from collections import Counter

# # Paths
# DATA_DIR = "data_obstacle"
# MODEL_PATH = "obstacle_classifier.pt"

# # Transforms
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(10),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# # Load dataset
# dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
# print("✅ Class mapping:", dataset.class_to_idx)

# # Flip labels: obstacle_present → 0, no_obstacle → 1
# label_map = {dataset.class_to_idx['obstacle_present']: 1, dataset.class_to_idx['no_obstacle']: 0}
# labels = [label_map[label] for _, label in dataset.samples]
# label_counts = Counter(labels)
# print("📊 Class distribution:", label_counts)

# # Weighted sampling
# weights = [1.0 / label_counts[label_map[label]] for _, label in dataset.samples]
# indices = list(range(len(dataset)))
# train_idx, val_idx = train_test_split(indices, test_size=0.2, stratify=labels, random_state=42)

# train_dataset = Subset(dataset, train_idx)
# val_dataset = Subset(dataset, val_idx)

# train_sampler = WeightedRandomSampler([weights[i] for i in train_idx], len(train_idx), replacement=True)
# train_loader = DataLoader(train_dataset, batch_size=16, sampler=train_sampler)
# val_loader = DataLoader(val_dataset, batch_size=16)

# # Model setup
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = models.resnet18(pretrained=True)
# model.fc = nn.Linear(model.fc.in_features, 1)
# model = model.to(device)

# # Loss and optimizer
# criterion = nn.BCEWithLogitsLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# # Evaluation function
# def evaluate(model, loader):
#     model.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for x, y in loader:
#             x = x.to(device)
#             y = torch.tensor([label_map[label.item()] for label in y]).float().unsqueeze(1).to(device)
#             out = model(x)
#             pred = (torch.sigmoid(out) > 0.5).float()
#             correct += (pred == y).sum().item()
#             total += y.size(0)
#     return correct / total

# # Training loop
# for epoch in range(10):
#     model.train()
#     total_loss = 0
#     for x, y in train_loader:
#         x = x.to(device)
#         y = torch.tensor([label_map[label.item()] for label in y]).float().unsqueeze(1).to(device)
#         optimizer.zero_grad()
#         out = model(x)
#         loss = criterion(out, y)
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
#     val_acc = evaluate(model, val_loader)
#     print(f"📦 Epoch {epoch+1}/10 - Loss: {total_loss/len(train_loader):.4f} - Val Acc: {val_acc:.4f}")

# # Save model
# torch.save(model, MODEL_PATH)
# print(f"✅ Model saved to {MODEL_PATH}")





#######################################################################
# from torchvision import datasets, transforms, models
# from torch.utils.data import DataLoader
# import torch, torch.nn as nn, torch.optim as optim

# DATA_DIR = "data_obstacle"
# MODEL_PATH = "obstacle_classifier.pt"


# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(10),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
# dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# model = models.resnet18(pretrained=True)
# model.fc = nn.Linear(model.fc.in_features, 1)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = model.to(device)

# criterion = nn.BCEWithLogitsLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# for epoch in range(10):
#     model.train()
#     for inputs, labels in dataloader:
#         inputs, labels = inputs.to(device), labels.float().unsqueeze(1).to(device)
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()

# torch.save(model, MODEL_PATH)
# print(f"✅ Saved to {MODEL_PATH}")
