package com.elitescloud.boot.auth.provider.config.servlet.oauth2.handler;

import com.elitescloud.cloudt.common.base.ApiCode;
import com.elitescloud.cloudt.common.base.ApiResult;
import com.elitescloud.cloudt.context.util.HttpServletUtil;
import lombok.extern.log4j.Log4j2;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
 * OAuth2服务端认证失败处理器.
 * <p>
 * 原处理逻辑：{@link OAuth2AuthorizationEndpointFilter#sendErrorResponse(HttpServletRequest, HttpServletResponse, AuthenticationException)}
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/3
 */
@Log4j2
public class OAuth2AuthorizationErrorResponseHandler extends AbstractOAuth2ServerHandler implements AuthenticationFailureHandler {

    private static final Map<String, String> ERROR_DESCRIPTION = new HashMap<>();

    static {
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.INVALID_REQUEST, "请求方式或请求参数错误");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, "未认证的客户端");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.ACCESS_DENIED, "无权访问");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, "不支持的response_type");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.INVALID_SCOPE, "无效的scope");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.INSUFFICIENT_SCOPE, "授权不足");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.INVALID_TOKEN, "token无效");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.SERVER_ERROR, "服务端错误");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.TEMPORARILY_UNAVAILABLE, "暂时不可用");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.INVALID_CLIENT, "客户端无效");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.INVALID_GRANT, "授权无效");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, "不支持的grant_type");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.UNSUPPORTED_TOKEN_TYPE, "不支持的token_type");
        ERROR_DESCRIPTION.put(OAuth2ErrorCodes.INVALID_REDIRECT_URI, "无效的redirect_uri");
    }

    @Override
    public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) throws IOException, ServletException {
        log.info("OAuth2认证异常：", exception);
        super.printParameters(request);
        super.printHeaders(request);

        // 针对OAuth2认证请求异常的处理
        if (exception instanceof OAuth2AuthorizationCodeRequestAuthenticationException) {
            this.sendErrorResponseForAuthorizationCodeRequest(request, response, (OAuth2AuthorizationCodeRequestAuthenticationException) exception);
            return;
        }

        if (super.supportRedirect(request)) {
            // 支持重定向，则调用原OAuth2的逻辑进行重定向
            this.sendErrorResponse(request, response, exception);
            return;
        }

        // 不支持重定向，则返回json格式数据
        sendErrorResponseByJson(request, response, exception);
    }

    private void sendErrorResponseForAuthorizationCodeRequest(HttpServletRequest request, HttpServletResponse response,
                                                              OAuth2AuthorizationCodeRequestAuthenticationException exception) throws IOException {
        var error = exception.getError();
        if (error == null) {
            writeResponse(response, ApiResult.fail(ApiCode.BAD_REQUEST, "认证失败：" + exception.getMessage()), HttpStatus.BAD_REQUEST);
            return;
        }

        var exceptionDescription = error.getDescription();
        if (OAuth2ErrorCodes.INVALID_REQUEST.equals(error.getErrorCode())) {
            var descArray = exceptionDescription.split(":");
            if (descArray.length == 2) {
                writeResponse(response, ApiResult.fail(ApiCode.BAD_REQUEST, "请求参数有误：" + descArray[1]), HttpStatus.BAD_REQUEST);
                return;
            }
        }
        writeResponse(response, ApiResult.fail(ApiCode.BAD_REQUEST, ERROR_DESCRIPTION.get(error.getErrorCode()) + ":" + exceptionDescription), HttpStatus.BAD_REQUEST);
    }

    /**
     * 认证失败时的处理
     * <p>
     * 复制于{@link OAuth2AuthorizationEndpointFilter#sendErrorResponse(HttpServletRequest, HttpServletResponse, AuthenticationException)}
     *
     * @param request
     * @param response
     * @param exception
     * @throws IOException
     */
    private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
                                   AuthenticationException exception) throws IOException {

        OAuth2AuthorizationCodeRequestAuthenticationException authorizationCodeRequestAuthenticationException =
                (OAuth2AuthorizationCodeRequestAuthenticationException) exception;
        OAuth2Error error = authorizationCodeRequestAuthenticationException.getError();
        OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
                authorizationCodeRequestAuthenticationException.getAuthorizationCodeRequestAuthentication();

        if (authorizationCodeRequestAuthentication == null ||
                !StringUtils.hasText(authorizationCodeRequestAuthentication.getRedirectUri())) {
//            response.sendError(HttpStatus.BAD_REQUEST.value(), error.toString());
            HttpServletUtil.writeJsonIgnoreException(response, error);
            return;
        }

        UriComponentsBuilder uriBuilder = UriComponentsBuilder
                .fromUriString(authorizationCodeRequestAuthentication.getRedirectUri())
                .queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode());
        if (StringUtils.hasText(error.getDescription())) {
            uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription());
        }
        if (StringUtils.hasText(error.getUri())) {
            uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri());
        }
        if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) {
            uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState());
        }
        super.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
    }

    private void sendErrorResponseByJson(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) throws IOException {
        OAuth2AuthorizationCodeRequestAuthenticationException authorizationCodeRequestAuthenticationException =
                (OAuth2AuthorizationCodeRequestAuthenticationException) exception;
        OAuth2Error error = authorizationCodeRequestAuthenticationException.getError();
        OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
                authorizationCodeRequestAuthenticationException.getAuthorizationCodeRequestAuthentication();
        // 请求方式有误
        if (authorizationCodeRequestAuthentication == null) {
            writeResponse(response, ApiResult.fail("请求方式或参数有误"), HttpStatus.BAD_REQUEST);
            return;
        }

        // 请求有误
        Map<String, Object> result = new HashMap<>();
        result.put(OAuth2ParameterNames.ERROR, error.getErrorCode());
        if (StringUtils.hasText(error.getDescription())) {
            result.put(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription());
        }
        if (StringUtils.hasText(error.getUri())) {
            result.put(OAuth2ParameterNames.ERROR_URI, error.getUri());
        }
        if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) {
            result.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState());
        }

        writeResponse(response, ApiResult.fail(ApiCode.UNAUTHORIZED, result), HttpStatus.UNAUTHORIZED);
    }
}
