package com.elitesland.cbpl.tool.websocket.handler;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.corundumstudio.socketio.SocketIOClient;
import com.elitesland.cbpl.tool.extra.spring.SpringUtils;
import com.elitesland.cbpl.tool.websocket.constant.WebSocketConstant;
import com.elitesland.cbpl.tool.websocket.handler.domain.OnlineUser;
import com.elitesland.cbpl.tool.websocket.spi.OnlineListener;
import com.elitesland.cbpl.tool.websocket.util.WebSocketUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.socket.sockjs.SockJsTransportFailureException;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.stream.Collectors;

/**
 * Websocket Session 统一管理
 *
 * @author eric.hao
 * @since 2024/11/18
 */
@Slf4j
public class WebSocketSessionManager {

    /**
     * UserId <=> Session 映射关系
     */
    private static final ConcurrentHashMap<String, List<OnlineUser>> SESSION_USERS = new ConcurrentHashMap<>();

    /**
     * 添加 session
     */
    public static void add(SocketIOClient session) {
        var visitorId = WebSocketUtil.getVisitorId(session);
        if (StrUtil.isBlank(visitorId)) {
            logger.warn("[PHOENIX-WS] VisitorId is Missing.");
            return;
        }

        var sid = session.getSessionId();
        var userId = getUserId(session);
        var ip = getClientIP(session);
        // 当前活跃用户
        var user = OnlineUser.of(session, ip);
        SESSION_USERS.compute(userId, (uid, users) -> {
            // 1.空数据，直接返回新数组
            if (CollUtil.isEmpty(users)) {
                return CollUtil.newArrayList(user);
            }
            // 2.如果存在，先删除
            users.removeIf(row -> row.getSessionId().equals(sid));
            // 3.再更新session数据，主要是最后活跃时间
            users.add(user);
            return users;
        });

        // 扩展实现
        if (SpringUtils.isPresent(OnlineListener.class)) {
            var listener = SpringUtils.getBean(OnlineListener.class);
            listener.add(user);
        }
    }

    /**
     * 删除 session
     */
    public static void remove(SocketIOClient session) {
        var visitorId = WebSocketUtil.getVisitorId(session);
        if (StrUtil.isBlank(visitorId)) {
            logger.warn("[PHOENIX-WS] VisitorId is Missing.");
            return;
        }

        var sid = session.getSessionId();
        var userId = getUserId(session);
        var ip = getClientIP(session);
        // 当前活跃用户
        var user = OnlineUser.of(session, ip);
        SESSION_USERS.compute(userId, (uid, users) -> {
            // 1.空数据，直接返回
            if (CollUtil.isEmpty(users)) {
                return users;
            }
            // 2.如果存在，则删除
            users.removeIf(row -> row.getSessionId().equals(sid));
            return users;
        });

        // 3.扩展实现
        if (SpringUtils.isPresent(OnlineListener.class)) {
            var listener = SpringUtils.getBean(OnlineListener.class);
            listener.remove(user);
        }
    }

    /**
     * 获取当前 WebSocket Session 的 UserId
     */
    public static String getUserId(SocketIOClient session) {
        // 如果对接了云梯，使用云梯的用户对象解析
        if (SpringUtils.isPresent(OnlineListener.class)) {
            var listener = SpringUtils.getBean(OnlineListener.class);
            return listener.getUserId(session.getHandshakeData());
        }
        // 如果没有，统一返回游客
        return WebSocketConstant.GUEST_USER_ID;
    }

    /**
     * 获取当前 WebSocket Session 的登录账号
     */
    public static String getUsername(SocketIOClient session) {
        // 如果对接了云梯，使用云梯的用户对象解析
        if (SpringUtils.isPresent(OnlineListener.class)) {
            var listener = SpringUtils.getBean(OnlineListener.class);
            return listener.getUsername(session.getHandshakeData());
        }
        // 如果没有，统一返回游客
        return WebSocketConstant.GUEST_USERNAME;
    }

    /**
     * 获取当前 WebSocket Session 的昵称
     */
    public static String getNickname(SocketIOClient session) {
        // 如果对接了云梯，使用云梯的用户对象解析
        if (SpringUtils.isPresent(OnlineListener.class)) {
            var listener = SpringUtils.getBean(OnlineListener.class);
            return listener.getNickname(session.getHandshakeData());
        }
        // 如果没有，统一返回游客
        return WebSocketConstant.GUEST_NICKNAME;
    }

    /**
     * 获取当前 WebSocket Session 的租户编码
     */
    public static String getTenantCode(SocketIOClient session) {
        // 如果对接了云梯，使用云梯的用户对象解析
        if (SpringUtils.isPresent(OnlineListener.class)) {
            var listener = SpringUtils.getBean(OnlineListener.class);
            return listener.getTenantCode(session.getHandshakeData());
        }
        // 如果没有，统一返回无租户
        return WebSocketConstant.GUEST_TENANT_CODE;
    }

    /**
     * 获取客户端IP【临时方法】
     */
    private static String getClientIP(SocketIOClient session) {
        if (session.getRemoteAddress() == null) {
            return "unknown";
        }
        return session.getHandshakeData().getAddress().getAddress().getHostAddress();
    }

    /**
     * 推送消息
     *
     * @param recipients 接收人清单(UserIds)
     * @param consumer   推送方法
     * @return 实际推送的UserIds
     */
    public static List<String> push(Set<String> recipients, Consumer<SocketIOClient> consumer) {
        var recipientsPushed = new ArrayList<String>();
        SESSION_USERS.forEach((uid, users) -> {
            if (recipients.contains(uid)) {
                for (OnlineUser user : users) {
                    if (user.getSession().isChannelOpen()) {
                        try {
                            consumer.accept(user.getSession());
                        } catch (SockJsTransportFailureException e) {
                            logger.error("[PHOENIX-WS] {}-pushToClient failed.", user.getSessionId());
                        }
                    }
                }
                recipientsPushed.add(uid);
            }
        });
        return recipientsPushed;
    }

    /**
     * 获取所有在线用户
     */
    public static List<OnlineUser> getOnlineUsers() {
        return SESSION_USERS.values().stream().flatMap(Collection::stream).collect(Collectors.toList());
    }
}
