package com.elitescloud.boot.auth.client.config.security.configurer.filter;

import com.elitescloud.boot.auth.cas.provider.UserTransferHelper;
import com.elitescloud.boot.auth.client.config.AuthorizationProperties;
import com.elitescloud.boot.auth.client.config.security.resolver.BearerTokenResolver;
import com.elitescloud.boot.auth.client.config.security.resolver.impl.DefaultBearerTokenResolver;
import com.elitescloud.boot.auth.client.config.support.AuthenticationCache;
import com.elitescloud.boot.constant.AuthenticationClaim;
import com.elitescloud.boot.constant.OpenFeignConstant;
import com.elitescloud.boot.threadpool.common.ThreadPoolHolder;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import lombok.extern.log4j.Log4j2;
import org.springframework.lang.NonNull;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * token续期filter.
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/7
 */
@Log4j2
public class AccessTokenRenewalFilter extends OncePerRequestFilter {

    private final AuthorizationProperties authorizationProperties;
    private final AuthenticationCache authenticationCache;
    private final JwtDecoder jwtDecoder;
    private final UserTransferHelper userTransferHelper;
    private BearerTokenResolver bearerTokenResolver = new DefaultBearerTokenResolver();

    private Cache<String, String> tokenRefreshCache = null;
    private final List<RequestMatcher> requestMatcherIgnore;
    private final ThreadPoolExecutor threadPoolExecutor;

    public AccessTokenRenewalFilter(AuthorizationProperties authorizationProperties, AuthenticationCache authenticationCache, JwtDecoder jwtDecoder) {
        this.authorizationProperties = authorizationProperties;
        this.authenticationCache = authenticationCache;
        this.jwtDecoder = jwtDecoder;
        this.userTransferHelper = UserTransferHelper.getInstance(authorizationProperties.getIssuerUrl());

        tokenRefreshCache = Caffeine.newBuilder()
                .expireAfterWrite(authorizationProperties.getTokenRenewalRate())
                .maximumSize(5000)
                .build();
        requestMatcherIgnore = List.of(new AntPathRequestMatcher(OpenFeignConstant.URI_PREFIX + "/**"),
                new AntPathRequestMatcher("/actuator/**"));
        threadPoolExecutor = ThreadPoolHolder.createThreadPool("renewToken-", 4, 8);
    }

    public void setBearerTokenResolver(@NonNull BearerTokenResolver bearerTokenResolver) {
        this.bearerTokenResolver = bearerTokenResolver;
    }

    @Override
    protected void doFilterInternal(@NonNull HttpServletRequest request, @NonNull HttpServletResponse response,
                                    @NonNull FilterChain filterChain) throws ServletException, IOException {
        if (!this.needRefresh(request)) {
            log.debug("无需刷新token的uri：{}", request.getRequestURI());
            filterChain.doFilter(request, response);
            return;
        }

        // 需要刷新
        String token = bearerTokenResolver.resolve(request);
        if (StringUtils.hasText(token)) {
            this.refreshTokenTtl(token);
        }

        filterChain.doFilter(request, response);
    }

    private void refreshTokenTtl(String token) {
        var exists = tokenRefreshCache.getIfPresent(token);
        if (exists == null) {
            // 本地已过期，需要刷新
            CompletableFuture.runAsync(() -> {
                Jwt jwt = null;
                try {
                    jwt = jwtDecoder.decode(token);
                } catch (JwtException e) {
                    log.error("续期token异常：{}", e.getMessage());
                    return;
                }

                if (!AuthenticationClaim.VALUE_PRINCIPAL_USER.equals(jwt.getClaimAsString(AuthenticationClaim.KEY_PRINCIPAL_TYPE))) {
                    // 非用户token，忽略
                    return;
                }
                tokenRefreshCache.put(token, "true");

                Duration tokenTtl = null;
                var casUserId = jwt.getClaimAsString(AuthenticationClaim.KEY_CAS_USERID);
                var clientId = jwt.getClaimAsString(AuthenticationClaim.KEY_CLIENT_ID);
                if (StringUtils.hasText(clientId) && StringUtils.hasText(casUserId)) {
                    // 获取身份认证端信息
                    var resValidate = userTransferHelper.validateClientUser(clientId, Long.parseLong(casUserId));
                    if (resValidate.getData() == null) {
                        log.warn("验证用户不通过：{}, {}", casUserId, resValidate.getMsg());
                        authenticationCache.expireAt(token, Duration.ofSeconds(1));
                        return;
                    }
                    if (Boolean.FALSE.equals(resValidate.getData().getTokenRenewal())) {
                        // 不支持自动续期
                        log.debug("不支持自动续期：{}", casUserId);
                        return;
                    }
                    tokenTtl = resValidate.getData().getTokenTtl();
                    log.debug("根据CAS自动续期");
                } else {
                    tokenTtl = authorizationProperties.getTokenRenewal();
                    log.debug("根据应用配置自动续期");
                }
                if (tokenTtl != null && tokenTtl.toSeconds() > 0) {
                    log.info("自动续期token：{}, {}min", token, tokenTtl.toMinutes());
                    authenticationCache.expireAt(token, tokenTtl);
                }
            }, threadPoolExecutor);
        }
    }

    private boolean needRefresh(HttpServletRequest request) {
        for (RequestMatcher requestMatcher : requestMatcherIgnore) {
            if (requestMatcher.matches(request)) {
                return false;
            }
        }
        return true;
    }

    private Duration localTokenTtl() {
        var duration = authorizationProperties.getTokenRenewal();
        if (duration != null && duration.toSeconds() > 0) {
            return duration;
        }
        log.debug("不支持自动续期");
        return null;
    }
}
