package com.elitescloud.boot.auth.client.config.security.configurer.filter;

import com.elitescloud.boot.auth.client.config.AuthorizationProperties;
import com.elitescloud.boot.auth.client.config.security.resolver.BearerTokenResolver;
import com.elitescloud.boot.auth.client.config.security.resolver.impl.DefaultBearerTokenResolver;
import com.elitescloud.boot.auth.client.config.support.AuthenticationCache;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import lombok.extern.log4j.Log4j2;
import org.springframework.lang.NonNull;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;

/**
 * token续期filter.
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/7
 */
@Log4j2
public class AccessTokenRenewalFilter extends OncePerRequestFilter {

    private final AuthorizationProperties authorizationProperties;
    private final AuthenticationCache authenticationCache;
    private final JwtDecoder jwtDecoder;
    private BearerTokenResolver bearerTokenResolver = new DefaultBearerTokenResolver();

    private Cache<String, String> tokenRefreshCache = null;

    public AccessTokenRenewalFilter(AuthorizationProperties authorizationProperties, AuthenticationCache authenticationCache, JwtDecoder jwtDecoder) {
        this.authorizationProperties = authorizationProperties;
        this.authenticationCache = authenticationCache;
        this.jwtDecoder = jwtDecoder;

        tokenRefreshCache = Caffeine.newBuilder()
                .expireAfterWrite(authorizationProperties.getTokenRenewalRate())
                .maximumSize(5000)
                .build();
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        if (renewal()) {
            // 需要刷新
            String token = bearerTokenResolver.resolve(request);
            if (StringUtils.hasText(token)) {
                refreshTokenTtl(token);
            }
        }

        filterChain.doFilter(request, response);
    }

    private void refreshTokenTtl(String token) {
        var principal = tokenRefreshCache.getIfPresent(token);
        if (principal == null) {
            // 本地已过期，需要刷新
            CompletableFuture.runAsync(() -> {
                var user = authenticationCache.getUserDetail(token);
                if (user == null) {
                    // 用户token已过期，无法自动续期
                    return;
                }

                Jwt jwt = null;
                try {
                    jwt = jwtDecoder.decode(token);
                } catch (JwtException e) {
                    log.error("续期token异常：{}", e.getMessage());
                    return;
                }
                if (jwt.getExpiresAt() == null) {
                    // 没有设置过期时间，无需续期
                    return;
                }

                authenticationCache.setUserDetail(token, user, authorizationProperties.getTokenRenewal());
                tokenRefreshCache.put(token, "true");
            });
        }
    }

    private boolean renewal() {
        return authorizationProperties.getTokenRenewal() != null && authorizationProperties.getTokenRenewal().getSeconds() > 0;
    }

    public void setBearerTokenResolver(@NonNull BearerTokenResolver bearerTokenResolver) {
        this.bearerTokenResolver = bearerTokenResolver;
    }
}
