#!/usr/bin/env python3
# Copyright (C) 2025-2026 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import argparse
import numpy as np
import cv2
import openvino_genai
from openvino import Tensor
from pathlib import Path


def streamer(subword: str) -> bool:
    """

    Args:
        subword: sub-word of the generated text.

    Returns: Return flag corresponds whether generation should be stopped.

    """
    print(subword, end="", flush=True)

    # No value is returned as in this example we don't want to stop the generation in this method.
    # "return None" will be treated the same as "return openvino_genai.StreamingStatus.RUNNING".


def read_video(path: str, num_frames: int = 8) -> tuple[Tensor, openvino_genai.VideoMetadata]:
    """

    Args:
        path: The path to the video.
        num_frames: Number of frames used to calculate frames indices for further sampling.

    Returns: tuple of Tensor containing original video and corresponding VideoMetadata.

    """
    cap = cv2.VideoCapture(path)

    total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)

    video_metadata = openvino_genai.VideoMetadata()
    video_metadata.fps = cap.get(cv2.CAP_PROP_FPS)
    # Passing video metadata with frame indices defined enables sampling based on provided indices within the pipeline,
    # and any model-specific sampling logic will be skipped (if defined).
    # Leave frames_indices empty to apply model-specific sampling (e.g. for Qwen3-VL).
    video_metadata.frames_indices = indices.tolist()

    frames = []
    idx = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(np.array(frame))
        idx += 1

    cap.release()
    assert idx == total_num_frames, "Frame count mismatch: expected {}, got {}".format(total_num_frames, idx)

    return Tensor(frames), video_metadata


def read_videos(path: str) -> tuple[list[Tensor], list[openvino_genai.VideoMetadata]]:
    entry = Path(path)
    if entry.is_dir():
        videos = []
        videos_metadata = []
        for file in sorted(entry.iterdir()):
            video, video_metadata = read_video(str(file))
            videos.append(video)
            videos_metadata.append(video_metadata)
        return videos, videos_metadata
    video, video_metadata = read_video(path)
    return [video], [video_metadata]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("model_dir", help="Path to the model directory")
    parser.add_argument("video_dir", help="Path to a video file.")
    parser.add_argument("device", nargs="?", default="CPU", help="Device to run the model on (default: CPU)")
    args = parser.parse_args()

    videos, videos_metadata = read_videos(args.video_dir)

    # GPU and NPU can be used as well.
    # Note: If NPU is selected, only the language model will be run on the NPU.
    enable_compile_cache = dict()
    if args.device == "GPU":
        # Cache compiled models on disk for GPU to save time on the next run.
        # It's not beneficial for CPU.
        enable_compile_cache["CACHE_DIR"] = "vlm_cache"

    pipe = openvino_genai.VLMPipeline(args.model_dir, args.device, **enable_compile_cache)

    config = openvino_genai.GenerationConfig()
    config.max_new_tokens = 100

    history = openvino_genai.ChatHistory()
    prompt = input("question:\n")
    history.append({"role": "user", "content": prompt})
    decoded_results = pipe.generate(
        history, videos=videos, videos_metadata=videos_metadata, generation_config=config, streamer=streamer
    )
    history.append({"role": "assistant", "content": decoded_results.texts[0]})

    while True:
        try:
            prompt = input("\n----------\nquestion:\n")
        except EOFError:
            break

        history.append({"role": "user", "content": prompt})
        # New images and videos can be passed at each turn
        decoded_results = pipe.generate(history, generation_config=config, streamer=streamer)
        history.append({"role": "assistant", "content": decoded_results.texts[0]})


if __name__ == "__main__":
    main()
