package com.elitescloud.boot.auth.provider.security.grant;

import cn.hutool.core.collection.CollUtil;
import com.elitescloud.boot.SpringContextHolder;
import com.elitescloud.boot.auth.client.common.AuthenticateBeforeHandler;
import com.elitescloud.boot.auth.client.common.LoginType;
import com.elitescloud.boot.auth.client.token.AbstractCustomAuthenticationToken;
import com.elitescloud.boot.auth.provider.common.AuthorizationConstant;
import com.elitescloud.boot.auth.provider.common.LoginParameterNames;
import org.springframework.http.HttpMethod;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * 自定义授权登录Filter.
 *
 * @author Kaiser（wang shao）
 * @date 2022/6/27
 */
public class CustomGrantAuthenticationFilter extends AbstractAuthenticationProcessingFilter {

    /**
     * 默认登录路径
     */
    public static final String DEFAULT_FILTER_PROCESS_URI = "/oauth/login";
    private final Map<String, AbstractCustomAuthenticationToken<?>> authenticationTokenMap = new HashMap<>();
    private List<AuthenticateBeforeHandler> authenticateBeforeHandlers;

    public CustomGrantAuthenticationFilter() {
        this(DEFAULT_FILTER_PROCESS_URI);
    }

    public CustomGrantAuthenticationFilter(String defaultFilterProcessesUrl) {
        super(new AntPathRequestMatcher(defaultFilterProcessesUrl, HttpMethod.POST.name()));
    }

    public <T extends AbstractCustomAuthenticationToken> void addAuthenticationTokenConvert(T authenticationToken) {
        LoginType loginType = authenticationToken.loginType();
        Assert.notNull(loginType, authenticationToken.getClass().getName() + "的loginType为空");

        authenticationTokenMap.put(loginType.getType(), authenticationToken);
    }

    @Override
    public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException, IOException, ServletException {
        request.setAttribute(AuthorizationConstant.REQUEST_ATTRIBUTE_LOGIN_START_TIME, LocalDateTime.now());

        // 判断登录类型
        String loginType = request.getParameter(LoginParameterNames.LOGIN_TYPE);
        if (!StringUtils.hasText(loginType)) {
            throw new AuthenticationServiceException("未知登录类型");
        }

        // 转token
        AbstractCustomAuthenticationToken<?> authenticationToken = authenticationTokenMap.get(loginType);
        if (authenticationToken == null) {
            throw new AuthenticationServiceException("不支持的登录类型");
        }
        AbstractCustomAuthenticationToken<?> authentication = authenticationToken.convert(request);
        setDetails(request, authentication);

        // 认证前的处理
        request.setAttribute(AuthorizationConstant.REQUEST_ATTRIBUTE_AUTHENTICATION_ORIGINAL, authentication);
        beforeAuthenticate(request, response);

        // 开始认证
        return this.getAuthenticationManager().authenticate(authentication);
    }

    @Override
    protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) throws IOException, ServletException {
        this.logger.trace("Failed to process authentication request", failed);
        this.logger.trace("Cleared SecurityContextHolder");
        this.logger.trace("Handling authentication failure");
        logger.info("认证不通过");
        this.getRememberMeServices().loginFail(request, response);
        this.getFailureHandler().onAuthenticationFailure(request, response, failed);
    }

    private void setDetails(HttpServletRequest request, AbstractAuthenticationToken authentication) {
        authentication.setDetails(this.authenticationDetailsSource.buildDetails(request));
    }

    private List<AuthenticateBeforeHandler> getAuthenticateBeforeHandlers() {
        if (this.authenticateBeforeHandlers == null) {
            authenticateBeforeHandlers = SpringContextHolder.getObjectProvider(AuthenticateBeforeHandler.class).stream().collect(Collectors.toList());
        }
        return authenticateBeforeHandlers;
    }

    private void beforeAuthenticate(HttpServletRequest request, HttpServletResponse response) {
        if (CollUtil.isNotEmpty(getAuthenticateBeforeHandlers())) {
            for (AuthenticateBeforeHandler authenticateBeforeHandler : getAuthenticateBeforeHandlers()) {
                authenticateBeforeHandler.handle(request, response);
            }
        }
        SecurityContextHolder.clearContext();
    }
}
