import gradio as gr
import whisper
from PIL import Image
import cv2
import numpy as np
import torch

# Auto-detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

def process_video(video, model_name, num_frames):
    if video is None:
        return "No video uploaded", None

    # Load Whisper model (only downloads once per size)
    print(f"Loading Whisper model: {model_name}...")
    model = whisper.load_model(model_name, device=device)

    # Transcribe directly from video file (Whisper extracts audio itself)
    print("Transcribing...")
    result = model.transcribe(video, language=None, word_timestamps=False)

    # Build nice timestamped transcription
    transcription = ""
    for seg in result["segments"]:
        start = seg["start"]
        end = seg["end"]
        text = seg["text"].strip()
        transcription += f"[{int(start//60):02d}:{start%60:05.2f} → {int(end//60):02d}:{end%60:05.2f}] {text}\n"

    # Extract evenly spaced frames
    print("Extracting frames...")
    cap = cv2.VideoCapture(video)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    duration = total_frames / fps if fps > 0 else 0

    images_with_captions = []

    if duration > 0 and num_frames > 0:
        interval = duration / (num_frames + 1)
        for i in range(1, num_frames + 1):
            timestamp = i * interval
            frame_id = int(fps * timestamp)
            cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, frame_id))
            ret, frame = cap.read()
            if ret:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                pil_img = Image.fromarray(frame_rgb)
                caption = f"{timestamp:.1f}s ({int(timestamp//60):02d}:{timestamp%60:04.1f})"
                images_with_captions.append((pil_img, caption))

    cap.release()

    return transcription, images_with_captions


# Gradio UI
with gr.Blocks(title="Video → Transcription + Stills") as demo:
    gr.Markdown("# 🎬 Video to Transcription + Key Stills (100% Local)")
    gr.Markdown("Upload any video → get perfectly timestamped subtitles + beautiful representative images.")

    with gr.Row():
        video_input = gr.Video(label="Upload Video")
    
    with gr.Row():
        model_choice = gr.Dropdown(
            choices=["tiny", "base", "small", "medium", "large-v3", "turbo"],
            value="turbo",
            label="Whisper Model (turbo = best speed/quality balance)"
        )
        num_frames_slider = gr.Slider(5, 100, value=25, step=1, label="Number of still images")

    btn = gr.Button("🚀 Process Video", variant="primary")

    transcription_output = gr.Textbox(label="Timestamped Transcription", lines=20)
    gallery = gr.Gallery(label="Key Frames (evenly spaced)", columns=5, height="auto")

    btn.click(
        fn=process_video,
        inputs=[video_input, model_choice, num_frames_slider],
        outputs=[transcription_output, gallery]
    )

    gr.Markdown("""
    ### Recommended settings
    - Short videos (<5 min): use **turbo** or **large-v3** + 30–50 frames
    - Long videos (1h+): use **small** or **base** for speed
    - Best quality: **large-v3** (needs good GPU + 10–12GB VRAM)
    """)

demo.launch(share=True)  # Set share=True if you want a public link
