package com.elitescloud.boot.auth.provider.config.servlet.oauth2.handler;

import cn.hutool.core.text.CharSequenceUtil;
import com.elitescloud.boot.auth.client.common.OAuth2ClientConstant;
import com.elitescloud.boot.auth.client.config.AuthorizationProperties;
import com.elitescloud.boot.auth.provider.config.servlet.oauth2.OAuth2AuthorizationCodeRequestCache;
import com.elitescloud.boot.auth.provider.security.handler.LogoutRedirectHandler;
import com.elitescloud.boot.auth.resolver.UniqueRequestResolver;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

/**
 * OAuth2登出后重定向至登录页.
 *
 * @author Kaiser（wang shao）
 * @date 3/2/2023
 */
public class OAuth2ServerLogoutRedirectHandler implements LogoutRedirectHandler {
    private static final OAuth2TokenType TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);

    private final RegisteredClientRepository registeredClientRepository;
    private final AuthorizationProperties authorizationProperties;
    private final OAuth2AuthorizationService authorizationService;
    private UniqueRequestResolver uniqueRequestResolver = null;
    private RequestCache requestCache;

    public OAuth2ServerLogoutRedirectHandler(RegisteredClientRepository registeredClientRepository,
                                             AuthorizationProperties authorizationProperties,
                                             OAuth2AuthorizationService authorizationService) {
        this.registeredClientRepository = registeredClientRepository;
        this.authorizationProperties = authorizationProperties;
        this.authorizationService = authorizationService;
    }

    @Override
    public String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response) {
        // 获取认证请求中的重定向url
        SavedRequest savedRequest = null;
        if (requestCache != null) {
            savedRequest = requestCache.getRequest(request, response);
        }
        String clientId = null;
        if (savedRequest != null) {
            String redirectUri = this.obtainRequestParam(savedRequest, OAuth2ParameterNames.REDIRECT_URI);
            if (StringUtils.hasText(redirectUri)) {
                return redirectUri;
            }
            clientId = this.obtainRequestParam(savedRequest, OAuth2ParameterNames.CLIENT_ID);
        }

        // 获取所属客户端
        if (CharSequenceUtil.isBlank(clientId)) {
            clientId = this.obtainClientId(request);
        }
        if (CharSequenceUtil.isBlank(clientId)) {
            return null;
        }

        // 获取登录页
        var client = registeredClientRepository.findByClientId(clientId);
        if (client == null) {
            return null;
        }
        String loginUrl = client.getClientSettings().getSetting(OAuth2ClientConstant.SETTING_LOGIN_URL);
        if (StringUtils.hasText(loginUrl)) {
            return loginUrl;
        }

        // 返回默认登录页
        return authorizationProperties.getLoginPage();
    }

    private String obtainRequestParam(SavedRequest request, String param) {
        String[] values = request.getParameterValues(param);
        if (values != null && values.length > 0) {
            return values[0];
        }
        return null;
    }

    private String obtainClientId(HttpServletRequest request) {
        if (uniqueRequestResolver != null) {
            String reqId = uniqueRequestResolver.analyze(request);
            if (StringUtils.hasText(reqId)) {
                OAuth2Authorization oAuth2Authorization = authorizationService.findByToken(reqId, TOKEN_TYPE);
                if (oAuth2Authorization != null) {
                    return oAuth2Authorization.getAttribute(OAuth2ParameterNames.CLIENT_ID);
                }
            }
        }

        return null;
    }

    public void setUniqueRequestResolver(UniqueRequestResolver uniqueRequestResolver) {
        this.uniqueRequestResolver = uniqueRequestResolver;
    }

    public void setRequestCache(RequestCache requestCache) {
        this.requestCache = requestCache;
    }
}
