package com.elitescloud.cloudt.authorization.api.provider.config.servlet.oauth2.configurer.filter;

import com.elitescloud.cloudt.authorization.sdk.resolver.UniqueRequestResolver;
import lombok.extern.log4j.Log4j2;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.StringUtils;

import javax.annotation.Nonnull;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.security.Principal;

/**
 * 授权码模式的state认证.
 * <p>
 * 用以支持不支持重定向（无法找到表单登录之前的认证请求）的终端的认证
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/3
 */
@Log4j2
public class OAuth2AuthorizationCodeStateAuthenticationFilter extends AbstractOAuth2Filter {
    private static final OAuth2TokenType TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);

    private final RequestMatcher oauth2AuthorizationEndpointRequestMatcher;
    private final OAuth2AuthorizationService authorizationService;
    private UniqueRequestResolver uniqueRequestResolver;

    public OAuth2AuthorizationCodeStateAuthenticationFilter(String oauth2AuthorizationEndpoint, OAuth2AuthorizationService authorizationService) {
        this.oauth2AuthorizationEndpointRequestMatcher = super.buildOAuth2AuthorizationEndpointRequestMatcher(oauth2AuthorizationEndpoint);
        this.authorizationService = authorizationService;
    }

    @Override
    protected void doFilterInternal(@Nonnull HttpServletRequest request, @Nonnull HttpServletResponse response, @Nonnull FilterChain filterChain) throws ServletException, IOException {
        if (!oauth2AuthorizationEndpointRequestMatcher.matches(request)) {
            filterChain.doFilter(request, response);
            return;
        }

        var authentication = SecurityContextHolder.getContext().getAuthentication();
        if (super.isPrincipalAuthenticated(authentication)) {
            // 已认证
            log.info("认证过的用户：{}", authentication.getName());
            filterChain.doFilter(request, response);
            return;
        }

        String reqId = uniqueRequestResolver.analyze(request);
        String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
        if (!StringUtils.hasText(clientId) || !StringUtils.hasText(reqId)) {
            log.info("缺少必要参数，认证请求忽略：{}，{}", reqId, clientId);

            filterChain.doFilter(request, response);
            return;
        }

        // 判断是否已认证过
        authentication = authentication(clientId, reqId);
        if (authentication != null) {
            // 已认证过
            SecurityContext context = SecurityContextHolder.createEmptyContext();
            context.setAuthentication(authentication);
            SecurityContextHolder.setContext(context);
        }

        try {
            filterChain.doFilter(request, response);
        } finally {
            SecurityContextHolder.clearContext();
        }
    }

    public void setUniqueRequestResolver(UniqueRequestResolver uniqueRequestResolver) {
        this.uniqueRequestResolver = uniqueRequestResolver;
    }

    private Authentication authentication(String clientId, String reqId) {
        OAuth2Authorization oAuth2Authorization = authorizationService.findByToken(reqId, TOKEN_TYPE);
        if (oAuth2Authorization == null) {
            log.debug("未找到OAuth2Authorization：{}，需登录认证", reqId);
            return null;
        }
        if (!clientId.equals(oAuth2Authorization.getAttribute(OAuth2ParameterNames.CLIENT_ID))) {
            log.info("客户端{}, {}不一致，需登录认证", clientId, clientId);
            return null;
        }

        return oAuth2Authorization.getAttribute(Principal.class.getName());
    }
}
