package com.elitesland.cbpl.tool.websocket.notifier;

import cn.hutool.core.net.url.UrlQuery;
import cn.hutool.core.text.StrPool;
import cn.hutool.core.util.ObjectUtil;
import com.elitesland.cbpl.tool.core.bean.BeanUtils;
import com.elitesland.cbpl.tool.core.http.HttpServletUtil;
import com.elitesland.cbpl.tool.websocket.domain.NotifierPayload;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import lombok.var;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.sockjs.SockJsTransportFailureException;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 站内消息推送者(通过WebSocket推送至前端)
 *
 * @author eric.hao
 * @since 2022/09/14
 */
@Slf4j
@RequiredArgsConstructor
public class WebSocketNotifier extends TextWebSocketHandler {

    private static final String URI_PARAM_USER_ID = "userId";
    private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
    private final Map<String, String> userSessionIds = new ConcurrentHashMap<>();

    public <T extends NotifierPayload> void notify(T payload) {
        var message = new TextMessage(BeanUtils.toJsonOrThrow(payload));
        var recipients = payload.getTos();
        var recipientsPushed = new ArrayList<String>();
        userSessionIds.forEach((uid, sids) -> {
            if (recipients.contains(uid)) {
                for (var sid : sids.split(StrPool.COMMA)) {
                    if (ObjectUtil.isNotNull(sessions.get(sid)) && sessions.get(sid).isOpen()) {
                        try {
                            pushToClient(sessions.get(sid), message);
                        } catch (SockJsTransportFailureException e) {
                            logger.error("[PHOENIX-WS] {}-pushToClient failed.", sid);
                        }
                    }
                }
                recipientsPushed.add(uid);
            }
        });
        logger.debug("[PHOENIX-WS] recipients has been pushed: {}", recipientsPushed);
    }

    @SneakyThrows
    private void pushToClient(WebSocketSession session, TextMessage message) {
        session.sendMessage(message);
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        var sessionId = session.getId();
        var userId = getUserId(session);
        logger.info("[PHOENIX-WS] websocket connection established: {} of u-{}", sessionId, userId);
        sessions.computeIfAbsent(sessionId, sid -> {
            userSessionIds.compute(userId, (uid, sids) ->
                    sids == null ? sessionId : (sids + StrPool.COMMA + sessionId));
            return session;
        });
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        var sessionId = session.getId();
        var userId = getUserId(session);
        logger.info("[PHOENIX-WS] afterConnectionClosed u-{}", userId);
        userSessionIds.computeIfPresent(userId, (uid, sids) -> {
            String[] split = sids.split(StrPool.COMMA);
            List<String> strings = new ArrayList<>(Arrays.asList(split));
            for (var sid : split) {
                if (sessionId.equals(sid)) {
                    sessions.remove(sid);
                    strings.remove(sid);
                }
            }
            return StringUtils.join(strings, StrPool.COMMA);
        });
        logger.info("[PHOENIX-WS] websocket connection closed: {}", sessionId);
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        // 目前WebSocket的使用场景仅用于单向发送站内消息，故消息接收处理被忽略。
        logger.trace("[PHOENIX-WS] message from websocket client: {}", message);
    }

    private String getUserId(WebSocketSession session) {
        UrlQuery queryParameters = HttpServletUtil.getQueryParameters(session.getUri());
        return queryParameters.get(URI_PARAM_USER_ID).toString();
    }
}
