package com.elitesland.cloudt.context.spi.support;

import com.elitesland.yst.common.annotation.context.spi.Spi;
import lombok.extern.log4j.Log4j2;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * .
 *
 * @author Kaiser（wang shao）
 * @date 2022/2/16
 */
@Log4j2
public final class ProviderInstanceHolder {

    private static final Map<String, List<ProviderWrapper<?>>> CACHE = new ConcurrentHashMap<>();

    private ProviderInstanceHolder() {
    }

    /**
     * 查询所有实例
     *
     * @param serviceType 实例类型
     * @param reload      是否重新加载
     * @param <T>         类型
     * @return 所有实例
     */
    @SuppressWarnings("unchecked")
    public static <T> List<T> loadProviderInstances(@NonNull Class<T> serviceType, Boolean reload) {
        Assert.notNull(serviceType, "serviceType为空");

        var providerWrappers = findProvider(serviceType, reload);
        if (providerWrappers.isEmpty()) {
            return Collections.emptyList();
        }
        return (List<T>) providerWrappers.stream()
                .map(ProviderWrapper::getProvider)
                .collect(Collectors.toList());
    }

    /**
     * 获取一个实例
     * <p>
     * 若有默认的，则优先使用默认的
     *
     * @param serviceType 实例类型
     * @param reload      是否重新加载
     * @param <T>         类型
     * @return 实例
     */
    @SuppressWarnings("unchecked")
    public static <T> Optional<T> loadProviderInstance(@NonNull Class<T> serviceType, Boolean reload) {
        Assert.notNull(serviceType, "serviceType为空");

        var providerWrappers = findProvider(serviceType, reload);
        if (providerWrappers.isEmpty()) {
            return Optional.empty();
        }

        for (ProviderWrapper<?> providerWrapper : providerWrappers) {
            if (providerWrapper.isPrimary()) {
                return (Optional<T>) Optional.of(providerWrapper.getProvider());
            }
        }
        return (Optional<T>) Optional.of(providerWrappers.get(0).getProvider());
    }

    private static <T> List<ProviderWrapper<?>> findProvider(Class<T> serviceType, Boolean reload) {
        boolean loadCache = reload == null || !reload;

        String serviceName = serviceType.getName();
        if (loadCache && CACHE.containsKey(serviceName)) {
            return CACHE.get(serviceName);
        }

        List<ProviderWrapper<?>> providers = load(serviceType);
        CACHE.put(serviceName, providers);

        return providers;
    }

    private static <T> List<ProviderWrapper<?>> load(Class<T> serviceType) {
        List<ProviderWrapper<?>> providers = new ArrayList<>();

        Spi spi = SpiUtil.obtainAnnotationSingle(serviceType, Spi.class);
        for (T t : ServiceLoader.load(serviceType)) {
            providers.add(new ProviderWrapper<>(t, spi));
        }

        // 处理默认的实现
        dealSpiPrimary(providers, spi);

        // 排序
        providers.sort(Comparator.comparing(ProviderWrapper::getOrder, Comparator.reverseOrder()));

        return providers;
    }

    private static void dealSpiPrimary(List<ProviderWrapper<?>> providers, Spi spi) {
        if (spi == null || !StringUtils.hasText(spi.primary()) || providers.isEmpty()) {
            return;
        }

        for (ProviderWrapper<?> provider : providers) {
            if (provider.getClazz().getName().equals(spi.primary())) {
                provider.setPrimary(true);
                return;
            }
        }
        throw new IllegalStateException("未找到默认的实现类：" + spi.primary());
    }

}
