package com.elitescloud.boot.auth.provider.config.servlet;

import com.elitescloud.boot.auth.client.OidcClaim;
import com.elitescloud.boot.auth.client.common.SecurityConstants;
import com.elitescloud.boot.auth.client.config.AuthorizationProperties;
import com.elitescloud.boot.auth.client.config.security.AbstractServletSecurityConfig;
import com.elitescloud.boot.auth.client.config.security.handler.DefaultAccessDeniedHandler;
import com.elitescloud.boot.auth.client.config.security.handler.DelegateAuthenticationCallable;
import com.elitescloud.boot.auth.client.config.support.AuthenticationCallable;
import com.elitescloud.boot.auth.client.token.AbstractCustomAuthenticationToken;
import com.elitescloud.boot.auth.client.tool.RedisHelper;
import com.elitescloud.boot.auth.provider.config.LoginSupportConfig;
import com.elitescloud.boot.auth.provider.config.servlet.oauth2.OAuth2AuthorizationCodeRequestCache;
import com.elitescloud.boot.auth.provider.config.servlet.oauth2.OAuth2AuthorizationCodeUserVerifier;
import com.elitescloud.boot.auth.provider.config.servlet.oauth2.configurer.OAuth2AuthorizationCodeStateFilterSecurityConfigurer;
import com.elitescloud.boot.auth.provider.config.servlet.oauth2.configurer.OAuth2AuthorizationCodeUserSecurityConfigurer;
import com.elitescloud.boot.auth.provider.config.servlet.oauth2.handler.*;
import com.elitescloud.boot.auth.provider.security.configurer.LoginFilterSecurityConfigurer;
import com.elitescloud.boot.auth.provider.security.configurer.support.LoginFilterCustomizer;
import com.elitescloud.boot.auth.provider.security.generator.token.TokenGenerator;
import com.elitescloud.boot.auth.provider.security.impl.RedisOAuth2AuthorizationCodeRequestCache;
import com.elitescloud.boot.auth.resolver.UniqueRequestResolver;
import com.elitescloud.boot.auth.resolver.impl.DefaultUniquestResolver;
import com.elitescloud.boot.util.JwtUtil;
import com.elitescloud.cloudt.security.entity.GeneralUserDetails;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers.OAuth2AuthorizationServerConfigurer;
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext;
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.security.Principal;
import java.util.function.Function;

/**
 * OAuth2认证方式服务端.
 *
 * @author Kaiser（wang shao）
 * @date 2022/6/21
 */
@Log4j2
@ConditionalOnProperty(prefix = AuthorizationProperties.CONFIG_PREFIX, name = "type", havingValue = "oauth2_server")
@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET)
@Import({LoginSupportConfig.class})
public class ServletOAuth2ServerConfig extends AbstractServletSecurityConfig {
    private ObjectProvider<LoginFilterCustomizer<HttpSecurity>> loginFilterCustomizers;
    private ObjectProvider<OAuth2AuthorizationCodeUserVerifier> auth2AuthorizeUserVerifiers;

    private final UniqueRequestResolver uniqueRequestResolver = ServletOAuth2ServerConfig.getOAuth2UniqueRequestResolver();

    /**
     * OAuth2 认证SecurityFilterChain
     *
     * @param http
     * @return
     * @throws Exception
     */
    @Bean(SECURITY_CHAIN_AUTH2_SERVER)
    @Order(Ordered.HIGHEST_PRECEDENCE)
    @ConditionalOnMissingBean(name = SECURITY_CHAIN_AUTH2_SERVER)
    public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http,
                                                                      OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache,
                                                                      AuthorizationServerSettings serverSettings,
                                                                      RegisteredClientRepository registeredClientRepository
    ) throws Exception {
        var requestCache = AbstractServletSecurityConfig.getRequestCache();
        // 认证失败handler
        var authenticationFailureHandler = new OAuth2AuthenticationFailHandler();
        // 未认证handler
        var entryPointHandler = new OAuth2ServerJsonAuthenticationEntryPointHandler(authorizationCodeRequestCache, serverSettings.getAuthorizationEndpoint());
        entryPointHandler.setUniqueRequestResolver(uniqueRequestResolver);
        // 无权限handler
        var accessDeniedHandler = new DefaultAccessDeniedHandler();
        // 授权码模式认证失败
        var authorizationErrorHandler = new OAuth2AuthorizationErrorResponseHandler();

        OAuth2AuthorizationServerConfigurer authorizationServerConfigurer = new OAuth2AuthorizationServerConfigurer();
        RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher();

        // 客户端认证配置
        authorizationServerConfigurer
                .authorizationEndpoint(configurer ->
                        configurer.errorResponseHandler(authorizationErrorHandler)
                                .authorizationResponseHandler(new OAuth2AuthorizationResponseHandler())
                )
                .clientAuthentication(configurer ->
                        configurer.errorResponseHandler(authenticationFailureHandler)
                )
                .tokenEndpoint(configurer ->
                        configurer.errorResponseHandler(authenticationFailureHandler)
                                .accessTokenResponseHandler(new OAuth2AccessTokenResponseHandler())
                )
                .oidc(configurer -> configurer.userInfoEndpoint(userInfoEndpoint ->
                                userInfoEndpoint.userInfoMapper(oidcUserInfoMapper())
                        )
                )
        ;

        http.requestMatcher(endpointsMatcher)
                .authorizeRequests(req -> req.anyRequest().authenticated())
                .csrf(csrf -> {
                    if (Boolean.FALSE.equals(authorizationProperties.getCsrfEnabled())) {
                        csrf.disable();
                        return;
                    }
                    csrf.ignoringRequestMatchers(endpointsMatcher);
                })
                .apply(authorizationServerConfigurer)
                .and()
                // 对于不支持redirect的终端，依赖于state维持状态的配置
                .apply(new OAuth2AuthorizationCodeStateFilterSecurityConfigurer<>(serverSettings))
                .uniqueRequestResolver(uniqueRequestResolver)
                .and()
                // 认证用户校验
                .apply(new OAuth2AuthorizationCodeUserSecurityConfigurer<>(serverSettings))
                .userVerifiers(auth2AuthorizeUserVerifiers)
                .and()
                // 异常处理配置
                .exceptionHandling(configurer -> {
                    // 未认证时的处理
                    if (StringUtils.hasText(authorizationProperties.getLoginPage())) {
                        // 有配置登录页，则针对web型的进行支持重定向至登录页
                        var loginUrlEntryPoint = new OAuth2ServerLoginUrlAuthenticationEntryPointHandler(authorizationProperties.getLoginPage(), registeredClientRepository, authorizationCodeRequestCache);
                        loginUrlEntryPoint.setUniqueRequestResolver(uniqueRequestResolver);
                        loginUrlEntryPoint.setRequestCache(requestCache);
                        configurer.defaultAuthenticationEntryPointFor(entryPointHandler, new RequestHeaderRequestMatcher(SecurityConstants.HEADER_AUTH_REDIRECT, "false"))
                                .defaultAuthenticationEntryPointFor(loginUrlEntryPoint, new AntPathRequestMatcher("/**"));
                    } else {
                        // 未配置登录页，则都直接json格式返回
                        configurer.authenticationEntryPoint(entryPointHandler);
                    }

                    // 无权限时的处理
                    configurer.accessDeniedHandler(accessDeniedHandler);
                })
        ;

        // resource server配置
        http.oauth2ResourceServer(super.oauth2ResourceServer());

        // cors配置
        corsConfiguration(http);

        return http.build();
    }

    /**
     * 用户认证SecurityFilterChain
     *
     * @param http
     * @return
     * @throws Exception
     */
    @Bean(SECURITY_CHAIN_DEFAULT)
    @ConditionalOnMissingBean(name = SECURITY_CHAIN_DEFAULT)
    public SecurityFilterChain defaultSecurityFilterChain(HttpSecurity http,
                                                          TokenGenerator tokenGenerator,
                                                          OAuth2AuthorizationCodeRequestCache auth2AuthorizationCodeRequestCache,
                                                          RegisteredClientRepository clientRepository,
                                                          OAuth2AuthorizationService authorizationService,
                                                          AuthorizationServerSettings serverSettings
    ) throws Exception {
        var requestCache = AbstractServletSecurityConfig.getRequestCache();
        // 认证成功处理器
        OAuth2ServerAuthenticationSuccessHandler auth2ServerAuthenticationSuccessHandler =
                new OAuth2ServerAuthenticationSuccessHandler(serverSettings.getAuthorizationEndpoint(), authorizationProperties,
                        auth2AuthorizationCodeRequestCache, clientRepository, authorizationService);
        auth2ServerAuthenticationSuccessHandler.setTokenGenerator(tokenGenerator);
        AuthenticationCallable authenticationCallable = DelegateAuthenticationCallable.getInstance();
        auth2ServerAuthenticationSuccessHandler.setAuthenticationCallable(authenticationCallable);
        auth2ServerAuthenticationSuccessHandler.setUniqueRequestResolver(uniqueRequestResolver);
        auth2ServerAuthenticationSuccessHandler.setRequestCache(requestCache);

        OAuth2ServerAuthenticationFailureHandler failureHandler = new OAuth2ServerAuthenticationFailureHandler(authenticationCallable, auth2AuthorizationCodeRequestCache);
        failureHandler.setUniqueRequestResolver(uniqueRequestResolver);
        failureHandler.setRequestCache(requestCache);


        super.defaultSecurityConfig(http)
                .apply(new LoginFilterSecurityConfigurer<>(loginFilterCustomizers))
                .successHandler(auth2ServerAuthenticationSuccessHandler)
                .failureHandler(failureHandler)
        ;

        return http.build();
    }

    /**
     * 授权码缓存
     * <p>
     * 针对不支持重定向的客户端的缓存方案
     *
     * @param redisHelper
     * @return
     */
    @Bean
    @ConditionalOnMissingBean
    public OAuth2AuthorizationCodeRequestCache oAuth2AuthorizationCodeRequestCache(RedisHelper redisHelper) {
        return new RedisOAuth2AuthorizationCodeRequestCache(redisHelper);
    }

    /**
     * jwk源配置
     *
     * @param rsaKey
     * @return
     */
    @Bean
    public JWKSource<SecurityContext> jwkSource(RSAKey rsaKey) {
        return JwtUtil.generateJwkSource(rsaKey);
    }

    @Bean
    public AuthorizationServerSettings authorizationServerSettings() {
        var builder = AuthorizationServerSettings.builder();

        if (StringUtils.hasText(authorizationProperties.getIssuerUrl())) {
            builder.issuer(authorizationProperties.getIssuerUrl());
        }

        return builder.build();
    }

    /**
     * 登录过滤器自定义处理
     *
     * @param loginFilterCustomizers
     */
    @Autowired
    public void setLoginFilterCustomizers(ObjectProvider<LoginFilterCustomizer<HttpSecurity>> loginFilterCustomizers) {
        this.loginFilterCustomizers = loginFilterCustomizers;
    }

    /**
     * 认证用户校验
     *
     * @param auth2AuthorizeUserVerifiers
     */
    @Autowired
    public void setAuth2AuthorizeUserVerifiers(ObjectProvider<OAuth2AuthorizationCodeUserVerifier> auth2AuthorizeUserVerifiers) {
        this.auth2AuthorizeUserVerifiers = auth2AuthorizeUserVerifiers;
    }

    /**
     * OAuth2认证回调
     *
     * @return
     */
    @Bean
    public AuthenticationCallable oauth2AuthenticationCallable() {
        return new AuthenticationCallable() {
            @Override
            public void onLogout(HttpServletRequest request, HttpServletResponse response, String token, Object principal) {
                var reqId = uniqueRequestResolver.analyze(request);
                if (StringUtils.hasText(reqId)) {
                    uniqueRequestResolver.clear(response, reqId);
                }
            }
        };
    }

    /**
     * 登出后的重定向登录页
     *
     * @param auth2AuthorizationCodeRequestCache
     * @param clientRepository
     * @return
     */
    @Bean
    public OAuth2ServerLogoutRedirectHandler oAuth2ServerLogoutRedirectHandler(OAuth2AuthorizationCodeRequestCache auth2AuthorizationCodeRequestCache,
                                                                               RegisteredClientRepository clientRepository) {
        var handler = new OAuth2ServerLogoutRedirectHandler(clientRepository, authorizationProperties, auth2AuthorizationCodeRequestCache);
        handler.setRequestCache(AbstractServletSecurityConfig.getRequestCache());
        handler.setUniqueRequestResolver(uniqueRequestResolver);
        return handler;
    }

    public static UniqueRequestResolver getOAuth2UniqueRequestResolver() {
        DefaultUniquestResolver resolver = new DefaultUniquestResolver("X-OAuth2-Urq");
        resolver.setCookieMaxAge(-1);
        return resolver;
    }

    private Function<OidcUserInfoAuthenticationContext, OidcUserInfo> oidcUserInfoMapper() {
        return authenticationContext -> {
            OAuth2Authorization authorization = authenticationContext.getAuthorization();
            Object principal = authorization.getAttribute(Principal.class.getName());
            if (principal instanceof AbstractCustomAuthenticationToken) {
                AbstractCustomAuthenticationToken authenticationToken = (AbstractCustomAuthenticationToken) principal;
                GeneralUserDetails userDetails = (GeneralUserDetails) authenticationToken.getPrincipal();
                return OidcUserInfo.builder()
                        .subject(userDetails.getUsername())
                        .name(userDetails.getUser().getPrettyName())
                        .phoneNumber(userDetails.getUser().getMobile())
                        .email(userDetails.getUser().getEmail())
                        .claim(OidcClaim.KEY_USER_ID, userDetails.getUserId())
                        .claim(OidcClaim.KEY_TERMINAL, authenticationToken.getTerminal())
                        .build();
            }

            return OidcUserInfo.builder()
                    .subject(authorization.getPrincipalName())
                    .build();
        };
    }
}
