package com.elitescloud.cloudt.authorization.api.provider.config.servlet.oauth2.configurer.filter;

import com.elitescloud.cloudt.authorization.api.client.common.SecurityConstants;
import com.elitescloud.cloudt.authorization.api.provider.config.servlet.oauth2.OAuth2AuthorizationCodeUserVerifier;
import lombok.extern.log4j.Log4j2;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.StringUtils;

import javax.annotation.Nonnull;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.List;

/**
 * 授权码模式的认证用户校验.
 *
 * @author Kaiser（wang shao）
 * @date 2/28/2023
 */
@Log4j2
public class OAuth2AuthorizationCodeUserVerifierFilter extends AbstractOAuth2Filter {

    private final RequestMatcher oauth2AuthorizationEndpointRequestMatcher;
    private final List<OAuth2AuthorizationCodeUserVerifier> auth2AuthorizeUserVerifiers;

    public OAuth2AuthorizationCodeUserVerifierFilter(String oauth2AuthorizationEndpointRequestMatcher, List<OAuth2AuthorizationCodeUserVerifier> auth2AuthorizeUserVerifiers) {
        this.oauth2AuthorizationEndpointRequestMatcher = super.buildOAuth2AuthorizationEndpointRequestMatcher(oauth2AuthorizationEndpointRequestMatcher);
        this.auth2AuthorizeUserVerifiers = auth2AuthorizeUserVerifiers;
    }

    @Override
    protected void doFilterInternal(@Nonnull HttpServletRequest request, @Nonnull HttpServletResponse response, @Nonnull FilterChain filterChain) throws ServletException, IOException {
        if (!oauth2AuthorizationEndpointRequestMatcher.matches(request)) {
            filterChain.doFilter(request, response);
            return;
        }
        String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
        if (!StringUtils.hasText(clientId)) {
            // 缺少客户端参数
            filterChain.doFilter(request, response);
            return;
        }

        var authentication = SecurityContextHolder.getContext().getAuthentication();
        if (!super.isPrincipalAuthenticated(authentication)) {
            // 未认证
            filterChain.doFilter(request, response);
            return;
        }

        for (OAuth2AuthorizationCodeUserVerifier auth2AuthorizeUserVerifier : auth2AuthorizeUserVerifiers) {
            if (auth2AuthorizeUserVerifier.verify(clientId, authentication)) {
                log.info("已认证用户{}与客户端{}已绑定", authentication.getName(), clientId);
                continue;
            }

            // 校验不通过，则设置为匿名用户
            log.info("已认证用户{}与客户端{}未绑定", authentication.getName(), clientId);
            SecurityContextHolder.getContext().setAuthentication(SecurityConstants.AUTHENTICATION_ANONYMOUS);
            break;
        }
        filterChain.doFilter(request, response);
    }
}
