package com.elitesland.cloudt.authorization.api.provider.security.handler.oauth2.server;

import com.elitesland.cloudt.authorization.api.client.model.OAuthToken;
import com.elitesland.cloudt.authorization.api.provider.security.handler.oauth2.server.support.OAuth2AuthorizationCodeRequestCache;
import com.elitesland.yst.common.base.ApiResult;
import lombok.extern.log4j.Log4j2;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.DefaultOAuth2AccessTokenResponseMapConverter;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.security.Principal;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/**
 * 认证成功时的处理.
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/3
 */
@Log4j2
public class OAuth2ServerAuthenticationSuccessHandler extends AbstractOAuth2ServerHandler implements AuthenticationSuccessHandler {

    private final OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache;
    private final RegisteredClientRepository clientRepository;
    private final OAuth2AuthorizationService oAuth2AuthorizationService;

    public OAuth2ServerAuthenticationSuccessHandler(OAuth2AuthorizationCodeRequestCache authorizationCodeRequestCache,
                                                    RegisteredClientRepository clientRepository, OAuth2AuthorizationService oAuth2AuthorizationService) {
        this.authorizationCodeRequestCache = authorizationCodeRequestCache;
        this.clientRepository = clientRepository;
        this.oAuth2AuthorizationService = oAuth2AuthorizationService;
    }

    @Override
    public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException {
        if (authentication instanceof OAuth2AuthorizationCodeRequestAuthenticationToken) {
            // 授权码模式授权成功时
            forOAuth2AuthorizationCode(request, response, (OAuth2AuthorizationCodeRequestAuthenticationToken) authentication);
            return;
        } else if (authentication instanceof OAuth2AccessTokenAuthenticationToken) {
            // 获取accessToken成功时
            forAccessToken(request, response, (OAuth2AccessTokenAuthenticationToken) authentication);
            return;
        }

        // 登录成功时
        forLogin(request, response, authentication);
    }

    private void forLogin(HttpServletRequest request, HttpServletResponse response, Authentication authentication) {
        // 获取登录前的认证请求
        String state = request.getParameter(OAuth2ParameterNames.STATE);
        if (!StringUtils.hasText(state)) {
            log.warn("缺少state参数");
            return;
        }

        OAuth2AuthorizationCodeRequestAuthenticationToken codeRequest = authorizationCodeRequestCache.getAuthenticationToken(state);
        if (codeRequest == null) {
            log.error("未找到授权码认证请求信息：{}", state);
            return;
        }
        RegisteredClient client = clientRepository.findByClientId(codeRequest.getClientId());
        OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
                .authorizationUri(codeRequest.getAuthorizationUri())
                .clientId(codeRequest.getClientId())
                .redirectUri(codeRequest.getRedirectUri())
                .scopes(codeRequest.getScopes())
                .state(codeRequest.getState())
                .additionalParameters(codeRequest.getAdditionalParameters())
                .build();
        OAuth2Authorization authorization = authorizationBuilder(client, authentication, authorizationRequest)
                .attribute(OAuth2ParameterNames.STATE, codeRequest.getState())
                .attribute(OAuth2ParameterNames.CLIENT_ID, codeRequest.getClientId())
                .build();
        // 保存授权码认证成功信息
        oAuth2AuthorizationService.save(authorization);
        // 删除无用的缓存信息
        authorizationCodeRequestCache.removeAuthenticationToken(state);
    }

    private void forOAuth2AuthorizationCode(HttpServletRequest request, HttpServletResponse response, OAuth2AuthorizationCodeRequestAuthenticationToken authentication) throws IOException {
        // 认证通过，则返回授权码
        if (supportRedirect(request)) {
            // 支持重定向的，则走原spring security的逻辑
            sendAuthorizationResponse(request, response, authentication);
            return;
        }

        // 返回json格式数据
        Map<String, Object> result = new HashMap<>(4);
        result.put(OAuth2ParameterNames.CODE, authentication.getAuthorizationCode().getTokenValue());
        if (StringUtils.hasText(authentication.getState())) {
            result.put(OAuth2ParameterNames.STATE, authentication.getState());
        }
        writeResponse(response, ApiResult.ok(result));
    }

    private void forAccessToken(HttpServletRequest request, HttpServletResponse response, OAuth2AccessTokenAuthenticationToken accessTokenAuthentication) throws IOException {
        // 构建accessToken结果
        OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
        OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
        Map<String, Object> additionalParameters = accessTokenAuthentication.getAdditionalParameters();

        OAuth2AccessTokenResponse.Builder builder =
                OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
                        .tokenType(accessToken.getTokenType())
                        .scopes(accessToken.getScopes());
        if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
            builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
        }
        if (refreshToken != null) {
            builder.refreshToken(refreshToken.getTokenValue());
        }
        if (!CollectionUtils.isEmpty(additionalParameters)) {
            builder.additionalParameters(additionalParameters);
        }
        OAuth2AccessTokenResponse accessTokenResponse = builder.build();

        var oauthToken = convertOAuth2AccessToken(accessTokenResponse);
        writeResponse(response, ApiResult.ok(oauthToken));
    }

    /**
     * OAuth2AccessToken转换
     * <p>
     * 转换AccessToken返回格式，原默认可参考{@link DefaultOAuth2AccessTokenResponseMapConverter}
     *
     * @return accessToken
     */
    private static OAuthToken convertOAuth2AccessToken(OAuth2AccessTokenResponse tokenResponse) {
        OAuthToken token = new OAuthToken();
        token.setAccessToken(tokenResponse.getAccessToken().getTokenValue());
        token.setTokenType(tokenResponse.getAccessToken().getTokenType().getValue());
        token.setExpiresIn(getExpiresIn(tokenResponse));
        token.setScope(Objects.requireNonNullElse(tokenResponse.getAccessToken().getScopes(), Collections.emptySet()));

        var refreshToken = tokenResponse.getRefreshToken();
        if (refreshToken != null) {
            token.setRefreshToken(refreshToken.getTokenValue());
        }

        return token;
    }

    private static long getExpiresIn(OAuth2AccessTokenResponse tokenResponse) {
        if (tokenResponse.getAccessToken().getExpiresAt() != null) {
            return ChronoUnit.SECONDS.between(Instant.now(), tokenResponse.getAccessToken().getExpiresAt());
        }
        return -1;
    }

    /**
     * 构建OAuthentication
     * <p>
     * 复制于{@link OAuth2AuthorizationCodeRequestAuthenticationProvider#authorizationBuilder(RegisteredClient, Authentication, OAuth2AuthorizationRequest)}
     *
     * @param registeredClient
     * @param principal
     * @param authorizationRequest
     * @return
     */
    private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient, Authentication principal,
                                                                    OAuth2AuthorizationRequest authorizationRequest) {
        return OAuth2Authorization.withRegisteredClient(registeredClient)
                .principalName(principal.getName())
                .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
                .attribute(Principal.class.getName(), principal)
                .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest);
    }
}
