package com.elitescloud.boot.auth.provider.config.servlet.oauth2.handler;

import cn.hutool.core.text.CharSequenceUtil;
import com.elitescloud.boot.auth.CommonAuthenticationToken;
import com.elitescloud.boot.auth.client.common.AuthorizationException;
import com.elitescloud.boot.auth.client.common.SecurityConstants;
import com.elitescloud.boot.auth.client.config.AuthorizationProperties;
import com.elitescloud.boot.auth.client.config.support.AuthenticationCallable;
import com.elitescloud.boot.auth.model.OAuthToken;
import com.elitescloud.boot.auth.provider.common.AuthorizationConstant;
import com.elitescloud.boot.auth.provider.common.LoginParameterNames;
import com.elitescloud.boot.auth.provider.config.servlet.oauth2.OAuth2AuthorizationCodeRequestCache;
import com.elitescloud.boot.auth.provider.security.generator.token.TokenGenerator;
import com.elitescloud.boot.auth.resolver.UniqueRequestResolver;
import com.elitescloud.cloudt.common.base.ApiResult;
import lombok.extern.log4j.Log4j2;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.DefaultSavedRequest;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.security.Principal;
import java.time.Duration;
import java.time.Instant;
import java.util.*;

/**
 * 认证成功时的处理.
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/3
 */
@Log4j2
public class OAuth2ServerAuthenticationSuccessHandler extends AbstractOAuth2ServerHandler implements AuthenticationSuccessHandler {
    private RequestCache requestCache = new HttpSessionRequestCache();

    /**
     * OAuth2认证地址
     */
    private final String authorizationEndpoint;
    private final AuthorizationProperties authorizationProperties;
    private final OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache;
    private final RegisteredClientRepository clientRepository;
    private final OAuth2AuthorizationService oAuth2AuthorizationService;

    /**
     * token生成器
     */
    private TokenGenerator tokenGenerator;
    /**
     * 认证信息缓存
     */
    private AuthenticationCallable authenticationCallable;
    private UniqueRequestResolver uniqueRequestResolver;

    public OAuth2ServerAuthenticationSuccessHandler(String authorizationEndpoint,
                                                    AuthorizationProperties authorizationProperties,
                                                    OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache,
                                                    RegisteredClientRepository clientRepository,
                                                    OAuth2AuthorizationService oAuth2AuthorizationService) {
        this.authorizationEndpoint = authorizationEndpoint;
        this.authorizationProperties = authorizationProperties;
        this.authorizationCodeRequestCache = authorizationCodeRequestCache;
        this.clientRepository = clientRepository;
        this.oAuth2AuthorizationService = oAuth2AuthorizationService;
    }

    @Override
    public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException {
        if (authentication instanceof CommonAuthenticationToken) {
            authentication = ((CommonAuthenticationToken) authentication).getOriginal();
        }
        // 返回前处理
        this.beforeReturn(request, response);
        // 先获取认证前的请求
        SavedRequest savedRequest = requestCache.getRequest(request, response);

        if (this.attemptToHandleOAuth2Response(request, response, authentication, savedRequest)) {
            // OAuth2的，已处理过
            authenticationCallable.onLogin(request, response, null, authentication);
            return;
        }

        // 非OAuth2的，返回token
        OAuthToken authToken = generateToken(authentication);
        var result = ApiResult.ok(authToken);
        writeResponse(response, result);

        request.setAttribute(AuthorizationConstant.REQUEST_ATTRIBUTE_LOGIN_RESULT, result);
        authenticationCallable.onLogin(request, response, authToken == null ? null : authToken.getAccessToken(), authentication);
    }

    private void beforeReturn(HttpServletRequest request, HttpServletResponse response) {
        var session = request.getSession(false);
        if (session != null) {
            response.addHeader(SecurityConstants.HEADER_SESSION_ID, session.getId());
        }
    }

    private OAuthToken generateToken(Authentication authentication) {
        if (tokenGenerator == null) {
            return null;
        }

        return tokenGenerator.generate(authentication);
    }

    private boolean attemptToHandleOAuth2Response(HttpServletRequest request, HttpServletResponse response,
                                                  Authentication authentication, SavedRequest savedRequest) throws IOException {
        if (savedRequest != null) {
            var defaultSavedRequest = (DefaultSavedRequest) savedRequest;
            if (!CharSequenceUtil.equals(obtainServletPath(savedRequest), authorizationEndpoint)) {
                // 非OAuth2认证
                return this.attemptHandleOAuth2Response(request, response, authentication, defaultSavedRequest);
            }
            String[] clientIds = savedRequest.getParameterValues(OAuth2ParameterNames.CLIENT_ID);
            String clientId = clientIds != null && clientIds.length > 0 ? clientIds[0] : null;

            // 保存认证结果
            this.updateOAuth2Authorization(request, clientId, authentication, defaultSavedRequest);

            if (super.supportRedirect(request)) {
                // 支持重定向
                super.sendRedirect(super.getUrlPrefix(defaultSavedRequest, authorizationProperties), defaultSavedRequest, request, response);
                return true;
            }
            // 不支持重定向
            this.writeOAuth2Response(request, response, defaultSavedRequest);
            return true;
        }

        return this.attemptHandleOAuth2Response(request, response, authentication, null);
    }

    private boolean attemptHandleOAuth2Response(HttpServletRequest request, HttpServletResponse response,
                                                Authentication authentication, DefaultSavedRequest savedRequest) throws IOException {
        // 获取登录前的认证请求
        var reqId = uniqueRequestResolver.analyze(request);
        if (StringUtils.hasText(reqId)) {
            this.handleOAuth2ResponseByReqId(request, response, authentication, savedRequest, reqId);
            return true;
        }

        // 根据认证客户端
        var clientId = request.getParameter(LoginParameterNames.CLIENT_ID);
        if (StringUtils.hasText(clientId)) {
            this.handleOAuth2ResponseByClientId(request, response, authentication, clientId);
            return true;
        }

        return false;
    }

    private void handleOAuth2ResponseByClientId(HttpServletRequest request, HttpServletResponse response, Authentication authentication,
                                                String clientId) throws IOException {
        log.info("OAuth2 Authorization ：{}, {}", clientId, authentication.getName());

        var client = clientRepository.findByClientId(clientId);
        if (client == null) {
            log.error("客户端不存在：{}", clientId);
            writeResponse(response, ApiResult.fail("客户端不存在，请联系管理员确认配置正常"));
            return;
        }

        var redirectUris = client.getRedirectUris();
        if (CollectionUtils.isEmpty(redirectUris)) {
            log.error("客户端配置错误，缺少配置回调地址：{}", clientId);
            writeResponse(response, ApiResult.fail("客户端配置错误，缺少配置回调地址，请联系管理员"));
            return;
        }
        var redirectUrl = redirectUris.stream().filter(t -> !t.contains("127.0.0.1")).findFirst().orElse(redirectUris.stream().findFirst().orElse(null));

        var urlPrefix = super.getUrlPrefix(request, authorizationProperties);
        var authorizeUrl = UriComponentsBuilder.fromUriString(urlPrefix)
                .path(authorizationEndpoint)
                .queryParam(OAuth2ParameterNames.CLIENT_ID, clientId)
                .queryParam(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2ParameterNames.CODE)
                .queryParam(OAuth2ParameterNames.SCOPE, OidcScopes.OPENID)
                .queryParam(OAuth2ParameterNames.REDIRECT_URI, redirectUrl)
                .toUriString();
        writeResponse(response, ApiResult.ok(authorizeUrl));
    }

    private void handleOAuth2ResponseByReqId(HttpServletRequest request, HttpServletResponse response, Authentication authentication,
                                             DefaultSavedRequest savedRequest, String reqId) throws IOException {
        log.info("OAuth2 Authorization ：{}, {}", reqId, authentication.getName());
        OAuth2AuthorizationCodeRequestAuthenticationToken codeRequest = authorizationCodeRequestCache.getAuthenticationToken(reqId);
        if (codeRequest == null) {
            log.error("未找到授权码认证请求信息：{}", reqId);
            writeResponse(response, ApiResult.fail("认证信息已超时，请刷新后重试"));
            return;
        }

        RegisteredClient client = clientRepository.findByClientId(codeRequest.getClientId());
        if (client == null) {
            writeResponse(response, ApiResult.fail("客户端不存在或已禁用"));
            return;
        }
        OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
                .authorizationUri(codeRequest.getAuthorizationUri())
                .clientId(codeRequest.getClientId())
                .redirectUri(codeRequest.getRedirectUri())
                .scopes(codeRequest.getScopes())
                .state(codeRequest.getState())
                .additionalParameters(codeRequest.getAdditionalParameters())
                .build();
        OAuth2Authorization authorization = authorizationBuilder(client, authentication, authorizationRequest)
                .attribute(OAuth2ParameterNames.STATE, reqId)
                .build();
        // 保存授权码认证成功信息
        oAuth2AuthorizationService.save(authorization);

        // 返回客户端处理结果
        String redirectUri = CollectionUtils.isEmpty(client.getRedirectUris()) ? null :
                client.getRedirectUris().stream().filter(CharSequenceUtil::isNotBlank).findFirst().orElse(null);
        if (StringUtils.hasText(redirectUri) && supportRedirect(request)) {
            sendRedirect(request, response, redirectUri);
            return;
        }

        String urlPrefix = savedRequest == null ? super.getUrlPrefix(request, authorizationProperties) : super.getUrlPrefix(savedRequest, authorizationProperties);
        String authorizeUrl = normalizeUrl(urlPrefix, authorizationRequest.getAuthorizationRequestUri());
        writeResponse(response, ApiResult.ok(authorizeUrl));
    }

    private void updateOAuth2Authorization(HttpServletRequest request, String clientId,
                                           Authentication authentication, DefaultSavedRequest savedRequest) {
        // 获取登录前的认证请求
        var reqId = uniqueRequestResolver.analyze(request);
        if (!StringUtils.hasText(reqId)) {
            log.debug("缺少state参数，无法确定为OAuth2请求");
            return;
        }

        log.info("OAuth2认证请求：{}，{}", savedRequest.getRequestURL(), savedRequest.getQueryString());
        RegisteredClient client = clientRepository.findByClientId(clientId);
        if (client == null) {
            log.error("客户端不存在或已禁用：{}", clientId);
            return;
        }

        Set<String> scopes = null;
        String scope = super.getParameter(savedRequest, OAuth2ParameterNames.SCOPE);
        if (StringUtils.hasText(scope)) {
            scopes = new HashSet<>(
                    Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
        }

        Map<String, Object> additionalParameters = new HashMap<>();
        savedRequest.getParameterMap().forEach((key, value) -> {
            if (!key.equals(OAuth2ParameterNames.RESPONSE_TYPE) &&
                    !key.equals(OAuth2ParameterNames.CLIENT_ID) &&
                    !key.equals(OAuth2ParameterNames.REDIRECT_URI) &&
                    !key.equals(OAuth2ParameterNames.SCOPE) &&
                    !key.equals(OAuth2ParameterNames.STATE)) {
                additionalParameters.put(key, value[0]);
            }
        });
        OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
                .authorizationUri(savedRequest.getRequestURL())
                .clientId(clientId)
                .redirectUri(super.getParameter(savedRequest, OAuth2ParameterNames.REDIRECT_URI))
                .scopes(scopes)
                .state(super.getParameter(savedRequest, OAuth2ParameterNames.STATE))
                .additionalParameters(additionalParameters)
                .build();

        OAuth2Authorization authorization = authorizationBuilder(client, authentication, authorizationRequest)
                .attribute(OAuth2ParameterNames.STATE, reqId)
                .build();
        // 保存授权码认证成功信息
        oAuth2AuthorizationService.save(authorization);
    }

    private String normalizeUrl(String uriPrefix, String url) {
        if (!StringUtils.hasText(uriPrefix)) {
            return url;
        }

        var uri = UriComponentsBuilder.fromUriString(url)
                .build();
        return UriComponentsBuilder.fromUriString(uriPrefix)
                .path(uri.getPath())
                .query(uri.getQuery())
                .build().toString();
    }

    private void writeOAuth2Response(HttpServletRequest request, HttpServletResponse response, DefaultSavedRequest savedRequest) throws IOException {
        String urlPrefix = savedRequest == null ? super.getUrlPrefix(request, authorizationProperties) : super.getUrlPrefix(savedRequest, authorizationProperties);
        String redirectUrl = super.obtainRedirectUrl(urlPrefix, savedRequest);
        super.writeResponse(response, ApiResult.ok(redirectUrl));
    }

    private String obtainServletPath(SavedRequest savedRequest) {
        if (savedRequest instanceof DefaultSavedRequest) {
            var request = (DefaultSavedRequest) savedRequest;
            return CharSequenceUtil.blankToDefault(request.getServletPath(), request.getRequestURI());
        }

        throw new AuthorizationException("暂不支持的SavedRequest类型");
    }

    /**
     * 构建OAuthentication
     * <p>
     * 复制于{@link OAuth2AuthorizationCodeRequestAuthenticationProvider#authorizationBuilder(RegisteredClient, Authentication, OAuth2AuthorizationRequest)}
     *
     * @param registeredClient
     * @param principal
     * @param authorizationRequest
     * @return
     */
    private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient, Authentication principal,
                                                                    OAuth2AuthorizationRequest authorizationRequest) {
        return OAuth2Authorization.withRegisteredClient(registeredClient)
                .principalName(principal.getName())
                .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
                .attribute(Principal.class.getName(), principal)
                .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest)
                .attribute(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
                .attribute(OAuth2ParameterNames.EXPIRES_IN, Instant.now().plus(Duration.ofMinutes(5)).getEpochSecond())
                ;
    }

    public void setTokenGenerator(TokenGenerator tokenGenerator) {
        this.tokenGenerator = tokenGenerator;
    }

    public void setAuthenticationCallable(AuthenticationCallable authenticationCallable) {
        this.authenticationCallable = authenticationCallable;
    }

    public void setUniqueRequestResolver(UniqueRequestResolver uniqueRequestResolver) {
        this.uniqueRequestResolver = uniqueRequestResolver;
    }

    public void setRequestCache(RequestCache requestCache) {
        this.requestCache = requestCache;
    }
}
