import os
from PIL import Image
import torch
import torch.nn.functional as F
import timm
import torchvision.transforms as T
import json

# -----------------------------
# Helper functions
# -----------------------------
def get_all_image_paths(root_dir, extensions={".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}):
    paths = []
    for subdir, _, files in os.walk(root_dir):
        for f in files:
            if os.path.splitext(f)[1].lower() in extensions:
                paths.append(os.path.join(subdir, f))
    return paths

ID2LABEL_JSON = "id2label.json" 

with open(ID2LABEL_JSON, "r", encoding="utf-8") as f:
    id_to_species = json.load(f)


# -----------------------------
# Load FungiTastic via timm
# -----------------------------
model_name = "BVRA/vit_base_patch16_384.in1k_ft_fungitastic_384"
model = timm.create_model(
    f"hf-hub:{model_name}",
    pretrained=True
)
model.eval()

# -----------------------------
# Define image transforms
# -----------------------------
transform = T.Compose([
    T.Resize((224, 224)),  
    T.ToTensor(),
    T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# -----------------------------
# Directory with images
# -----------------------------
fungi_root = "./"
paths = get_all_image_paths(fungi_root)

# -----------------------------
# Predict loop
# -----------------------------
for img_path in paths:
    try:
        img = Image.open(img_path).convert("RGB")
        x = transform(img).unsqueeze(0)  # add batch dimension

        with torch.no_grad():
            logits = model(x)
            probs = F.softmax(logits, dim=-1)[0]
            class_id = probs.argmax().item()
            prob = probs[class_id].item()

            species_name = id_to_species.get(str(class_id), f"Unknown_{class_id}")

            print(f"{img_path:<30} → {species_name:<40} ({prob:.3f})")


    except Exception as e:
        print("Error with", img_path, e)
