package com.elitescloud.boot.auth.provider.config.servlet.oauth2.handler;

import com.elitescloud.boot.auth.provider.config.servlet.oauth2.OAuth2AuthorizationCodeRequestCache;
import com.elitescloud.boot.auth.resolver.UniqueRequestResolver;
import com.elitescloud.cloudt.common.base.ApiCode;
import com.elitescloud.cloudt.common.base.ApiResult;
import lombok.extern.log4j.Log4j2;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

/**
 * 未认证时的处理.
 * <p>
 * 以json格式向前端返回，用以支持不能重定向的客户端
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/3
 */
@Log4j2
public class OAuth2ServerJsonAuthenticationEntryPointHandler extends AbstractOAuth2ServerHandler implements AuthenticationEntryPoint {

    private final OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache;
    private final RequestMatcher oauth2AuthorizationEndpointRequestMatcher;
    private UniqueRequestResolver uniqueRequestResolver;

    public OAuth2ServerJsonAuthenticationEntryPointHandler(OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache, String oauth2AuthorizationEndpoint) {
        this.authorizationCodeRequestCache = authorizationCodeRequestCache;
        this.oauth2AuthorizationEndpointRequestMatcher = buildOAuth2AuthorizationEndpointRequestMatcher(oauth2AuthorizationEndpoint);
    }

    @Override
    public void commence(HttpServletRequest request, HttpServletResponse response, AuthenticationException authException) throws IOException, ServletException {
        log.info("{}未认证，需前端转向登录页", request.getRequestURI());
        cacheAuthorizeRequest(request, response);

        writeResponse(response, ApiResult.fail(ApiCode.UNAUTHORIZED, "未认证或身份认证已过期，请重新登录"), HttpStatus.UNAUTHORIZED);
    }

    private void cacheAuthorizeRequest(HttpServletRequest request, HttpServletResponse response) {
        if (!oauth2AuthorizationEndpointRequestMatcher.matches(request)) {
            return;
        }

        // 标记request，以便在登录后能找到该认证请求信息
        OAuth2AuthorizationCodeRequestAuthenticationToken codeRequest = (OAuth2AuthorizationCodeRequestAuthenticationToken) new OAuth2AuthorizationCodeRequestAuthenticationConverter().convert(request);
        if (uniqueRequestResolver != null) {
            String reqId = uniqueRequestResolver.signRequest(response);
            authorizationCodeRequestCache.setAuthenticationToken(reqId, codeRequest, null);
        }
    }

    private RequestMatcher buildOAuth2AuthorizationEndpointRequestMatcher(String oauth2AuthorizationEndpoint) {
        return new OrRequestMatcher(new AntPathRequestMatcher(oauth2AuthorizationEndpoint, HttpMethod.GET.name()), new AntPathRequestMatcher(oauth2AuthorizationEndpoint, HttpMethod.POST.name()));
    }

    public void setUniqueRequestResolver(UniqueRequestResolver uniqueRequestResolver) {
        this.uniqueRequestResolver = uniqueRequestResolver;
    }
}
