package com.elitesland.cloudt.authorization.api.provider.security.configurer.filter;

import lombok.extern.log4j.Log4j2;
import org.springframework.http.HttpMethod;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.OAuth2TokenType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.security.Principal;

/**
 * 授权码模式的state认证.
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/3
 */
@Log4j2
public class OAuth2AuthorizationCodeStateAuthenticationFilter extends OncePerRequestFilter {
    private static final OAuth2TokenType TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);

    private final RequestMatcher oauth2AuthorizationEndpointRequestMatcher;
    private final OAuth2AuthorizationService authorizationService;

    public OAuth2AuthorizationCodeStateAuthenticationFilter(String oauth2AuthorizationEndpoint, OAuth2AuthorizationService authorizationService) {
        this.oauth2AuthorizationEndpointRequestMatcher = buildOAuth2AuthorizationEndpointRequestMatcher(oauth2AuthorizationEndpoint);
        this.authorizationService = authorizationService;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        if (!oauth2AuthorizationEndpointRequestMatcher.matches(request)) {
            filterChain.doFilter(request, response);
            return;
        }

        String state = request.getParameter(OAuth2ParameterNames.STATE);
        String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
        if (!StringUtils.hasText(state) || !StringUtils.hasText(clientId)) {
            log.warn("缺少必要参数，认证请求忽略：{}，{}", state, clientId);

            filterChain.doFilter(request, response);
            return;
        }

        // 判断是否已认证过
        Authentication authentication = authentication(clientId, state);
        if (authentication != null) {
            // 已认证过
            SecurityContext context = SecurityContextHolder.createEmptyContext();
            context.setAuthentication(authentication);
            SecurityContextHolder.setContext(context);
        }

        filterChain.doFilter(request, response);
    }

    private Authentication authentication(String clientId, String state) {
        OAuth2Authorization oAuth2Authorization = authorizationService.findByToken(state, TOKEN_TYPE);
        if (oAuth2Authorization == null) {
            log.info("未找到OAuth2Authorization：{}，需登录认证", state);
            return null;
        }
        if (!clientId.equals(oAuth2Authorization.getAttribute(OAuth2ParameterNames.CLIENT_ID))) {
            log.info("客户端{}, {}不一致，需登录认证", clientId, clientId);
            return null;
        }

        return oAuth2Authorization.getAttribute(Principal.class.getName());
    }

    private RequestMatcher buildOAuth2AuthorizationEndpointRequestMatcher(String oauth2Authorizationendpoint) {
        return new OrRequestMatcher(
                new AntPathRequestMatcher(
                        oauth2Authorizationendpoint,
                        HttpMethod.GET.name()),
                new AntPathRequestMatcher(
                        oauth2Authorizationendpoint,
                        HttpMethod.POST.name()));
    }
}
