package com.elitescloud.boot.auth.provider.config.servlet.oauth2.handler;

import com.elitescloud.boot.auth.provider.config.servlet.oauth2.OAuth2AuthorizationCodeRequestCache;
import com.elitescloud.boot.auth.resolver.UniqueRequestResolver;
import com.elitescloud.cloudt.common.base.ApiResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.util.StringUtils;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
 * OAuth2 客户端认证成功处理器.
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/10
 */
public class OAuth2AuthorizationResponseHandler extends AbstractOAuth2ServerHandler implements AuthenticationSuccessHandler {
    private static final Logger logger = LoggerFactory.getLogger(OAuth2AuthorizationResponseHandler.class);

    private final UniqueRequestResolver uniqueRequestResolver;
    private final OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache;

    public OAuth2AuthorizationResponseHandler(UniqueRequestResolver uniqueRequestResolver,
                                              OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache) {
        this.uniqueRequestResolver = uniqueRequestResolver;
        this.authorizationCodeRequestCache = authorizationCodeRequestCache;
    }

    @Override
    public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException {
        OAuth2AuthorizationCodeRequestAuthenticationToken authenticationToken = (OAuth2AuthorizationCodeRequestAuthenticationToken) authentication;
        if (authenticationToken.getAuthorizationCode() != null) {
            logger.info("OAuth2 authorize success：{}", authenticationToken.getAuthorizationCode().getTokenValue());
        }

        // 清空缓存认证
        var reqId = uniqueRequestResolver.analyze(request);
        if (StringUtils.hasText(reqId)) {
            authorizationCodeRequestCache.removeAuthenticationToken(reqId);
        }

        // 返回
        if (supportRedirect(request)) {
            // 支持重定向的，则走原spring security的逻辑
            sendAuthorizationResponse(request, response, authenticationToken);
            return;
        }

        // 返回json格式数据
        Map<String, Object> result = new HashMap<>(4);
        result.put(OAuth2ParameterNames.CODE, authenticationToken.getAuthorizationCode().getTokenValue());
        if (StringUtils.hasText(authenticationToken.getState())) {
            result.put(OAuth2ParameterNames.STATE, authenticationToken.getState());
        }
        writeResponse(response, ApiResult.ok(result));
    }
}
