package com.elitescloud.boot.auth.provider.config.servlet.oauth2.handler;

import cn.hutool.core.text.CharSequenceUtil;
import com.elitescloud.boot.auth.client.common.AuthorizationException;
import com.elitescloud.boot.auth.client.common.SecurityConstants;
import com.elitescloud.boot.auth.client.config.AuthorizationProperties;
import com.elitescloud.boot.auth.client.config.support.AuthenticationCallable;
import com.elitescloud.boot.auth.model.OAuthToken;
import com.elitescloud.boot.auth.provider.common.AuthorizationConstant;
import com.elitescloud.boot.auth.provider.config.servlet.oauth2.OAuth2AuthorizationCodeRequestCache;
import com.elitescloud.boot.auth.provider.security.generator.token.TokenGenerator;
import com.elitescloud.boot.auth.resolver.UniqueRequestResolver;
import com.elitescloud.cloudt.common.base.ApiResult;
import lombok.extern.log4j.Log4j2;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.DefaultSavedRequest;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.security.Principal;

/**
 * 认证成功时的处理.
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/3
 */
@Log4j2
public class OAuth2ServerAuthenticationSuccessHandler extends AbstractOAuth2ServerHandler implements AuthenticationSuccessHandler {

    private RequestCache requestCache = new HttpSessionRequestCache();

    /**
     * OAuth2认证地址
     */
    private final String authorizationEndpoint;
    private final AuthorizationProperties authorizationProperties;
    private final OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache;
    private final RegisteredClientRepository clientRepository;
    private final OAuth2AuthorizationService oAuth2AuthorizationService;

    /**
     * token生成器
     */
    private TokenGenerator tokenGenerator;
    /**
     * 认证信息缓存
     */
    private AuthenticationCallable authenticationCallable;
    private UniqueRequestResolver uniqueRequestResolver;

    public OAuth2ServerAuthenticationSuccessHandler(String authorizationEndpoint,
                                                    AuthorizationProperties authorizationProperties,
                                                    OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache,
                                                    RegisteredClientRepository clientRepository,
                                                    OAuth2AuthorizationService oAuth2AuthorizationService) {
        this.authorizationEndpoint = authorizationEndpoint;
        this.authorizationProperties = authorizationProperties;
        this.authorizationCodeRequestCache = authorizationCodeRequestCache;
        this.clientRepository = clientRepository;
        this.oAuth2AuthorizationService = oAuth2AuthorizationService;
    }

    @Override
    public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException {
        // 返回前处理
        this.beforeReturn(request, response);
        // 先获取认证前的请求
        SavedRequest savedRequest = requestCache.getRequest(request, response);

        if (this.attemptToHandleOAuth2Response(request, response, authentication, savedRequest)) {
            // OAuth2的，已处理过
            authenticationCallable.onLogin(request, response, null, authentication);
            return;
        }

        // 非OAuth2的，返回token
        OAuthToken authToken = generateToken(authentication);
        var result = ApiResult.ok(authToken);
        writeResponse(response, result);

        request.setAttribute(AuthorizationConstant.REQUEST_ATTRIBUTE_LOGIN_RESULT, result);
        authenticationCallable.onLogin(request, response, authToken == null ? null : authToken.getAccessToken(), authentication);
    }

    private void beforeReturn(HttpServletRequest request, HttpServletResponse response) {
        var session = request.getSession(true);
        if (session != null) {
            response.addHeader(SecurityConstants.HEADER_SESSION_ID, session.getId());
        }
    }

    private OAuthToken generateToken(Authentication authentication) {
        if (tokenGenerator == null) {
            return null;
        }

        return tokenGenerator.generate(authentication);
    }

    private boolean attemptToHandleOAuth2Response(HttpServletRequest request, HttpServletResponse response,
                                                  Authentication authentication, SavedRequest savedRequest) throws IOException {
        if (savedRequest != null) {
            if (!CharSequenceUtil.equals(obtainServletPath(savedRequest), authorizationEndpoint)) {
                // 非OAuth2认证
                return false;
            }
            String[] clientIds = savedRequest.getParameterValues(OAuth2ParameterNames.CLIENT_ID);
            if (clientIds != null && clientIds.length > 0) {
                request.setAttribute(AuthorizationConstant.REQUEST_ATTRIBUTE_CLIENT_ID, clientIds[0]);
            }

            if (super.supportRedirect(request)) {
                // 支持重定向
                super.sendRedirect(authorizationProperties.getRedirectUriPrefix(), (DefaultSavedRequest) savedRequest, request, response);
                return true;
            }
            // 不支持重定向
            this.writeOAuth2Response(request, response, savedRequest);
            return true;
        }

        // 获取登录前的认证请求
        var reqId = uniqueRequestResolver.analyze(request);
        if (!StringUtils.hasText(reqId)) {
            log.debug("缺少state参数，无法确定为OAuth2请求");
            return false;
        }

        OAuth2AuthorizationCodeRequestAuthenticationToken codeRequest = authorizationCodeRequestCache.getAuthenticationToken(reqId);
        if (codeRequest == null) {
            log.error("未找到授权码认证请求信息：{}", reqId);
            writeResponse(response, ApiResult.fail("认证信息已超时，请刷新后重试"));
            return true;
        }
        request.setAttribute(AuthorizationConstant.REQUEST_ATTRIBUTE_CLIENT_ID, codeRequest.getClientId());
        RegisteredClient client = clientRepository.findByClientId(codeRequest.getClientId());
        if (client == null) {
            writeResponse(response, ApiResult.fail("客户端不存在或已禁用"));
            return true;
        }
        OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
                .authorizationUri(codeRequest.getAuthorizationUri())
                .clientId(codeRequest.getClientId())
                .redirectUri(codeRequest.getRedirectUri())
                .scopes(codeRequest.getScopes())
                .state(codeRequest.getState())
                .additionalParameters(codeRequest.getAdditionalParameters())
                .build();
        OAuth2Authorization authorization = authorizationBuilder(client, authentication, authorizationRequest)
                .attribute(OAuth2ParameterNames.STATE, reqId)
                .attribute(OAuth2ParameterNames.CLIENT_ID, codeRequest.getClientId())
                .build();
        // 保存授权码认证成功信息
        oAuth2AuthorizationService.save(authorization);

        // 返回客户端处理结果
        String redirectUri = CollectionUtils.isEmpty(client.getRedirectUris()) ? null :
                client.getRedirectUris().stream().filter(CharSequenceUtil::isNotBlank).findFirst().orElse(null);
        if (StringUtils.hasText(redirectUri) && supportRedirect(request)) {
            sendRedirect(request, response, redirectUri);
            return true;
        }
        String authorizeUrl = normalizeUrl(authorizationProperties.getRedirectUriPrefix(), authorizationRequest.getAuthorizationRequestUri());
        writeResponse(response, ApiResult.ok(authorizeUrl));
        return true;
    }

    private String normalizeUrl(String uriPrefix, String url) {
        if (!StringUtils.hasText(uriPrefix)) {
            return url;
        }

        var uri = UriComponentsBuilder.fromUriString(url)
                .build();
        return UriComponentsBuilder.fromUriString(uriPrefix + "/" + uri.getPath())
                .query(uri.getQuery())
                .build().toString();
    }

    private void writeOAuth2Response(HttpServletRequest request, HttpServletResponse response, SavedRequest savedRequest) throws IOException {
        String redirectUrl = super.obtainRedirectUrl(authorizationProperties.getRedirectUriPrefix(), (DefaultSavedRequest) savedRequest);
        super.writeResponse(response, ApiResult.ok(redirectUrl));
    }

    private String obtainServletPath(SavedRequest savedRequest) {
        if (savedRequest instanceof DefaultSavedRequest) {
            var request = (DefaultSavedRequest) savedRequest;
            return CharSequenceUtil.blankToDefault(request.getServletPath(), request.getRequestURI());
        }

        throw new AuthorizationException("暂不支持的SavedRequest类型");
    }

    /**
     * 构建OAuthentication
     * <p>
     * 复制于{@link OAuth2AuthorizationCodeRequestAuthenticationProvider#authorizationBuilder(RegisteredClient, Authentication, OAuth2AuthorizationRequest)}
     *
     * @param registeredClient
     * @param principal
     * @param authorizationRequest
     * @return
     */
    private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient, Authentication principal,
                                                                    OAuth2AuthorizationRequest authorizationRequest) {
        return OAuth2Authorization.withRegisteredClient(registeredClient)
                .principalName(principal.getName())
                .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
                .attribute(Principal.class.getName(), principal)
                .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest);
    }

    public void setTokenGenerator(TokenGenerator tokenGenerator) {
        this.tokenGenerator = tokenGenerator;
    }

    public void setAuthenticationCallable(AuthenticationCallable authenticationCallable) {
        this.authenticationCallable = authenticationCallable;
    }

    public void setUniqueRequestResolver(UniqueRequestResolver uniqueRequestResolver) {
        this.uniqueRequestResolver = uniqueRequestResolver;
    }

    public void setRequestCache(RequestCache requestCache) {
        this.requestCache = requestCache;
    }
}
