package com.elitescloud.boot.util;

import cn.hutool.core.io.IORuntimeException;
import cn.hutool.core.io.resource.ResourceUtil;
import cn.hutool.core.util.RuntimeUtil;
import lombok.extern.log4j.Log4j2;
import org.springframework.data.redis.connection.ReturnType;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.util.Assert;

import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;

/**
 * 限流工具类.
 *
 * @author Kaiser（wang shao）
 * @date 2/17/2023
 */
@Log4j2
public final class LimiterUtil {
    private static final ConcurrentHashMap<String, LimiterUtil> INSTANCE_MAP = new ConcurrentHashMap<>();
    private static final ConcurrentMap<String, AtomicLong> LOCAL_LIMITER = new ConcurrentHashMap<>();
    private static volatile boolean destroy = false;
    private static final byte[] LIMITER_BUSINESS = "{cloudt}:limiter".getBytes(StandardCharsets.UTF_8);
    private static byte[] luaScriptAdd;
    private static byte[] luaScriptSubtract;

    private final String key;
    private final byte[] keyBytes;
    private static RedisTemplate<Object, Object> redisTemplate;

    private LimiterUtil(String key, RedisTemplate<Object, Object> redisTemplate) {
        this.key = key;
        this.keyBytes = key.getBytes(StandardCharsets.UTF_8);
        LimiterUtil.redisTemplate = redisTemplate;
    }

    static {
        // 初始化限流脚本
        try {
            luaScriptAdd = ResourceUtil.readBytes("classpath:redis/limiter_add.lua");
            luaScriptSubtract = ResourceUtil.readBytes("classpath:redis/limiter_subtract.lua");
        } catch (IORuntimeException e) {
            log.error("加载限流资源文件失败", e);
        }

        RuntimeUtil.addShutdownHook(LimiterUtil::shutdown);
    }

    /**
     * 获取限流实例
     *
     * @param redisTemplate redis工具类
     * @param key        业务标识
     * @return 实例
     */
    public static LimiterUtil getInstance(@NotNull RedisTemplate<Object, Object> redisTemplate, @NotBlank String key) {
        Assert.notNull(redisTemplate, "redisTemplate为空");
        Assert.hasText(key, "业务标识为空");

        synchronized (LimiterUtil.class) {
            var instance = INSTANCE_MAP.get(key);
            if (instance != null) {
                return instance;
            }

            instance = new LimiterUtil(key, redisTemplate);
            INSTANCE_MAP.put(key, instance);
            return instance;
        }
    }

    /**
     * 更新限流数据
     *
     * @param max  最大限流量
     * @param step 步，最小为1
     * @param add  是否是增加
     * @return 是否更新成功，若false，标识超过限流
     */
    public boolean updateLimiter(int max, int step, boolean add) {
        step = Math.max(step, 1);
        byte[] stepBytes = (step + "").getBytes(StandardCharsets.UTF_8);
        byte[] maxBytes = (max + "").getBytes(StandardCharsets.UTF_8);

        // 更新redis中的数据
        Long result;
        if (add) {
            result = redisTemplate.execute((RedisCallback<Long>) connection -> connection.eval(luaScriptAdd,
                    ReturnType.INTEGER, 1, LIMITER_BUSINESS, keyBytes, maxBytes, stepBytes));
        } else {
            result = redisTemplate.execute((RedisCallback<Long>) connection -> connection.eval(luaScriptSubtract,
                    ReturnType.INTEGER, 1, LIMITER_BUSINESS, keyBytes, stepBytes));
        }

        // 更新本地缓存
        if (result != null && result != -1) {
            LOCAL_LIMITER.computeIfAbsent(key, k -> new AtomicLong(0L)).addAndGet(step * (add ? 1L : -1L));
            return true;
        }
        return false;
    }

    private static void shutdown() {
        if (destroy) {
            return;
        }
        destroy = true;
        log.info("数据导入导出服务销毁...");

        for (Map.Entry<String, AtomicLong> entry : LOCAL_LIMITER.entrySet()) {
            clearLimiter(entry.getKey(), entry.getValue().intValue());
        }
    }

    private static void clearLimiter(String key, int value) {
        if (value <= 0) {
            return;
        }
        byte[] keyBytes = key.getBytes(StandardCharsets.UTF_8);
        byte[] valueBytes = (value + "").getBytes(StandardCharsets.UTF_8);
        redisTemplate.execute((RedisCallback<Long>) connection -> connection.eval(luaScriptSubtract,
                ReturnType.INTEGER, 1, LIMITER_BUSINESS, keyBytes, valueBytes));
    }
}
