package com.elitescloud.boot.websocket.support;

import cn.hutool.core.text.CharSequenceUtil;
import com.elitescloud.boot.auth.CommonAuthenticationToken;
import com.elitescloud.boot.provider.UserDetailProvider;
import com.elitescloud.boot.util.ObjectMapperFactory;
import com.elitescloud.boot.websocket.CloudtWebSocketHandler;
import com.elitescloud.boot.websocket.common.MsgType;
import com.elitescloud.boot.websocket.common.WebSocketConstants;
import com.elitescloud.boot.websocket.model.BaseParameterType;
import com.elitescloud.boot.websocket.util.WebSocketUtil;
import com.elitescloud.cloudt.common.base.ApiResult;
import com.elitescloud.cloudt.security.entity.GeneralUserDetails;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.lang.NonNull;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.util.Assert;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;

import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
import java.util.Map;

/**
 * websocket handler代理.
 *
 * @author Kaiser（wang shao）
 * @date 2023/5/31
 */
@Slf4j
public class CloudtWebSocketHandlerDelegate<T extends BaseParameterType, R extends Serializable> extends AbstractWebSocketHandler {
    private final CloudtWebSocketHandler<T, R> delegate;
    private final UserDetailProvider userDetailProvider;
    private final Class<T> parameterType;
    private final ObjectMapper objectMapper = ObjectMapperFactory.instance();

    public CloudtWebSocketHandlerDelegate(CloudtWebSocketHandler<T, R> delegate, UserDetailProvider userDetailProvider) {
        this.delegate = delegate;
        this.userDetailProvider = userDetailProvider;
        this.parameterType = delegate.parameterType();
        Assert.notNull(parameterType, () -> "获取" + delegate.getClass().getName() + "参数类型失败");
    }

    @Override
    public void afterConnectionEstablished(@NonNull WebSocketSession session) throws Exception {
        super.afterConnectionEstablished(session);

        if (session.getAttributes().containsKey(WebSocketConstants.ATTRIBUTE_USER_DETAIL)) {
            // 已认证
            WebSocketSessionManager.addAuthSession(session, (GeneralUserDetails) session.getAttributes().get(WebSocketConstants.ATTRIBUTE_USER_DETAIL));
        }
    }

    @Override
    protected void handleTextMessage(@NonNull WebSocketSession session, @NonNull TextMessage message) throws Exception {
        String msg = message.getPayload();
        if (CharSequenceUtil.isBlank(msg)) {
            log.info("消息为空，忽略！");
            return;
        }

        // 解析消息
        log.info("WebSocket Msg:{}", msg);
        Map<String, Object> payload = null;
        try {
            payload = objectMapper.readValue(message.getPayload(), new TypeReference<>() {
            });
        } catch (Exception e) {
            log.info("解析消息异常：", e);
            session.sendMessage(WebSocketUtil.convertTextMessage(false, "消息格式不正确", null));
            return;
        }
        // 获取消息类型
        var payloadType = payload.get(WebSocketConstants.PAYLOAD_TYPE);
        if (!(payloadType instanceof String)) {
            session.sendMessage(WebSocketUtil.convertTextMessage(false, "缺少字符串参数type", null));
            return;
        }
        var type = (String) payloadType;

        // 认证消息
        var dealed = this.attemptAuth(session, type, payload);
        if (dealed) {
            return;
        }

        // 业务处理
        this.adapterToHandler(session, payload);
    }

    @Override
    protected void handleBinaryMessage(@NonNull WebSocketSession session, @NonNull BinaryMessage message) throws Exception {
        log.warn("暂不支持binary message");
        try {
            session.close(CloseStatus.NOT_ACCEPTABLE.withReason("Binary messages not supported"));
        } catch (IOException ex) {
            // ignore
        }
    }

    @Override
    public void handleTransportError(@NonNull WebSocketSession session, @NonNull Throwable exception) throws Exception {
        log.error("WebSocket 通讯异常：", exception);
        WebSocketSessionManager.removeSession(session);

        super.handleTransportError(session, exception);
    }

    @Override
    public void afterConnectionClosed(@NonNull WebSocketSession session, @NonNull CloseStatus status) throws Exception {
        super.afterConnectionClosed(session, status);

        WebSocketSessionManager.removeSession(session);
    }

    private void adapterToHandler(WebSocketSession session, Map<String, Object> payload) throws IOException {
        // 转换入参消息
        T parameter = null;
        try {
            parameter = objectMapper.convertValue(payload, parameterType);
        } catch (Exception e) {
            log.error("转换消息异常：", e);
            session.sendMessage(WebSocketUtil.convertTextMessage(false, "消息格式不正确", null));
            return;
        }

        // 设置用户上下文
        var userDetail = WebSocketSessionManager.obtainAuthUser(session);
        if (userDetail != null) {
            SecurityContextHolder.getContext().setAuthentication(new CommonAuthenticationToken(SecurityContextHolder.getContext().getAuthentication(), userDetail, Collections.emptyList()));
        }
        ApiResult<R> apiResult = null;
        try {
            apiResult = delegate.handle(parameter);
        } catch (Throwable e) {
            log.error("处理消息异常：", e);
            return;
        } finally {
            if (userDetail != null) {
                SecurityContextHolder.clearContext();
            }
        }
        if (apiResult != null) {
            session.sendMessage(WebSocketUtil.convertTextMessage(apiResult));
        }
    }

    private boolean attemptAuth(WebSocketSession session, String type, Map<String, Object> payload) throws IOException {
        if (!MsgType.AUTH.name().equals(type)) {
            return false;
        }

        var accessToken = payload.get(WebSocketConstants.PAYLOAD_TOKEN);
        if (!(accessToken instanceof String) || CharSequenceUtil.isBlank(accessToken.toString())) {
            session.sendMessage(WebSocketUtil.convertTextMessage(false, "认证失败，缺少参数accessToken", null));
            return true;
        }
        var token = accessToken.toString();

        var user = userDetailProvider.getByToken(token);
        log.info("WebSocket Authorized User：{}", user == null ? null : user.getUsername());
        if (user == null) {
            session.sendMessage(WebSocketUtil.convertTextMessage(false, "认证已过期，请重新登录", null));
            return true;
        }

        // 认证消息
        WebSocketSessionManager.addAuthSession(session, user);
        session.sendMessage(WebSocketUtil.convertTextMessage(true, "认证成功", null));

        return true;
    }
}
