from __future__ import annotations

from typing import Type, TypeVar, cast, Callable, Any
from opcua import Client, ua
import re
import time

from qbtools.json import cvt_to_bool
from loguru import logger

T = TypeVar("T", int, float, str, bool)


class OpcConnectionError(RuntimeError):
    """与 OPC 服务器连接/通信异常（可重连但最终失败）"""


class OpcNodeNotFoundError(RuntimeError):
    """节点不存在（已确认是 NodeId 问题，而不是连接问题）"""


class OpcClient(object):
    def __init__(
        self,
        endpoint: str,
        namespace: int,
        identifier_prefix: str,
        readonly: bool = False,
        timeout:int =1
    ):
        self.endpoint = endpoint
        self.namespace = namespace
        self.readonly = readonly
        self.identifier_prefix = identifier_prefix

        self.idtmap = ["i", "s", "g", "b"]
        self.d_types = [float, int, str, bool]
        self.timeout = timeout

        self._last_connect_error = False

        # 只在这里 new 一次 Client，后面重连都复用这个实例
        self.client: Client = Client(self.endpoint, timeout=self.timeout)
        self._connect_initial()

    # -------------------------
    # 连接维护
    # -------------------------

    def _connect_initial(self) -> None:
        try:
            self.client.connect()
        except Exception as e:
            raise OpcConnectionError(
                f"OPC 初始连接失败: endpoint={self.endpoint}, err={e}"
            ) from e

    def _disconnect_quietly(self) -> None:
        try:
            self.client.disconnect()
        except Exception:
            pass

    def _reconnect(self) -> None:
        logger.warning(f"OPC 连接异常，1秒后尝试重连: endpoint={self.endpoint}")
        time.sleep(1)
        # 关键点：不要替换 self.client，只在同一个实例上 disconnect + connect
        self._disconnect_quietly()
        try:
            self.client.connect()
            if self._last_connect_error:
                logger.info(f"OPC重连成功: endpoint={self.endpoint}")
                self._last_connect_error = False
        except Exception as e:
            self._last_connect_error = True
            raise OpcConnectionError(
                f"OPC 重连失败: endpoint={self.endpoint}, err={e}"
            ) from e

    def _get_status_code_value(self, err: Exception) -> int | None:
        # python-opcua: UaStatusCodeError / BadXXX 通常有 .code (int)
        code = getattr(err, "code", None)
        if code is not None:
            try:
                return int(code)
            except Exception:
                pass

        # 兼容极少数实现：可能叫 status
        status = getattr(err, "status", None)
        if status is not None:
            try:
                return int(status)
            except Exception:
                pass

        # 兜底：有些异常把 code 放在 args[0]
        if getattr(err, "args", None):
            try:
                return int(err.args[0])
            except Exception:
                pass

        return None

    def _is_connection_like_error(self, err: Exception) -> bool:
        # 1) 常规网络错误
        if isinstance(err, (OSError, ConnectionError, TimeoutError)):
            return True

        # 2) python-opcua 有时会抛这个 AttributeError，说明 UASocketClient._socket 为 None
        if isinstance(err, AttributeError):
            msg = str(err)
            if "NoneType" in msg and "write" in msg:
                return True

        # 3) UaStatusCodeError / BadXXX 异常：使用你之前实现的 _get_status_code_value
        if isinstance(err, ua.UaStatusCodeError):
            code = self._get_status_code_value(err)
            if code is None:
                return False

            bad = ua.StatusCodes
            conn_like = {
                bad.BadSessionIdInvalid,
                bad.BadSessionClosed,
                bad.BadSecureChannelIdInvalid,
                bad.BadSecureChannelClosed,
                bad.BadConnectionClosed,
                bad.BadServerNotConnected,
                bad.BadCommunicationError,
                bad.BadTcpInternalError,
                bad.BadTimeout,
            }
            return code in conn_like

        # 4) 字符串兜底匹配
        msg = str(err)
        keywords = (
            "BadSessionIdInvalid",
            "BadSecureChannelIdInvalid",
            "BadConnectionClosed",
            "BadServerNotConnected",
            "Broken pipe",
            "Connection reset",
            "timed out",
        )
        _res = any(k in msg for k in keywords)
        if _res:
            self._last_connect_error = True
        return _res

    def _rpc(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
        """
        包装所有“可能触发 OPC 通信”的操作：
        - 若判定为连接/会话类错误：重连一次并重试一次
        - 仍失败：抛 OpcConnectionError
        """
        try:
            return func(*args, **kwargs)
        except Exception as e:
            if not self._is_connection_like_error(e):
                raise
            # 重连一次后重试一次
            try:
                self._reconnect()
            except OpcConnectionError:
                # 重连本身失败，直接抛出（异常信息已包含 endpoint）
                raise
            try:
                return func(*args, **kwargs)
            except Exception as e2:
                if self._is_connection_like_error(e2):
                    raise OpcConnectionError(
                        f"OPC 通信失败(重试后仍失败): endpoint={self.endpoint}, err={e2}"
                    ) from e2
                raise

    # -------------------------
    # NodeId / Node
    # -------------------------
    def _gen_nodeid(self, identifier: str, identifier_type: str) -> str:
        if identifier_type not in self.idtmap:
            raise ValueError(
                f"未知的 identifier_type={identifier_type}, 可选值={self.idtmap}"
            )

        prefix = f"{self.identifier_prefix}." if self.identifier_prefix else ""
        return f"ns={self.namespace};{identifier_type}={prefix}{identifier}"

    def get_node(self, identifier: str, identifier_type: str = "s"):
        nodeid = self._gen_nodeid(
            identifier=identifier, identifier_type=identifier_type
        )
        self.check_node_exists(nodeid=nodeid)
        return self.client.get_node(nodeid)

    def check_node_exists(self, nodeid: str) -> None:  

        def _check():
            # 读取 NodeClass 属性（轻量）
            node = self.client.get_node(nodeid)
            return node.get_attribute(ua.AttributeIds.NodeClass)

        try:
            self._rpc(_check)
        except ua.UaStatusCodeError as e:
            code = self._get_status_code_value(e)

            if code in {ua.StatusCodes.BadNodeIdUnknown, ua.StatusCodes.BadNodeIdInvalid}:
                raise OpcNodeNotFoundError(f"{nodeid} 不存在: {e}") from e

            if self._is_connection_like_error(e):
                raise OpcConnectionError(f"OPC 连接/会话异常: endpoint={self.endpoint}, err={e}") from e

            raise
        except OpcConnectionError:
            raise
        except Exception as e:
            # 非状态码异常：若像连接问题则明确报连接异常
            if self._is_connection_like_error(e):
                raise OpcConnectionError(
                    f"OPC 连接异常: endpoint={self.endpoint}, err={e}"
                ) from e
            # 否则保留原始异常语义
            raise

    # -------------------------
    # 读写
    # -------------------------
    def read(
        self, identifier: str, data_type: Type[T] = str, identifier_type: str = "s"
    ) -> T:
        if data_type not in self.d_types:
            raise ValueError(f"data_type 参数错误, 可选值={self.d_types}")

        node = self.get_node(identifier=identifier, identifier_type=identifier_type)

        raw_value = self._rpc(node.get_value)
        raw_str = str(raw_value)

        if identifier_type == "s":
            if data_type is str:
                return cast(T, raw_str)
            if data_type is bool:
                return cast(T, cvt_to_bool(raw_str))

            _res = re.findall(r"-*[\d\.]+", raw_str)
            if len(_res) == 0:
                raise ValueError(f"原始值:{raw_str} 中未找到数字，请确认类型是否正确")

            if data_type is float:
                return cast(T, float(_res[-1]))

            if "." in raw_str:
                logger.warning(f"原始值为 float 形式 {raw_str}，强制转换为整数")
            return cast(T, int(float(_res[-1])))

        raise NotImplementedError(
            "仅支持 identifier_type='s' 的读取，其他类型等待完善"
        )
        return cast(T, raw_value)

    def write(
        self,
        identifier: str,
        value: Any,
        data_type: object | None = str,
        identifier_type: str = "s",
    ) -> bool:
        if self.readonly:
            raise PermissionError("当前连接为只读模式, 不允许写入")
        if identifier_type != "s":
            raise NotImplementedError(
                "仅支持 identifier_type='s' 的写入，其他类型等待完善"
            )

        node = self.get_node(identifier=identifier, identifier_type=identifier_type)

        new_val = value
        if data_type is not None:
            if data_type is str:
                new_val = str(value)
            elif data_type is int:
                new_val = int(float(value))
            elif data_type is float:
                new_val = float(value)
            elif data_type is bool:
                new_val = cvt_to_bool(value)
            else:
                raise ValueError(f"数据类型参数错误, 可选值={self.d_types}")

        # 按 new_val 的真实类型选择 VariantType（修复你原先用 value 判断 bool 的边界问题）
        if isinstance(new_val, bool):
            dv = ua.DataValue(ua.Variant(new_val, ua.VariantType.Boolean))
        elif isinstance(new_val, int):
            dv = ua.DataValue(ua.Variant(new_val, ua.VariantType.Int64))
        elif isinstance(new_val, float):
            dv = ua.DataValue(ua.Variant(new_val, ua.VariantType.Double))
        else:
            dv = ua.DataValue(ua.Variant(str(new_val), ua.VariantType.String))

        self._rpc(node.set_value, dv)
        return True

    # -------------------------
    # 生命周期
    # -------------------------
    def close(self):
        self._disconnect_quietly()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
        return False
