#!/usr/bin/env python3

import os
import sys
import logging
import argparse
import mysql.connector

"""
kill 所有数据库上的连接
"""

name = os.path.basename(__file__)
logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s",level=logging.INFO)

def parser_cmd_args() -> argparse.ArgumentParser:
    """
    处理命令行参数
    """
    parser = argparse.ArgumentParser(name)
    parser.add_argument('--host',type=str,default='127.0.0.1',help='mysql host')
    parser.add_argument('--port',type=int,default=3306,help="mysql port")
    parser.add_argument('--user',type=str,default="opsuser",help="mysql user")
    parser.add_argument('--password',type=str,default="mtls0352",help="mysql user's password")
    parser.add_argument('username',type=str,default=None,help="target user's session")
    args = parser.parse_args()
    return args


def kill_all_connections(args):
    cnx = None
    try:
        cnx = mysql.connector.connect(host=args.host,port=args.port,user=args.user,password=args.password)
        cnx.autocommit=True
        cursor = cnx.cursor()
        # 查询出本身的连接 id
        cursor.execute("select connection_id();")
        self_connection_id,*_ = cursor.fetchone()
        #logging.info(f"self connection equal {self_connection_id}")

        # 查询出当前主机上的所有连接
        processlist_ids = []
        if args.username == None:
            cursor.execute("select id from information_schema.processlist where command not in ('Daemon','Binlog Dump') ;")
        else:
            cursor.execute("select id from information_schema.processlist where command not in ('Daemon','Binlog Dump') and user=%s;",(args.username,))
        # 提取出 processlist id 
        for ip,*_ in cursor.fetchall():
            if ip != self_connection_id:
                processlist_ids.append(ip)
        
        # kill 掉每一个 processlist id
        for processlist in processlist_ids:
            try:
                # 不管 kill 是否正常返回，都认为是正常的
                cursor.execute(f"kill {processlist};")
                logging.info(f"kill {processlist};")
            except mysql.connector.Error as err:
                logging.error(f"kill {processlist} failed '{str(err)}' ")

    except mysql.connector.Error as err:
        logging.error(str(err))
    finally:
        if cnx != None and hasattr(cnx,'close'):
            cnx.close()


if __name__ == "__main__":
    args = parser_cmd_args()
    kill_all_connections(args)
