package com.elitesland.cbpl.tool.websocket.handler;

import cn.hutool.core.text.StrPool;
import cn.hutool.core.util.ObjectUtil;
import com.elitesland.cbpl.tool.websocket.constant.WebSocketInstant;
import com.elitesland.cbpl.tool.websocket.util.WebSocketUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.SockJsTransportFailureException;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

/**
 * Websocket Session 统一管理
 *
 * @author eric.hao
 * @since 2024/11/18
 */
@Slf4j
public class WsSessionManager {

    /**
     * 保存连接 session 的地方
     */
    private static final ConcurrentHashMap<String, WebSocketSession> SESSION_POOL = new ConcurrentHashMap<>();
    /**
     * UserId <=> sessionIds 关系
     */
    private static final Map<String, String> SESSION_USER_IDS = new ConcurrentHashMap<>();

    /**
     * 添加 session
     */
    public static void add(WebSocketSession session) {
        // 添加 session
        String sessionId = session.getId();
        String userId = getUserId(session);
        String clientIp = WebSocketUtil.getAttrStr(session, WebSocketInstant.URI_PARAM_IP);
        SESSION_POOL.computeIfAbsent(sessionId, sid -> {
            SESSION_USER_IDS.compute(userId, (uid, sids) ->
                    sids == null ? sessionId : (sids + StrPool.COMMA + sessionId)
            );
            return session;
        });
    }

    /**
     * 删除 session
     */
    public static void remove(WebSocketSession session) {
        // 删除 session
        String sessionId = session.getId();
        String userId = WebSocketUtil.getAttrStr(session, WebSocketInstant.URI_PARAM_USER_ID);
        SESSION_USER_IDS.computeIfPresent(userId, (uid, sids) -> {
            String[] split = sids.split(StrPool.COMMA);
            List<String> sessionIds = new ArrayList<>(Arrays.asList(split));
            for (var sid : split) {
                if (sessionId.equals(sid)) {
                    SESSION_POOL.remove(sid);
                    sessionIds.remove(sid);
                }
            }
            return StringUtils.join(sessionIds, StrPool.COMMA);
        });
    }

    /**
     * 删除并同步关闭连接
     */
    public static void removeAndClose(WebSocketSession session) {
        remove(session);
        try {
            // 关闭连接
            session.close();
        } catch (IOException e) {
            // todo: 关闭出现异常处理
            e.printStackTrace();
        }
    }

    /**
     * 获取当前 WebSocket Session 的 UserId
     */
    public static String getUserId(WebSocketSession session) {
        return WebSocketUtil.getAttrStr(session, WebSocketInstant.URI_PARAM_USER_ID);
    }

    /**
     * 推送消息
     *
     * @param recipients 接收人清单(UserIds)
     * @param consumer   推送方法
     * @return 实际推送的UserIds
     */
    public static List<String> push(Set<String> recipients, Consumer<WebSocketSession> consumer) {
        var recipientsPushed = new ArrayList<String>();
        SESSION_USER_IDS.forEach((uid, sessionIds) -> {
            if (recipients.contains(uid)) {
                for (var sid : sessionIds.split(StrPool.COMMA)) {
                    if (ObjectUtil.isNotNull(SESSION_POOL.get(sid)) && SESSION_POOL.get(sid).isOpen()) {
                        try {
                            consumer.accept(SESSION_POOL.get(sid));
                        } catch (SockJsTransportFailureException e) {
                            logger.error("[PHOENIX-WS] {}-pushToClient failed.", sid);
                        }
                    }
                }
                recipientsPushed.add(uid);
            }
        });
        return recipientsPushed;
    }
}
