import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from collections import Counter
import timm

# Paths
DATA_DIR = "data_pipe"
MODEL_PATH = "pipe_classifier.pt"

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load dataset
dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
print("✅ Class mapping:", dataset.class_to_idx)

# Label mapping: pipe_present → 1, no_pipe → 0
label_map = {
    dataset.class_to_idx['pipe_present']: 1,
    dataset.class_to_idx['no_pipe']: 0
}
labels = [label_map[label] for _, label in dataset.samples]
label_counts = Counter(labels)
print("📊 Class distribution:", label_counts)

# Weighted sampling
weights = [1.0 / label_counts[label] for label in labels]
indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, stratify=labels, random_state=42)

train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)

train_sampler = WeightedRandomSampler([weights[i] for i in train_idx], len(train_idx), replacement=True)
train_loader = DataLoader(train_dataset, batch_size=16, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=16)

# Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=1)
model = model.to(device)

# Loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Evaluation function
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            mapped_labels = torch.tensor([label_map[label.item()] for label in labels]).float().unsqueeze(1).to(device)
            outputs = model(inputs)
            preds = (torch.sigmoid(outputs) > 0.5).float()
            correct += (preds == mapped_labels).sum().item()
            total += mapped_labels.size(0)
    return correct / total

# Training loop
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        mapped_labels = torch.tensor([label_map[label.item()] for label in labels]).float().unsqueeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, mapped_labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    val_acc = evaluate(model, val_loader)
    print(f"📦 Epoch {epoch+1}/10 - Loss: {running_loss/len(train_loader):.4f} - Val Acc: {val_acc:.4f}")

# Save model weights only
torch.save(model.state_dict(), MODEL_PATH)
print(f"✅ Model weights saved to {MODEL_PATH}")