package com.elitescloud.boot.auth.provider.config.servlet.oauth2.handler;

import cn.hutool.core.text.CharSequenceUtil;
import com.elitescloud.boot.auth.client.common.OAuth2ClientConstant;
import com.elitescloud.boot.auth.client.config.AuthorizationProperties;
import com.elitescloud.boot.auth.common.AuthSdkConstant;
import com.elitescloud.boot.auth.provider.config.servlet.oauth2.OAuth2AuthorizationCodeRequestCache;
import com.elitescloud.boot.auth.resolver.UniqueRequestResolver;
import com.elitescloud.boot.util.ObjUtil;
import lombok.extern.log4j.Log4j2;
import org.springframework.http.HttpHeaders;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.util.StringUtils;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.constraints.NotBlank;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;

/**
 * OAuth2登录时的.
 *
 * @author Kaiser（wang shao）
 * @date 2022/12/27
 */
@Log4j2
public class OAuth2ServerLoginUrlAuthenticationEntryPointHandler extends LoginUrlAuthenticationEntryPoint {
    private final AuthorizationProperties properties;
    private final RegisteredClientRepository registeredClientRepository;
    private final OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache;
    private UniqueRequestResolver uniqueRequestResolver = null;
    private RequestCache requestCache;
    private final ThreadLocal<String> seqLocal = new ThreadLocal<>();

    public OAuth2ServerLoginUrlAuthenticationEntryPointHandler(AuthorizationProperties properties, RegisteredClientRepository registeredClientRepository,
                                                               OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache) {
        super(properties.getLoginPage());
        this.properties = properties;
        this.registeredClientRepository = registeredClientRepository;
        this.authorizationCodeRequestCache = authorizationCodeRequestCache;
    }

    @Override
    public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) throws IOException, ServletException {
        log.info("{}未认证，将转向登录页", request.getRequestURI());
        if (uniqueRequestResolver != null) {
            OAuth2AuthorizationCodeRequestAuthenticationToken codeRequest = (OAuth2AuthorizationCodeRequestAuthenticationToken) new OAuth2AuthorizationCodeRequestAuthenticationConverter().convert(request);

            String seqId = uniqueRequestResolver.signRequest(response);
            seqLocal.set(seqId);
            authorizationCodeRequestCache.setAuthenticationToken(seqId, codeRequest, Duration.ofDays(7));
        }
        if (requestCache != null) {
            requestCache.saveRequest(request, response);
        }
        try {
            super.commence(request, response, authException);
        } finally {
            seqLocal.remove();
        }
    }

    @Override
    protected String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) {
        var clientId = this.obtainClientId(request);
        log.info("authorize client: {}", clientId);

        // 获取客户端定制登录页
        if (StringUtils.hasText(clientId)) {
            String clientLoginPage = obtainClientLoginPage(clientId);
            if (StringUtils.hasText(clientLoginPage)) {
                // 获取客户端自定义的登录页，如果有自定义，则需要写请求标识
                String seqId = seqLocal.get();
                if (StringUtils.hasText(seqId)) {
                    clientLoginPage = (clientLoginPage.contains("?") ? clientLoginPage + "&" : clientLoginPage + "?")
                            + "ClientId=" + clientId
                            + "&"
                            + "Urq=" + seqId;
                }
                String loginUrl = normalizeRedirectUrl(clientLoginPage, request.getHeader(HttpHeaders.REFERER));
                log.info("client customize loginUrl：{}", loginUrl);
                return loginUrl;
            }
        }

        // 默认登录页
        String loginPage = properties.getLoginPage();
        if (!StringUtils.hasText(loginPage)) {
            log.warn("未配置登录页地址");
            return loginPage;
        }
        loginPage = (loginPage.contains("?") ? loginPage + "&" : loginPage + "?") + "ClientId=" + ObjUtil.defaultIfNull(clientId, "");

        // 请求源地址
        var resultPage = this.normalizeRedirectUrl(loginPage, request.getParameter(AuthSdkConstant.OAUTH2_AUTHORIZE_PARAM_AUTH_SERVER));
        log.info("loginUrl:{} -> {}", loginPage, resultPage);
        return resultPage;
    }

    public void setUniqueRequestResolver(UniqueRequestResolver uniqueRequestResolver) {
        this.uniqueRequestResolver = uniqueRequestResolver;
    }

    public void setRequestCache(RequestCache requestCache) {
        this.requestCache = requestCache;
    }

    private String normalizeRedirectUrl(@NotBlank String url, String referUrl) {
        // 如果是绝对路径，则直接返回
        String tempUrl = url.toLowerCase();
        if (tempUrl.startsWith("http://") || tempUrl.startsWith("https://")) {
            return url;
        }

        String urlPrefix = null;
        if (StringUtils.hasText(referUrl)) {
            String referPath = URI.create(referUrl).getPath();
            urlPrefix = StringUtils.hasText(referPath) ? referUrl.substring(0, referUrl.indexOf(referPath)) : referUrl;
        }

        if (!StringUtils.hasText(urlPrefix)) {
            urlPrefix = properties.getRedirectUriPrefix();
        }

        if (urlPrefix == null) {
            urlPrefix = "/";
        }
        return url.startsWith("/") ? urlPrefix + url : urlPrefix + "/" + url;
    }

    private String convertUrlForRefer(String url, String referer) {
        var uri = URI.create(url);
        String path = CharSequenceUtil.blankToDefault(uri.getPath(), "");

        var refererUri = URI.create(referer);
        String pathReferPath = refererUri.getPath();
        if (!StringUtils.hasText(pathReferPath)) {
            return referer + path;
        }

        return referer.substring(0, referer.indexOf(pathReferPath)) + path;
    }

    private String obtainClientLoginPage(String clientId) {
        var client = registeredClientRepository.findByClientId(clientId);
        if (client == null) {
            return null;
        }
        return client.getClientSettings().getSetting(OAuth2ClientConstant.SETTING_LOGIN_URL);
    }

    private String obtainClientId(HttpServletRequest request) {
        return request.getParameter(OAuth2ParameterNames.CLIENT_ID);
    }
}
