package com.elitescloud.boot.auth.provider.config.servlet.oauth2.handler;

import com.elitescloud.boot.auth.client.common.OAuth2ClientConstant;
import com.elitescloud.boot.auth.provider.config.servlet.oauth2.OAuth2AuthorizationCodeRequestCache;
import com.elitescloud.boot.auth.resolver.UniqueRequestResolver;
import lombok.extern.log4j.Log4j2;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.util.StringUtils;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

/**
 * OAuth2登录时的.
 *
 * @author Kaiser（wang shao）
 * @date 2022/12/27
 */
@Log4j2
public class OAuth2ServerLoginUrlAuthenticationEntryPointHandler extends LoginUrlAuthenticationEntryPoint {
    private final RegisteredClientRepository registeredClientRepository;
    private final OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache;
    private UniqueRequestResolver uniqueRequestResolver = null;
    private RequestCache requestCache;
    private final ThreadLocal<String> seqLocal = new ThreadLocal<>();

    public OAuth2ServerLoginUrlAuthenticationEntryPointHandler(String loginFormUrl, RegisteredClientRepository registeredClientRepository, OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache) {
        super(loginFormUrl);
        this.registeredClientRepository = registeredClientRepository;
        this.authorizationCodeRequestCache = authorizationCodeRequestCache;
    }

    @Override
    public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) throws IOException, ServletException {
        log.info("{}未认证，将转向认证页", request.getRequestURI());
        if (uniqueRequestResolver != null) {
            OAuth2AuthorizationCodeRequestAuthenticationToken codeRequest = (OAuth2AuthorizationCodeRequestAuthenticationToken) new OAuth2AuthorizationCodeRequestAuthenticationConverter().convert(request);

            String seqId = uniqueRequestResolver.signRequest(response);
            seqLocal.set(seqId);
            authorizationCodeRequestCache.setAuthenticationToken(seqId, codeRequest, null);
        }
        if (requestCache != null) {
            requestCache.saveRequest(request, response);
        }
        try {
            super.commence(request, response, authException);
        } finally {
            seqLocal.remove();
        }
    }

    @Override
    protected String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) {
        String clientLoginPage = obtainClientLoginPage(request);
        if (StringUtils.hasText(clientLoginPage)) {
            // 获取客户端自定义的登录页，如果有自定义，则需要写请求标识
            String seqId = seqLocal.get();
            if (StringUtils.hasText(seqId)) {
                clientLoginPage = (clientLoginPage.contains("?") ? clientLoginPage + "&" : clientLoginPage + "?") + "Urq=" + seqId;
            }
            return clientLoginPage;
        }

        return super.determineUrlToUseForThisRequest(request, response, exception);
    }

    public void setUniqueRequestResolver(UniqueRequestResolver uniqueRequestResolver) {
        this.uniqueRequestResolver = uniqueRequestResolver;
    }

    public void setRequestCache(RequestCache requestCache) {
        this.requestCache = requestCache;
    }

    private String obtainClientLoginPage(HttpServletRequest request) {
        String[] clientIds = request.getParameterValues(OAuth2ParameterNames.CLIENT_ID);
        if (clientIds == null || clientIds.length == 0) {
            return null;
        }

        var client = registeredClientRepository.findByClientId(clientIds[0]);
        if (client == null) {
            return null;
        }
        return client.getClientSettings().getSetting(OAuth2ClientConstant.SETTING_LOGIN_URL);
    }
}
