# coding=utf8
import json
import time
import websocket
from datetime import datetime


def log(message, *args, sep=' '):
    print(f'{datetime.now()} {message}', *args, sep=sep)


class BaiduVoiceCloneWebSocketSDK:
    """
    初始化 SDK 实例
    :param access_token: 鉴权令牌
    :param voice_id: 音色ID参数
    :param base_url: WebSocket 服务的基础 URL
    """

    def __init__(
            self,
            authorization: str,
            voice_id: int = 102630,
            base_url: str = "wss://aip.baidubce.com/ws/2.0/speech/publiccloudspeech/v1/voice/clone/tts",
            timeout: int = 10,
    ):
        self.authorization = authorization
        self.voice_id = voice_id
        self.base_url = base_url
        self.timeout = timeout
        self.ws = None
        self.output_file = None
        self.last_error = None
        self.is_connected = False
        self.audio_data = bytearray()  # 存储音频数据

    def connect(self):
        """建立 WebSocket 连接（失败时直接抛出异常）"""
        url = f"{self.base_url}?access_token={self.authorization}&voice_id={self.voice_id}"
        log(f"Connecting to: {url}")

        headers = {
            # 'Authorization': f'Bearer {self.authorization}',
        }

        try:
            log("==== Connecting ====")
            self.ws = websocket.WebSocket()
            self.ws.connect(url, header=headers, timeout=self.timeout)
            self.is_connected = True
            log("==== Connected ====")
        except websocket.WebSocketBadStatusException as e:
            self.last_error = f"{e.status_code}, {e.resp_body.decode('utf-8')}"
            self.is_connected = False
            raise RuntimeError(self.last_error)
        except Exception as e:
            self.last_error = f"Connect Failed: {str(e)}"
            self.is_connected = False
            raise RuntimeError(self.last_error)

    def receive_audio(self, timeout=10):
        """接收音频数据"""
        if not self.is_connected:
            raise RuntimeError("WebSocket Not Connected")

        start_time = time.time()
        self.audio_data = bytearray()  # 清空之前的音频数据

        try:
            while time.time() - start_time < timeout:
                try:
                    # 设置超时时间为1秒
                    message = self.ws.recv()

                    if isinstance(message, bytes):
                        now = time.time()
                        log(f'==== Received {len(message):6d} bytes ====')
                        self.audio_data.extend(message)
                    else:
                        # JSON 消息
                        data = json.loads(message)
                        log('==== Received message ====', json.dumps(data, indent=2), sep='\n')

                        # 检查是否结束
                        if data.get("type") == "system.finish":
                            break

                        # 检查错误
                        if data.get("type") == "system.error":
                            error_code = data.get("code", -1)
                            error_msg = data.get("message", "未知错误")
                            self.last_error = f"Server Error: {error_code} - {error_msg}"
                            raise RuntimeError(self.last_error)

                except websocket.WebSocketTimeoutException:
                    continue  # 继续等待
                except Exception as e:
                    raise RuntimeError(f"接收数据错误: {str(e)}")

            # 保存音频文件
            if self.audio_data and self.output_file:
                log(f"==== Start Save Audio ====")
                with open(self.output_file, "wb") as f:
                    f.write(self.audio_data)
                log(f"=== Saved Audio Done: {self.output_file} ====")

        except Exception as e:
            self.last_error = str(e)
            raise

    def send_json(self, payload: dict):
        """发送 JSON 数据（仅在连接成功时发送）"""
        if not self.is_connected:
            raise RuntimeError("WebSocket Not Connected")

        try:
            self.ws.send(json.dumps(payload))
            log("Send text:", json.dumps(payload, indent=2, ensure_ascii=False))
        except Exception as e:
            self.last_error = f"Send Text Failed: {str(e)}"
            raise RuntimeError(self.last_error)

    def send_start_request(self, spd=5, pit=5, vol=5, aue=3,
                           dialect="wuu-CN-shanghai", emotion="happy", sample_rate=24000):
        """
        发送开始合成请求
        :param spd: 语速，默认值为 5
        :param pit: 音调，默认值为 5
        :param vol: 音量，默认值为 5
        :param aue: 音频格式，默认值为 3 (mp3)
        :param dialect: 方言控制参数，默认值为 wuu-CN-shanghai（上海话）
        :param emotion: 情绪控制参数，默认值为 happy（开心）
        :param sample_rate: 采样率控制参数，默认值为 24000
        """
        payload = {
            "type": "system.start",
            "payload": {
                "spd": spd,
                "pit": pit,
                "vol": vol,
                "aue": aue,
                "dialect": dialect,
                "emotion": emotion,
                "sample_rate": sample_rate
            },
        }
        self.send_json(payload)

    def send_text_request(self, text: str):
        """发送文本合成请求"""
        payload = {
            "type": "text",
            "payload": {
                "text": text,
            },
        }
        self.send_json(payload)

    def send_finish_request(self):
        """发送结束合成请求"""
        payload = {
            "type": "system.finish",
        }
        self.send_json(payload)
        log("Sent system.finish")

    def synthesize(self, text: str, output_file: str = "output.mp3", **kwargs):
        """
        完整语音合成流程
        """
        self.output_file = output_file
        self.last_error = None

        try:
            # 1. 连接 WebSocket
            self.connect()

            # 2. 发送开始请求（支持传入dialect/emotion/sample_rate等参数）
            self.send_start_request(**kwargs)

            # 3. 发送文本
            self.send_text_request(text)

            # 4. 接收音频数据
            self.receive_audio(timeout=10)

            # 5. 发送结束请求
            self.send_finish_request()

        except Exception as e:
            if self.last_error:
                log(f"{self.last_error}")
            else:
                log(f"{str(e)}")
        finally:
            if self.ws:
                self.ws.close()
                self.is_connected = False


if __name__ == "__main__":
    AUTHORIZATION = "YOUR_TOKEN/YOUR_API_KEY"  # IAM API_KEY或TOKEN二选一
    VOICE_ID = 102630  # 替换为你的音色ID参数

    sdk = BaiduVoiceCloneWebSocketSDK(
        authorization=AUTHORIZATION,
        voice_id=VOICE_ID,
    )

    """
        完整的语音合成流程：连接 -> 开始合成 -> 发送文本 -> 接收音频 -> 结束合成 -> 关闭连接
        :param text: 需要合成的文本
        :param output_file: 保存音频的文件名
        :param spd: 语速
        :param pit: 音调
        :param vol: 音量
        :param aue: 音频格式
        :param dialect: 方言控制参数
        :param emotion: 情绪控制参数
        :param sample_rate: 采样率控制参数
    """
    sdk.synthesize(
        text="欢迎使用百度语音合成服务。",
        output_file="output.mp3",
        spd=5,
        pit=5,
        vol=5,
        aue=3,
        dialect="wuu-CN-shanghai",
        emotion="happy",
        sample_rate=16000
    )