package com.elitescloud.boot.spi.support;

import cn.hutool.core.util.ArrayUtil;
import com.elitescloud.cloudt.common.annotation.context.spi.Spi;
import com.elitescloud.boot.spi.common.BaseSpiService;
import com.elitescloud.boot.SpringContextHolder;
import lombok.extern.log4j.Log4j2;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * .
 *
 * @author Kaiser（wang shao）
 * @date 2022/2/16
 */
@Log4j2
public final class ProviderInstanceHolder {

    private static final Map<String, List<ProviderWrapper<?>>> CACHE = new ConcurrentHashMap<>();
    /**
     * 是否加载过spring的bean
     */
    private static final Map<String, Boolean> CACHE_SPRING_LOAD = new ConcurrentHashMap<>();

    private ProviderInstanceHolder() {
    }

    /**
     * 查询所有实例
     *
     * @param serviceType 实例类型
     * @param reload      是否重新加载
     * @param <T>         类型
     * @return 所有实例
     */
    @SuppressWarnings("unchecked")
    public static <T> List<T> loadProviderInstances(@NonNull Class<T> serviceType, Boolean reload) {
        Assert.notNull(serviceType, "serviceType为空");

        var providerWrappers = findProvider(serviceType, reload);
        if (providerWrappers.isEmpty()) {
            return Collections.emptyList();
        }
        return (List<T>) providerWrappers.stream()
                .map(ProviderWrapper::getProvider)
                .collect(Collectors.toList());
    }

    /**
     * 获取一个实例
     * <p>
     * 若有默认的，则优先使用默认的
     *
     * @param serviceType 实例类型
     * @param reload      是否重新加载
     * @param <T>         类型
     * @return 实例
     */
    @SuppressWarnings("unchecked")
    public static <T> Optional<T> loadProviderInstance(@NonNull Class<T> serviceType, Boolean reload) {
        Assert.notNull(serviceType, "serviceType为空");

        var providerWrappers = findProvider(serviceType, reload);
        if (providerWrappers.isEmpty()) {
            return Optional.empty();
        }

        for (ProviderWrapper<?> providerWrapper : providerWrappers) {
            if (providerWrapper.isPrimary()) {
                return (Optional<T>) Optional.of(providerWrapper.getProvider());
            }
        }
        return (Optional<T>) Optional.of(providerWrappers.get(0).getProvider());
    }

    static <T> List<ProviderWrapper<?>> findProvider(Class<T> serviceType, Boolean reload) {
        boolean loadCache = reload == null || !reload;
        String serviceName = serviceType.getName();

        // 判断是否已从spring容器加载过，如果没有则重新加载
        Boolean springLoaded = CACHE_SPRING_LOAD.get(serviceName);
        if (springLoaded == null || !springLoaded) {
            loadCache = false;
        }

        if (loadCache && CACHE.containsKey(serviceName)) {
            return CACHE.get(serviceName);
        }

        List<ProviderWrapper<?>> providers = load(serviceType);
        CACHE.put(serviceName, providers);

        return providers;
    }

    private static <T> List<ProviderWrapper<?>> load(Class<T> serviceType) {
        List<ProviderWrapper<?>> providers = new ArrayList<>();

        Spi spi = SpiUtil.obtainAnnotationSingle(serviceType, Spi.class);
        Set<String> loaded = new HashSet<>();
        for (T t : ServiceLoader.load(serviceType)) {
            providers.add(new ProviderWrapper<>(t, spi, false));
            loaded.add(t.getClass().getName());
        }

        // spring加载
        if (SpringContextHolder.initialized()) {
            for (T t : SpringContextHolder.getObjectProvider(serviceType)) {
                if (loaded.contains(t.getClass().getName())) {
                    continue;
                }
                if (t instanceof BaseSpiService) {
                    continue;
                }
                providers.add(new ProviderWrapper<>(t, spi, true));
                CACHE_SPRING_LOAD.put(serviceType.getName(), true);
            }
        }

        // 处理默认的实现
        dealSpiPrimary(providers, spi);

        // 排序
        providers.sort(Comparator.comparing(ProviderWrapper::getOrder, Comparator.reverseOrder()));

        return providers;
    }

    private static void dealSpiPrimary(List<ProviderWrapper<?>> providers, Spi spi) {
        if (spi == null ||
                (ArrayUtil.isEmpty(spi.primary()) && ArrayUtil.isEmpty(spi.primaryClass()))
                || providers.isEmpty()) {
            // 没有配置默认
            return;
        }

        Set<String> primaryClassName = new HashSet<>();
        if (ArrayUtil.isNotEmpty(spi.primary())) {
            primaryClassName.addAll(Arrays.asList(spi.primary()));
        }
        for (Class<?> clazz : spi.primaryClass()) {
            primaryClassName.add(clazz.getName());
        }
        for (ProviderWrapper<?> provider : providers) {
            if (primaryClassName.contains(provider.getClazz().getName())) {
                provider.setPrimary(true);
                return;
            }
        }
    }

}
