package com.elitescloud.boot.auth.provider.config.servlet.oauth2.handler;

import com.elitescloud.boot.auth.model.OAuthToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.DefaultOAuth2AccessTokenResponseMapConverter;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;

/**
 * AccessToken 处理器.
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/10
 */
public class OAuth2AccessTokenResponseHandler extends AbstractOAuth2ServerHandler implements AuthenticationSuccessHandler {

    @Override
    public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException {
        OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) authentication;
        // 构建accessToken结果
        OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
        OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
        Map<String, Object> additionalParameters = accessTokenAuthentication.getAdditionalParameters();

        OAuth2AccessTokenResponse.Builder builder =
                OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
                        .tokenType(accessToken.getTokenType())
                        .scopes(accessToken.getScopes());
        if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
            builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
        }
        if (refreshToken != null) {
            builder.refreshToken(refreshToken.getTokenValue());
        }
        if (!CollectionUtils.isEmpty(additionalParameters)) {
            builder.additionalParameters(additionalParameters);
        }
        OAuth2AccessTokenResponse accessTokenResponse = builder.build();

        var oauthToken = convertOAuth2AccessToken(accessTokenResponse);
        writeResponse(response, oauthToken);
    }

    /**
     * OAuth2AccessToken转换
     * <p>
     * 转换AccessToken返回格式，原默认可参考{@link DefaultOAuth2AccessTokenResponseMapConverter}
     *
     * @return accessToken
     */
    private static OAuthToken convertOAuth2AccessToken(OAuth2AccessTokenResponse tokenResponse) {
        OAuthToken token = new OAuthToken();
        token.setAccessToken(tokenResponse.getAccessToken().getTokenValue());
        token.setTokenType(tokenResponse.getAccessToken().getTokenType().getValue());
        token.setExpiresIn(getExpiresIn(tokenResponse));
        token.setScope(Objects.requireNonNullElse(tokenResponse.getAccessToken().getScopes(), Collections.emptySet()));

        var refreshToken = tokenResponse.getRefreshToken();
        if (refreshToken != null) {
            token.setRefreshToken(refreshToken.getTokenValue());
        }

        String idToken = (String) tokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN);
        if (StringUtils.hasText(idToken)) {
            token.setIdToken(idToken);
        }

        return token;
    }

    private static long getExpiresIn(OAuth2AccessTokenResponse tokenResponse) {
        if (tokenResponse.getAccessToken().getExpiresAt() != null) {
            return ChronoUnit.SECONDS.between(Instant.now(), tokenResponse.getAccessToken().getExpiresAt());
        }
        return -1;
    }
}
