package com.elitescloud.boot.auth.util;

import com.elitescloud.boot.SpringContextHolder;
import com.elitescloud.boot.auth.AuthorizedClient;
import com.elitescloud.boot.auth.BearerTokenAuthenticationToken;
import com.elitescloud.boot.auth.CommonAuthenticationToken;
import com.elitescloud.boot.auth.client.config.support.AuthenticationCache;
import com.elitescloud.boot.auth.client.config.support.AuthenticationContext;
import com.elitescloud.boot.constant.AuthenticationClaim;
import com.elitescloud.cloudt.security.entity.GeneralUserDetails;
import lombok.extern.log4j.Log4j2;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.time.Duration;
import java.time.LocalDateTime;
import java.util.*;
import java.util.function.Function;

/**
 * 安全上下文相关工具.
 *
 * @author Kaiser（wang shao）
 * @date 2022/01/08
 */
@Log4j2
public class SecurityContextUtil {

    private static final ContextBeanHolder CONTEXT_BEAN_HOLDER = new ContextBeanHolder();

    protected SecurityContextUtil() {
    }

    /**
     * 获取当前用户ID
     *
     * @return 当前用户ID
     */
    @Nullable
    public static Long currentUserId() {
        Long userId = (Long) currentUserClaims().get(AuthenticationClaim.KEY_USERID);
        if (userId != null) {
            return userId;
        }

        // 从其它框架上下文中获取
        return currentAuthenticationProperty(AuthenticationContext::getUserId);
    }

    /**
     * 获取当前用户登录号
     *
     * @return 当前用户登录号
     */
    @Nullable
    public static String currentUserName() {
        String username = (String) currentUserClaims().get(AuthenticationClaim.KEY_USERNAME);
        if (StringUtils.hasText(username)) {
            return username;
        }

        // 从其它框架上下文中获取
        return currentAuthenticationProperty(AuthenticationContext::getUsername);
    }

    /**
     * 获取当前租户ID
     *
     * @return 当前租户ID
     */
    @Nullable
    public static Long currentTenantId() {
        Long tenantId = (Long) currentUserClaims().get(AuthenticationClaim.KEY_TENANT_ID);
        if (tenantId != null) {
            return tenantId;
        }

        // 从其它框架上下文中获取
        return currentAuthenticationProperty(AuthenticationContext::getTenantId);
    }

    /**
     * 获取当前用户详细信息
     *
     * @return 当前用户详细信息
     */
    @Nullable
    public static GeneralUserDetails currentUser() {
        return currentUser(false);
    }

    /**
     * 获取当前用户详细信息
     *
     * @return 当前用户详细信息
     */
    @NonNull
    public static GeneralUserDetails currentUserIfUnauthorizedThrow() {
        return Objects.requireNonNull(currentUser(true));
    }

    /**
     * 获取当前的token值
     *
     * @return 当前的token值
     */
    @Nullable
    public static String currentToken() {
        Jwt jwt = currentAuthenticationJwt();
        if (jwt == null) {
            return null;
        }

        return jwt.getTokenValue();
    }

    /**
     * 更新当前缓存中的用户登录信息
     *
     * @param user 新的用户信息
     */
    public static void updateCurrentUser(@NonNull GeneralUserDetails user) {
        Assert.notNull(user, "用户信息为空");

        Jwt jwt = currentAuthenticationJwt();
        if (jwt == null) {
            // 没有登录认证
            SecurityUtil.throwUnauthorizedException();
            return;
        }

        // 清空自定义信息，避免其它业务域不能正常反序列化
        user.setExtendInfo(null);

        // 缓存新的用户信息
        Duration ttl = null;
        if (jwt.getExpiresAt() != null) {
            ttl = Duration.between(LocalDateTime.now(), jwt.getExpiresAt());
        }
        CONTEXT_BEAN_HOLDER.getAuthenticationCache().setUserDetail(jwt.getTokenValue(), user, ttl);
    }

    /**
     * 根据token值获取对应的用户信息
     *
     * @param token token值
     * @return 用户信息
     */
    public static GeneralUserDetails convertToken(String token) {
        return CONTEXT_BEAN_HOLDER.getAuthenticationCache().getUserDetail(token);
    }

    /**
     * 获取当前认证的客户端
     *
     * @return 客户端信息
     */
    public static AuthorizedClient currentAuthorizedClient() {
        return currentAuthorizedClient(false);
    }

    /**
     * 获取当前认证的客户端，如果没有，则返回401
     *
     * @return 客户端信息
     */
    public static AuthorizedClient currentAuthorizedClientIfUnauthorizedThrow() {
        return currentAuthorizedClient(true);
    }

    private static GeneralUserDetails currentUser(boolean required) {
        Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
        if (authentication == null) {
            if (required) {
                SecurityUtil.throwUnauthorizedException();
            }
        }
        if (authentication instanceof CommonAuthenticationToken) {
            // 如果是自定义的，说明已处理过用户详细信息
            var user = ((CommonAuthenticationToken) authentication).getUserDetails();
            if (required && user == null) {
                SecurityUtil.throwUnauthorizedException();
            }
            return user;
        } else if (authentication instanceof BearerTokenAuthenticationToken && authentication.getPrincipal() instanceof GeneralUserDetails) {
            var user =  (GeneralUserDetails) authentication.getPrincipal();
            if (required && user == null) {
                SecurityUtil.throwUnauthorizedException();
            }
        }

        // 根据jwt转换
        GeneralUserDetails userDetails = null;
        AuthorizedClient authorizedClient = null;
        Jwt jwt = currentAuthenticationJwt(authentication);
        if (jwt != null) {
            userDetails = convertToken(jwt.getTokenValue());
            authorizedClient = AuthorizedClient.buildByJwt(jwt);
        }
        if (userDetails == null) {
            log.debug("获取当前用户信息失败，token不存在或已过期");
            if (required) {
                SecurityUtil.throwUnauthorizedException();
            }
            return null;
        }

        // 转换成自定义Authentication，避免频繁获取
        CommonAuthenticationToken authenticationToken = new CommonAuthenticationToken(authentication, userDetails, authentication.getAuthorities());
        authenticationToken.setAuthorizedClient(authorizedClient);
        SecurityContextHolder.getContext().setAuthentication(authenticationToken);

        return userDetails;
    }

    /**
     * 获取当前认证的客户端
     *
     * @param required 是否是必须有，否则返回401
     * @return 客户端
     */
    private static AuthorizedClient currentAuthorizedClient(boolean required) {
        Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
        if (authentication == null) {
            return null;
        }
        if (authentication instanceof CommonAuthenticationToken) {
            // 如果是自定义的，说明已处理过客户端信息
            var client =  ((CommonAuthenticationToken) authentication).getAuthorizedClient();
            if (required && client == null) {
                SecurityUtil.throwUnauthorizedException();
            }
            return client;
        } else if (authentication instanceof BearerTokenAuthenticationToken) {
            var client = ((BearerTokenAuthenticationToken) authentication).getAuthorizedClient();
            if (required && client == null) {
                SecurityUtil.throwUnauthorizedException();
            }
            return client;
        }

        // 根据jwt转换
        GeneralUserDetails userDetails = null;
        AuthorizedClient authorizedClient = null;
        Jwt jwt = currentAuthenticationJwt(authentication);
        if (jwt != null) {
            userDetails = convertToken(jwt.getTokenValue());
            authorizedClient = AuthorizedClient.buildByJwt(jwt);
        }

        // 转换成自定义Authentication，避免频繁获取
        CommonAuthenticationToken authenticationToken = new CommonAuthenticationToken(authentication, userDetails, authentication.getAuthorities());
        authenticationToken.setAuthorizedClient(authorizedClient);
        SecurityContextHolder.getContext().setAuthentication(authenticationToken);

        if (required && authorizedClient == null) {
            SecurityUtil.throwUnauthorizedException();
        }
        return authorizedClient;
    }

    /**
     * 获取当前Authentication的claim
     *
     * @return
     */
    private static Map<String, Object> currentUserClaims() {
        Jwt jwt = currentAuthenticationJwt();
        if (jwt == null) {
            return Collections.emptyMap();
        }

        return jwt.getClaims();
    }

    /**
     * 获取当前认证用户jwt信息
     *
     * @return
     */
    private static Jwt currentAuthenticationJwt() {
        Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
        if (authentication == null) {
            return null;
        }
        if (authentication instanceof CommonAuthenticationToken) {
            // 自定义的，根据原始authentication判断
            CommonAuthenticationToken authenticationToken = (CommonAuthenticationToken) authentication;
            return currentAuthenticationJwt(authenticationToken.getOriginal());
        }

        return currentAuthenticationJwt(authentication);
    }

    private static Jwt currentAuthenticationJwt(Authentication authentication) {
        if (authentication == null || authentication instanceof AnonymousAuthenticationToken) {
            return null;
        }
        if (authentication instanceof JwtAuthenticationToken) {
            // oauth2 时，认证过的身份信息
            JwtAuthenticationToken jwtAuthenticationToken = (JwtAuthenticationToken) authentication;
            return jwtAuthenticationToken.getToken();
        }

        if (authentication instanceof BearerTokenAuthenticationToken) {
            // 单体时
            BearerTokenAuthenticationToken bearerTokenAuthenticationToken = (BearerTokenAuthenticationToken) authentication;
            String token = bearerTokenAuthenticationToken.getToken();
            return CONTEXT_BEAN_HOLDER.getJwtDecoder().decode(token);
        }

        log.debug("暂不支持的Authentication类型：{}", authentication.getClass().getName());
        return null;
    }

    private static <T> T currentAuthenticationProperty(Function<AuthenticationContext, T> apply) {
        Iterator<AuthenticationContext> contextIterator = CONTEXT_BEAN_HOLDER.getAuthenticationContext();
        while (contextIterator.hasNext()) {
            T value = apply.apply(contextIterator.next());
            if (value != null) {
                return value;
            }
        }

        return null;
    }

    static class ContextBeanHolder {

        private final Map<Class<?>, ObjectProvider<Object>> sharedBeans = new HashMap<>();

        public JwtDecoder getJwtDecoder() {
            return getSharedBean(JwtDecoder.class).getIfAvailable();
        }

        public AuthenticationCache getAuthenticationCache() {
            return getSharedBean(AuthenticationCache.class).getIfAvailable();
        }

        public Iterator<AuthenticationContext> getAuthenticationContext() {
            return getSharedBean(AuthenticationContext.class).iterator();
        }

        @SuppressWarnings("unchecked")
        private <T> ObjectProvider<T> getSharedBean(Class<T> clazz) {
            var objectProvider = sharedBeans.get(clazz);
            if (objectProvider != null) {
                return (ObjectProvider<T>) objectProvider;
            }

            objectProvider = SpringContextHolder.getObjectProvider((Class<Object>) clazz);
            if (objectProvider != null) {
                sharedBeans.put(clazz, objectProvider);
            }
            return (ObjectProvider<T>) objectProvider;
        }
    }
}
