package com.elitescloud.boot.auth.client.config.security;

import cn.hutool.core.util.BooleanUtil;
import com.elitescloud.boot.auth.client.common.InterceptUri;
import com.elitescloud.boot.auth.client.config.AuthorizationProperties;
import com.elitescloud.boot.auth.client.config.security.configurer.AuthorizationConfigurerCustomizer;
import com.elitescloud.boot.auth.client.config.security.configurer.DefaultAuthorizationConfigurer;
import com.elitescloud.boot.auth.client.config.security.handler.DefaultAccessDeniedHandler;
import com.elitescloud.boot.auth.client.config.security.handler.DefaultAuthenticationEntryPointHandler;
import com.elitescloud.boot.auth.client.config.security.resolver.BearerTokenResolver;
import com.elitescloud.boot.auth.client.config.support.AuthenticationCache;
import com.elitescloud.boot.auth.config.AuthorizationSdkProperties;
import com.elitescloud.boot.auth.sso.SsoProvider;
import com.elitescloud.boot.auth.sso.configurer.SsoFilterConfigurer;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configurers.*;
import org.springframework.security.config.annotation.web.configurers.oauth2.server.resource.OAuth2ResourceServerConfigurer;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.rememberme.PersistentTokenRepository;
import org.springframework.security.web.savedrequest.CookieRequestCache;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.util.StringUtils;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * .
 *
 * @author Kaiser（wang shao）
 * @date 2022/6/20
 */
@Log4j2
public abstract class AbstractServletSecurityConfig {
    /**
     * 默认的securityFilterChain名称
     */
    public static final String SECURITY_CHAIN_DEFAULT = "defaultSecurityFilterChain";
    /**
     * OAuth2 Server的securityFilterChain名称
     */
    public static final String SECURITY_CHAIN_AUTH2_SERVER = "authorizationServerSecurityFilterChain";

    protected AuthorizationProperties authorizationProperties;
    private AuthorizationSdkProperties authorizationSdkProperties;
    private JwtDecoder jwtDecoder;
    protected ObjectProvider<AuthenticationCache> cacheObjectProvider;
    protected ObjectProvider<SsoProvider> ssoProviderObjectProvider;
    protected ObjectProvider<AuthorizationConfigurerCustomizer> authorizationConfigurerCustomizerObjectProvider;
    protected ObjectProvider<BearerTokenResolver> tokenResolverObjectProvider;

    protected AbstractServletSecurityConfig() {
    }

    /**
     * 默认的配置
     *
     * @return
     */
    protected HttpSecurity defaultSecurityConfig(HttpSecurity http) throws Exception {
        if (Boolean.FALSE.equals(authorizationProperties.getCsrfEnabled())) {
            // 关闭csrf
            http.csrf().disable();
        }

        // cors配置
        corsConfiguration(http);

        http.authorizeRequests(authorizeRequest())
                // 默认的认证配置
                .apply(new DefaultAuthorizationConfigurer<>(authorizationProperties, cacheObjectProvider.getIfAvailable(),
                        jwtDecoder, authorizationConfigurerCustomizerObjectProvider))
                .needBearerTokenAuthenticationFilter(needCloudtBearerTokenAuthenticationFilter())
                .bearerTokenResolver(tokenResolverObjectProvider.getIfAvailable())
                .autoRenewalToken(autoRenewalToken())
                .and()
                // 单点登录配置
                .apply(new SsoFilterConfigurer<>(authorizationSdkProperties))
                .setSsoProvider(ssoProviderObjectProvider.getIfAvailable())
                .and()
                // 异常处理
                .exceptionHandling(exceptionHandlingCustomizer())
                // session管理
                .sessionManagement(sessionManagementCustomizer())
                .headers(headersCustomizer())
        ;

        return http;
    }

    /**
     * header
     *
     * @return
     */
    protected Customizer<HeadersConfigurer<HttpSecurity>> headersCustomizer() {
        return customizer -> customizer.frameOptions().sameOrigin();
    }

    /**
     * session管理
     *
     * @return
     */
    protected Customizer<SessionManagementConfigurer<HttpSecurity>> sessionManagementCustomizer() {
        if (Boolean.FALSE.equals(authorizationProperties.getSessionEnabled())) {
            // 禁用session
            return customizer -> customizer.sessionCreationPolicy(SessionCreationPolicy.STATELESS);
        }

        return customizer -> customizer.sessionCreationPolicy(SessionCreationPolicy.IF_REQUIRED);
    }

    /**
     * 跨域配置
     *
     * @param http
     */
    protected void corsConfiguration(HttpSecurity http) throws Exception {
        if (Boolean.FALSE.equals(authorizationProperties.getCorsEnabled())) {
            // 关闭cors
            http.cors().disable();
            return;
        }

        UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
        for (AuthorizationProperties.CorsConfig cors : authorizationProperties.getCors()) {
            CorsConfiguration corsConfiguration = new CorsConfiguration();
            cors.getAllowedOriginPatterns().forEach(corsConfiguration::addAllowedOriginPattern);
            cors.getAllowedOrigins().forEach(corsConfiguration::addAllowedOrigin);
            cors.getAllowedHeaders().forEach(corsConfiguration::addAllowedHeader);
            cors.getExposeHeaders().forEach(corsConfiguration::addExposedHeader);
            cors.getAllowedMethods().forEach(corsConfiguration::addAllowedMethod);
            corsConfiguration.setAllowCredentials(cors.isAllowCredentials());

            source.registerCorsConfiguration(cors.getPathMatcher(), corsConfiguration);
        }
        http.cors(configurer -> configurer.configurationSource(source));
    }

    /**
     * 记住我
     *
     * @return
     */
    protected Customizer<RememberMeConfigurer<HttpSecurity>> rememberMeConfigurerCustomizer(PersistentTokenRepository tokenRepository,
                                                                                            UserDetailsService userDetailsService,
                                                                                            AuthenticationSuccessHandler successHandler) {
        int ttl = (int) authorizationProperties.getRememberMeTtl().toSeconds();
        if (ttl  < 1) {
            return null;
        }

        return customizer -> customizer
                .rememberMeParameter("remember_me")
                .tokenValiditySeconds(ttl)
                .alwaysRemember(false)
//                .useSecureCookie(true)
                .tokenRepository(tokenRepository)
                .userDetailsService(userDetailsService)
                .authenticationSuccessHandler(successHandler)
                ;
    }

    /**
     * 自定义Spring Security相关异常处理
     *
     * @return
     */
    protected Customizer<ExceptionHandlingConfigurer<HttpSecurity>> exceptionHandlingCustomizer() {
        return configurer -> {
            // 未认证时的处理
            configurer.authenticationEntryPoint(new DefaultAuthenticationEntryPointHandler(authorizationProperties.getLoginPage()));

            // 无权限时的处理
            configurer.accessDeniedHandler(new DefaultAccessDeniedHandler());
        };
    }

    /**
     * 请求认证配置
     *
     * @return
     */
    protected Customizer<ExpressionUrlAuthorizationConfigurer<HttpSecurity>.ExpressionInterceptUrlRegistry> authorizeRequest() {
        return urlRegistry -> {
            if (BooleanUtil.isTrue(authorizationProperties.getAnonymousEnabled())) {
                log.warn("currently allows anonymous access !");
                // 允许匿名访问
                // 黑名单
                Set<String> rejectList = getRejectUris();
                if (!rejectList.isEmpty()) {
                    // mvc匹配模式
                    var mvcPatterns = rejectList.stream().filter(this::adapterMvcRequestMatch).collect(Collectors.toSet());
                    if (!mvcPatterns.isEmpty()) {
                        urlRegistry.mvcMatchers(mvcPatterns.toArray(String[]::new)).authenticated();
                    }

                    // ant模式
                    var antPatterns = rejectList.stream().filter(t -> t != null && !mvcPatterns.contains(t)).collect(Collectors.toSet());
                    if (!antPatterns.isEmpty()) {
                        urlRegistry.antMatchers(antPatterns.toArray(String[]::new)).authenticated();
                    }
                }
                urlRegistry.anyRequest().permitAll();
                return;
            }

            // 不允许匿名访问
            // 白名单
            Set<String> allowList = getAllowUris();
            if (!allowList.isEmpty()) {
                // mvc匹配模式
                var mvcPatterns = allowList.stream().filter(this::adapterMvcRequestMatch).collect(Collectors.toSet());
                if (!mvcPatterns.isEmpty()) {
                    urlRegistry.mvcMatchers(mvcPatterns.toArray(String[]::new)).permitAll();
                }

                // ant模式
                var antPatterns = allowList.stream().filter(t -> t != null && !mvcPatterns.contains(t)).collect(Collectors.toSet());
                if (!antPatterns.isEmpty()) {
                    urlRegistry.antMatchers(antPatterns.toArray(String[]::new)).permitAll();
                }
            }
            urlRegistry.anyRequest().authenticated();
        };
    }

    /**
     * 是否适配mvc匹配模式
     *
     * @param pattern
     * @return
     */
    protected boolean adapterMvcRequestMatch(String pattern) {
        return AuthorizationProperties.RequestMatcherType.MVC.equals(authorizationProperties.getRequestMatcherType());
    }

    /**
     * OAuth2 Resource Server配置
     *
     * @return
     */
    protected Customizer<OAuth2ResourceServerConfigurer<HttpSecurity>> oauth2ResourceServer() {
        return OAuth2ResourceServerConfigurer::jwt;
    }

    /**
     * 获取黑名单
     *
     * @return
     */
    protected Set<String> getRejectUris() {
        Set<String> rejectList = new HashSet<>();
        rejectList.addAll(InterceptUri.getRejectUri());
        rejectList.addAll(authorizationProperties.getRejectList());

        return rejectList;
    }

    /**
     * 获取白名单
     *
     * @return
     */
    protected Set<String> getAllowUris() {
        Set<String> allowList = new HashSet<>();
        allowList.addAll(InterceptUri.getAllowUri());
        allowList.addAll(authorizationProperties.getAllowList());

        if (StringUtils.hasText(authorizationProperties.getLoginPage())) {
            // 登录页放入白名单
            allowList.add(authorizationProperties.getLoginPage());
        }

        return allowList;
    }

    /**
     * 是否引入自定义的BearerToken认证filter
     *
     * @return
     */
    protected boolean needCloudtBearerTokenAuthenticationFilter() {
        return true;
    }

    /**
     * 是否自动续期token
     *
     * @return
     */
    protected boolean autoRenewalToken() {
        return true;
    }

    /**
     * 请求缓存
     *
     * @return
     */
    public static RequestCache getRequestCache() {
        return new DelegateRequestCache();
    }

    @Autowired
    public void setAuthorizationProperties(AuthorizationProperties authorizationProperties) {
        this.authorizationProperties = authorizationProperties;
    }

    @Autowired
    public void setAuthorizationSdkProperties(AuthorizationSdkProperties authorizationSdkProperties) {
        this.authorizationSdkProperties = authorizationSdkProperties;
    }

    @Autowired
    public void setCacheObjectProvider(ObjectProvider<AuthenticationCache> cacheObjectProvider) {
        this.cacheObjectProvider = cacheObjectProvider;
    }

    @Autowired
    public void setSsoProviderObjectProvider(ObjectProvider<SsoProvider> ssoProviderObjectProvider) {
        this.ssoProviderObjectProvider = ssoProviderObjectProvider;
    }

    @Autowired
    public void setAuthorizationConfigurerCustomizerObjectProvider(ObjectProvider<AuthorizationConfigurerCustomizer> authorizationConfigurerCustomizerObjectProvider) {
        this.authorizationConfigurerCustomizerObjectProvider = authorizationConfigurerCustomizerObjectProvider;
    }

    @Autowired
    public void setJwtDecoder(JwtDecoder jwtDecoder) {
        this.jwtDecoder = jwtDecoder;
    }

    @Autowired
    public void setTokenResolverObjectProvider(ObjectProvider<BearerTokenResolver> tokenResolverObjectProvider) {
        this.tokenResolverObjectProvider = tokenResolverObjectProvider;
    }

    static class DelegateRequestCache implements RequestCache {
        private final List<RequestCache> requestCaches = new ArrayList<>();

        public DelegateRequestCache() {
            requestCaches.add(new HttpSessionRequestCache());
            requestCaches.add(new CookieRequestCache());
        }

        @Override
        public void saveRequest(HttpServletRequest request, HttpServletResponse response) {
            for (RequestCache cache : requestCaches) {
                cache.saveRequest(request, response);
            }
        }

        @Override
        public SavedRequest getRequest(HttpServletRequest request, HttpServletResponse response) {
            for (RequestCache cache : requestCaches) {
                var req = cache.getRequest(request, response);
                if (req != null) {
                    return req;
                }
            }
            return null;
        }

        @Override
        public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) {
            for (RequestCache cache : requestCaches) {
                var req = cache.getMatchingRequest(request, response);
                if (req != null) {
                    return req;
                }
            }
            return null;
        }

        @Override
        public void removeRequest(HttpServletRequest request, HttpServletResponse response) {
            for (RequestCache cache : requestCaches) {
                cache.removeRequest(request, response);
            }
        }
    }
}
