# coding=utf-8
# +----------------------------------------------------------------------+
# | 波特智控 [ 以价值驱动应用, 以AI赋能控制, 让流程工业从稳态迈向自优化 ]          |
# +----------------------------------------------------------------------+
# | Copyright (c) 2020~2025 https://www.sdqbtech.com All rights reserved.|
# +----------------------------------------------------------------------+
# | Licensed 波特智控并不是自由软件，未经许可不得使用                           |
# +----------------------------------------------------------------------+
# | Author: 波特智控研究团队 <bodecontrol-team@sdqbtech.com>                |
# +----------------------------------------------------------------------+

from __future__ import annotations

import threading
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Callable, Optional, Literal
from zoneinfo import ZoneInfo

Reason = Literal["deadline", "woken"]


@dataclass(frozen=True)
class Tick:
    reason: Reason  # "deadline" 到点; "woken" 被 wake/stop/外部事件打断
    scheduled: datetime  # 本次计划触发的“自然分钟对齐”时间点（墙钟时间）
    returned_at: datetime  # 实际返回时刻（墙钟时间）
    lateness_s: float  # returned_at - scheduled 的秒数（到点触发时用于观测调度抖动）


class NaturalAlignedTimer:
    """
    对齐规则（核心）：
      以“自然分钟的 00 秒”为锚点，在该分钟内触发秒数为 period 的整数倍：
        :00, :period, :2*period, ... （< 60）
      若越过 60 秒，则落到下一分钟 :00

    例如：
      11:00:02, period=3  -> 下一次 11:00:03
      11:00:06, period=5  -> 下一次 11:00:10
      若 period=40，则每分钟触发 :00 和 :40；11:00:59 -> 下一次 11:01:00

    睡眠：
      使用 threading.Event.wait(timeout)，可被 wake() / stop() 打断。
      为避免 clear/set 竞态导致丢信号，内部用 “序号计数 + event”。
    """

    def __init__(
            self,
            period_seconds: float,
            *,
            now_func: Optional[Callable] = None,
            wake_event: Optional[threading.Event] = None,
            tz="Asia/Shanghai"
    ) -> None:
        self.tz = ZoneInfo(tz)
        if now_func is None:
            now_func = self._default_now

        self._now = now_func
        self._ev = wake_event or threading.Event()

        self._lock = threading.Lock()
        self._wake_seq = 0
        self._wake_consumed = 0
        self._stopped = False

        self._set_period(period_seconds)

        self._last_deadline: Optional[datetime] = None  # 上一次“到点触发”的计划时间（防止重复触发）

    def _default_now(self) -> datetime:
        return datetime.now(self.tz)

    # ---------- 对外接口 ----------

    def stop(self) -> None:
        """停止：让 wait() 立即返回 None。"""
        with self._lock:
            self._stopped = True
            self._wake_seq += 1
            self._ev.set()

    def wake(self) -> None:
        """打断等待：让 wait() 立即返回 reason='woken'。"""
        with self._lock:
            self._wake_seq += 1
            self._ev.set()

    def update_period(self, period_seconds: float) -> None:
        """
        更新周期。下一次 wait() 会基于“当前墙钟时间 + 新周期”计算最近边界。
        注意：不主动 wake，避免你在循环体内更新周期后导致下一次 wait() 立刻被“误唤醒”。
        若你确实在其他线程更新周期且希望立刻生效，可在更新后调用 wake()。
        """
        with self._lock:
            self._set_period(period_seconds)

    def wait(self) -> Optional[Tick]:
        """
        阻塞直到：
          - 到达下一自然分钟对齐边界（reason='deadline'）
          - 被 wake()/stop() 打断（reason='woken'）
          - stop 后直接返回 None
        """
        if self._is_stopped():
            return None

        while True:
            if self._is_stopped():
                return None

            now = self._now_checked()
            period_us = self._get_period_us()

            # 计算下一对齐触发点：>= now 的最近边界
            target = self._next_aligned_deadline(now, period_us)

            # 防止“同一个边界重复触发”（例如业务极短且刚好卡在边界）
            if self._last_deadline is not None and target <= self._last_deadline:
                target = self._advance_aligned(self._last_deadline, period_us)

            # 若计算出来仍在过去（极端情况下时钟跳变/代码执行耗时），继续推进到未来
            while target < now:
                target = self._advance_aligned(target, period_us)

            # 若有未消费的唤醒信号，直接返回 woken（不推进 last_deadline）
            if self._consume_pending_wake():
                returned = self._now_checked()
                return Tick(
                    reason="woken",
                    scheduled=target,
                    returned_at=returned,
                    lateness_s=(returned - target).total_seconds(),
                )

            # 开始等待：在锁内“确认无 pending wake 并清 event，记录 seq 快照”，避免丢信号
            with self._lock:
                if self._stopped:
                    return None
                if self._wake_seq > self._wake_consumed:
                    # 刚拿到锁发现有 wake，直接消费并返回
                    self._wake_consumed = self._wake_seq
                    self._ev.clear()
                    returned = self._now_checked()
                    return Tick(
                        reason="woken",
                        scheduled=target,
                        returned_at=returned,
                        lateness_s=(returned - target).total_seconds(),
                    )

                seq_snapshot = self._wake_seq
                self._ev.clear()

            # 进入等待：为了应对系统时间跳变、spurious wake，采用循环校验
            while True:
                if self._is_stopped():
                    return None

                now2 = self._now_checked()
                remaining = (target - now2).total_seconds()

                # 已到点（或略过点）：到点触发
                if remaining <= 0:
                    self._last_deadline = target
                    returned = self._now_checked()
                    return Tick(
                        reason="deadline",
                        scheduled=target,
                        returned_at=returned,
                        lateness_s=(returned - target).total_seconds(),
                    )

                # 等待剩余时间或被唤醒
                self._ev.wait(remaining)

                if self._is_stopped():
                    return None

                # 检查是否发生 wake（用 seq 判断）
                with self._lock:
                    if self._wake_seq != seq_snapshot:
                        self._wake_consumed = self._wake_seq
                        self._ev.clear()
                        returned = self._now_checked()
                        return Tick(
                            reason="woken",
                            scheduled=target,
                            returned_at=returned,
                            lateness_s=(returned - target).total_seconds(),
                        )
                # 否则可能是 spurious wake 或时间跳变，继续循环重新计算 remaining

    # ---------- 内部：时间/对齐计算 ----------

    def _set_period(self, period_seconds: float) -> None:
        sec = float(period_seconds)
        if sec <= 0:
            raise ValueError("period_seconds must be > 0")
        us = int(round(sec * 1_000_000))
        if us <= 0:
            raise ValueError("period_seconds too small after conversion")
        self._period_us = us

    def _get_period_us(self) -> int:
        with self._lock:
            return self._period_us

    def _is_stopped(self) -> bool:
        with self._lock:
            return self._stopped

    def _consume_pending_wake(self) -> bool:
        with self._lock:
            if self._wake_seq > self._wake_consumed:
                self._wake_consumed = self._wake_seq
                self._ev.clear()
                return True
            return False

    def _now_checked(self) -> datetime:
        now = self._now()
        if now.tzinfo is None:
            raise ValueError("now_func() must return timezone-aware datetime")
        return now

    @staticmethod
    def _minute_start(dt: datetime) -> datetime:
        return dt.replace(second=0, microsecond=0)

    @staticmethod
    def _ceil_div(a: int, b: int) -> int:
        return (a + b - 1) // b

    def _next_aligned_deadline(self, now: datetime, period_us: int) -> datetime:
        """
        计算 >= now 的最近对齐边界（按“每分钟从 :00 重新起算”的规则）。
        """
        m0 = self._minute_start(now)
        delta = now - m0
        in_minute_us = delta.seconds * 1_000_000 + delta.microseconds  # [0, 60s)

        # k = ceil(in_minute / period)，得到本分钟内的下一倍数偏移
        k = self._ceil_div(in_minute_us, period_us) if in_minute_us > 0 else 0
        offset_us = k * period_us

        if offset_us >= 60 * 1_000_000:
            # 本分钟已无边界：落到下一分钟 :00
            return m0 + timedelta(minutes=1)
        return m0 + timedelta(microseconds=offset_us)

    def _advance_aligned(self, aligned_dt: datetime, period_us: int) -> datetime:
        """
        给定一个“已对齐”的触发点，推进到下一个对齐点（同样按分钟内倍数，溢出到下分钟 :00）。
        """
        m0 = self._minute_start(aligned_dt)
        sec_us = (aligned_dt - m0).seconds * 1_000_000 + (aligned_dt - m0).microseconds

        next_offset = sec_us + period_us
        if next_offset >= 60 * 1_000_000:
            return m0 + timedelta(minutes=1)
        return m0 + timedelta(microseconds=next_offset)
