package com.elitescloud.boot.websocket.support;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.CharSequenceUtil;
import com.elitescloud.boot.provider.UserDetailProvider;
import com.elitescloud.boot.websocket.common.WebSocketConstants;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.NonNull;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

import java.util.Map;

/**
 * 认证握手拦截器.
 *
 * @author Kaiser（wang shao）
 * @date 2023/5/31
 */
@Slf4j
public class AuthHandshakeInterceptor implements HandshakeInterceptor {

    private final UserDetailProvider userDetailProvider;

    public AuthHandshakeInterceptor(UserDetailProvider userDetailProvider) {
        this.userDetailProvider = userDetailProvider;
    }

    @Override
    public boolean beforeHandshake(@NonNull ServerHttpRequest request, @NonNull ServerHttpResponse response, @NonNull WebSocketHandler wsHandler,
                                   @NonNull Map<String, Object> attributes) throws Exception {
        var token = this.obtainToken(request);
        if (StringUtils.hasText(token)) {
            var user = userDetailProvider.getByToken(token);
            log.info("WebSocket authenticated user：{}", user == null ? null : user.getUsername());
            if (user != null) {
                attributes.put(WebSocketConstants.ATTRIBUTE_USER_DETAIL, user);
                return true;
            }
            return false;
        }
        return true;
    }

    @Override
    public void afterHandshake(@NonNull ServerHttpRequest request, @NonNull ServerHttpResponse response, @NonNull WebSocketHandler wsHandler, Exception exception) {

    }

    private String obtainToken(ServerHttpRequest request) {
        // 从请求头获取token
        var headers = request.getHeaders();
        var authorization = headers.get(HttpHeaders.AUTHORIZATION);

        String token = CollUtil.isEmpty(authorization) ? null : authorization.get(0);
        if (CharSequenceUtil.isBlank(token)) {
            authorization = headers.get("Access-Token");
            token = CollUtil.isEmpty(authorization) ? null : authorization.get(0);
        }

        return token;
    }
}
