# coding=utf-8
# +----------------------------------------------------------------------+
# | 波特智控 [ 以价值驱动应用, 以AI赋能控制, 让流程工业从稳态迈向自优化 ]          |
# +----------------------------------------------------------------------+
# | Copyright (c) 2020~2025 https://www.sdqbtech.com All rights reserved.|
# +----------------------------------------------------------------------+
# | Licensed 波特智控并不是自由软件，未经许可不得使用                           |
# +----------------------------------------------------------------------+
# | Author: 波特智控研究团队 <bodecontrol-team@sdqbtech.com>                |
# +----------------------------------------------------------------------+


import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, wait as wait_futures
from datetime import datetime, timedelta
from typing import Dict, Callable, Type, Tuple

from apscheduler.schedulers.background import BackgroundScheduler
from django.core.cache.backends.base import BaseCache
from django_redis.cache import RedisCache
from loguru import logger
from redis.client import Redis

from qbtools.datetime import get_zoneinfo
from qbtools.datetime import app_now_time
from qbtools.datetime import get_p_time, get_f_time
from qbtools.django.network import get_local_ips

__all__ = ["BaseTask", "TaskAdmin"]


class BaseTask:
    def __init__(self, name=None) -> None:
        self.is_alive = False
        self.name = self.__class__.__name__ if name is None else name
        self.desc = ""
        self.start_time = get_p_time("1979-12-12 00:00:00", zoninfo=get_zoneinfo())
        self.stop_time = get_p_time("1979-12-12 00:00:00", zoninfo=get_zoneinfo())
        self.stoped_event = threading.Event()
        self.stoped_event.clear()

    def state_payload(self):
        key = f"worker.{self.name}.{os.getpid()}"
        payload = {
            "pid": os.getpid(),
            "host": get_local_ips(),
            "start_time": get_f_time(self.start_time),
            "stop_time": get_f_time(self.stop_time),
            "updated": int(time.time()),
        }
        return key, payload

    def start(self):
        try:
            self.is_alive = True
            self.start_time = app_now_time()
            self.run()
        except Exception as e:
            logger.exception(e)

        self.is_alive = False
        self.stop_time = app_now_time()

    def run(self):
        raise NotImplementedError

    def shutdown(self, timeout: int = 10):
        self.stop()
        self.is_alive = False
        signaled = self.stoped_event.wait(timeout=timeout)
        if not signaled:
            logger.warning(f"Task {self.name} stop timeout after {timeout} seconds")
        
        self.stop_time = app_now_time()
    def stop(self):
        pass

# --------------------------------------------------
# TaskAdmin 5  单例模式 + 管理进程状态
# --------------------------------------------------
class TaskAdmin:
    # ---------- 单例存储 ----------
    _instance: "TaskAdmin | None" = None
    _started: bool = False
    _redis_client: Redis | BaseCache | None = None
    
    _scheduler: BackgroundScheduler | None = BackgroundScheduler()
    _scheduler_config: dict | None = None
    _scheduler_prefix: str | None = None

    
    _task_status: Dict[str, str] = {}  # name -> 状态
     # (class_path, init_kwargs, job_kwargs)
    _wait_queue: Dict[str, Tuple[str, dict, dict]] = ({}) 

    _is_ready: bool = False

    # 这个任务是否被启用，需要手动设置，默认启用
    _task_cfg: dict = {}
    _task_map: Dict[str, BaseTask] = {}
    _lock = threading.RLock()
    _shutdown_event = threading.Event()

    # ---------- 单例构造 ----------
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)

        return cls._instance

    # 任务注册
    @classmethod
    def register(cls, **d_kwargs) -> Callable[[Type[BaseTask]], Type[BaseTask]]:
        """类装饰器：@TaskAdmin.task(trigger='interval', seconds=60, init={'name':'X'})"""
        init_kwargs = d_kwargs.pop("init", {}) or {}

        def decorator(task_cls: Type[BaseTask]) -> Type[BaseTask]:
            if not issubclass(task_cls, BaseTask):
                raise TypeError(f"{task_cls} must extend BaseTask")
            # instance = task_cls(**init_kwargs)
            # cls._register_instance(task=instance, job_kwargs=d_kwargs, init_kwargs=init_kwargs)
            class_path = f"{task_cls.__module__}:{task_cls.__qualname__}"
            cls._register_class(class_path, init_kwargs=init_kwargs, job_kwargs=d_kwargs)
            return task_cls

        return decorator

    
    @classmethod
    def _register_class(cls, class_path: str, init_kwargs: dict, job_kwargs: dict):
        """
        等价于旧的 _register_instance，但只记录类路径与参数，不创建实例
        """
        default = dict(
            trigger="date",
            run_date=datetime.now() + timedelta(seconds=5),
            coalesce=True,
            replace_existing=True,
        )
        default.update(job_kwargs)

        logger.debug(f"Task {class_path} added to wait queue")
        cls._wait_queue[class_path] = (class_path, init_kwargs, default)

    @classmethod
    def set_scheduler(cls, cfg: dict, prefix=None):
        cls._scheduler_config = cfg
        cls._scheduler_prefix = prefix
        cls._ensure_scheduler()
        cls._scheduler.configure(gconfig=cfg, prefix=prefix)

    @classmethod
    def _ensure_scheduler(cls):
        if cls._scheduler is None:
            cls._scheduler = BackgroundScheduler()
            if cls._scheduler_config is not None:
                cls._scheduler.configure(
                    gconfig=cls._scheduler_config, prefix=cls._scheduler_prefix
                )

    @classmethod
    def start(cls, ready=True) -> bool:

        if cls._started:
            return True

        cls._shutdown_event.clear()
        cls._ensure_scheduler()
        if cls._scheduler is None:
            raise RuntimeError("TaskAdmin scheduler is not initialized")

        # 心跳作业
        cls._scheduler.add_job(
            cls._heartbeat,  # type: ignore
            trigger="date",
            run_date=datetime.now(),
            id="TaskAdminHeartbeat",
            executor="thread"
        )
        cls._scheduler.start(paused=False)  # type: ignore
        cls._started = True
        
        if ready:
            cls.set_ready()

        return True

    @classmethod
    def shutdown(cls, wait: bool = False):
        """
        关闭调度器；wait=True 时等所有正在执行的 job 跑完
        """
        cls._shutdown_event.set()

        with cls._lock:
            tasks = list(cls._task_map.items())
            cls._task_map.clear()

        if tasks:
            max_workers = max(1, len(tasks))
            executor = ThreadPoolExecutor(max_workers=max_workers)
            futures = [
                executor.submit(cls._stop_task, name, task)
                for name, task in tasks
            ]
            wait_futures(futures)
            executor.shutdown(wait=False)

        if getattr(cls, "_started", False) and cls._scheduler is not None:
            try:
                cls._scheduler.shutdown(wait=wait)
                logger.info("TaskAdmin scheduler shut down")
            except Exception as exc:  # noqa: BLE001
                logger.exception(f"TaskAdmin scheduler shutdown failed: {exc}")
            finally:
                cls._started = False
                cls._scheduler = None

    @classmethod
    def set_ready(cls):
        """在 AppConfig.ready() 里调用"""
        cls._is_ready = True
        cls._flush_wait_queue()

    # ---------- flush 等待队列 ----------
    @classmethod
    def _flush_wait_queue(cls):
        if not cls._started:
            return False
        if not cls._is_ready:
            return False

        while cls._wait_queue:
            class_path, init_kwargs, job_kwargs = cls._wait_queue.popitem()[1]
            cls._add_job_to_scheduler(class_path, init_kwargs, job_kwargs)

        return True

    @classmethod
    def _add_job_to_scheduler(
            cls, class_path: str, init_kwargs: dict, job_kwargs: dict
    ):
        # 根据类名判断开关
        task_name = class_path.rsplit(":", 1)[-1]
        if not cls._task_cfg.get(task_name, True):
            logger.warning(f"Task {task_name} disabled by config")
            return False

        # —— job_kwargs 中可指定 executor=threads / default 等 ——

        # 生成最终 kwargs：传给 run_task
        runner_kwargs = {
            "class_path": class_path,
            "init_kwargs": init_kwargs,
            # 如需调用非 start 方法，可在 job_kwargs 里额外提供 method/call_kwargs
            "method": job_kwargs.pop("method", "start"),
            "call_kwargs": job_kwargs.pop("call_kwargs", {}),
            "job_id": task_name,
        }

        try:
            cls._scheduler.add_job(
                "qbtools.task.task_runner:run_task",  # 统一运行器
                id=task_name,
                kwargs=runner_kwargs,
                **job_kwargs,
            )
            cls._task_status[task_name] = "scheduled"
            logger.info(
                f"Task {task_name} scheduled (executor={job_kwargs.get('executor', 'default')})"
            )
            return True
        except Exception as e:
            logger.exception(f"Add job failed for {task_name}: {e}")
            return False

    # ==================================================
    # ---------- 对单个任务的控制 ------------------------
    # ==================================================
    @classmethod
    def pause_task(cls, name: str):
        """暂停指定任务(下一次触发被停用)"""
        cls._scheduler.pause_job(job_id=name)
        cls._task_status[name] = "paused"
        logger.info(f"Task {name} paused")

    @classmethod
    def resume_task(cls, name: str):
        """恢复已暂停的任务"""
        cls._scheduler.resume_job(job_id=name)
        cls._task_status[name] = "scheduled"
        logger.info(f"Task {name} resumed")

    @classmethod
    def remove_task(cls, name: str):
        """完全移除任务"""
        cls._scheduler.remove_job(job_id=name)
        cls._task_status[name] = "removed"
        logger.info(f"Task {name} removed")

    # -------- 全局控制 --------
    @classmethod
    def pause_all(cls):
        cls._scheduler.pause()
        logger.warning("Scheduler paused (all tasks)")
        for k in cls._task_status.keys():
            if cls._task_status[k] not in ("removed", "error"):
                cls._task_status[k] = "paused"

    @classmethod
    def resume_all(cls):
        cls._scheduler.resume()
        logger.warning("Scheduler resumed")
        for k in cls._task_status.keys():
            if cls._task_status[k] == "paused":
                cls._task_status[k] = "scheduled"

    @classmethod
    def get_all_task_status(cls):
        res = cls._task_status.copy()
        # 等待队列中的仍显示 waiting
        for class_path, _, _ in cls._wait_queue.values():
            name = class_path.rsplit(":", 1)[-1]
            res[name] = "waiting"
        return res

    # ---------- 心跳 + 自动 flush ----------
    @classmethod
    def _heartbeat(cls):

        while not cls._shutdown_event.is_set():
            try:
                # 1) 先把等待队列搬进 scheduler
                cls._flush_wait_queue()

                if not cls._is_ready:
                    logger.warning(f"Heartbeat: TaskAdmin is not ready.")

                # 2) 正常心跳到redis缓存
                if cls._redis_client is None:
                    continue

                data = {f"task.{k}.status": v for k, v in cls._task_status.items()}
                if isinstance(cls._redis_client, RedisCache):
                    cls._redis_client.set_many(data, timeout=3)
                else:
                    for key, val in data.items():
                        cls._redis_client.set(key, val)

            except Exception as e:
                logger.critical(e)
            finally:
                cls._shutdown_event.wait(timeout=1)

        logger.info("TaskAdmin heartbeat stopped")

    @classmethod
    def set_redis(cls, client: Redis | BaseCache):
        cls._redis_client = client

    @classmethod
    def set_task_cfg(cls, task_cfg: dict):
        cls._task_cfg = task_cfg

    @classmethod
    def _register_task_instance(cls, job_id: str | None, task: BaseTask):
        if job_id is None:
            return
        with cls._lock:
            cls._task_map[job_id] = task

    @classmethod
    def _unregister_task_instance(cls, job_id: str | None):
        if job_id is None:
            return
        with cls._lock:
            cls._task_map.pop(job_id, None)

    @classmethod
    def _stop_task(cls, name: str, task: BaseTask):
        try:
            shutdown_fn = getattr(task, "shutdown", None)
            if callable(shutdown_fn):
                shutdown_fn()
            else:
                task.stop()
            with cls._lock:
                cls._task_status[name] = "stopped"
            logger.info(f"TaskAdmin stopped task:{name}")
        except Exception as exc:  # noqa: BLE001
            logger.exception(f"TaskAdmin stop task:{name} failed: {exc}")
