package com.elitescloud.boot.task.retry;

import cn.hutool.core.collection.CollUtil;
import com.lmax.disruptor.util.DaemonThreadFactory;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.redisson.api.RedissonClient;
import org.springframework.lang.NonNull;
import org.springframework.scheduling.annotation.SchedulingConfigurer;
import org.springframework.scheduling.config.ScheduledTaskRegistrar;
import org.springframework.util.Assert;

import java.time.Duration;
import java.time.LocalDateTime;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;

/**
 * 重试服务.
 *
 * @author Kaiser（wang shao）
 * @date 2023/9/14
 */
@Slf4j
public abstract class AbstractRetryService<T extends RetryTask> implements SchedulingConfigurer, RetryableService<T> {
    private final RetryTaskProvider<T> retryTaskProvider;
    private final RedissonClient redissonClient;
    private RetryTaskQueueWrapper<T> retryTaskQueueWrapper = null;

    protected AbstractRetryService(RetryTaskProvider<T> retryTaskProvider, RedissonClient redissonClient) {
        this.retryTaskProvider = retryTaskProvider;
        this.redissonClient = redissonClient;
    }

    /**
     * 是否支持重试
     *
     * @return 是否支持重试
     */
    protected abstract boolean supportRetry();

    /**
     * 最大重试次数
     * <p>
     * 小于1则不重试
     *
     * @return 重试次数
     */
    protected abstract int retryTimes();

    /**
     * 重试间隔
     *
     * @return 重试间隔
     */
    protected abstract List<Duration> retryIntervals();

    /**
     * 执行任务
     *
     * @param task 待执行任务
     */
    protected abstract void executeTask(T task);

    /**
     * 线程前缀
     *
     * @return 前缀
     */
    protected String threadPrefix() {
        return "common-";
    }

    /**
     * 定时任务的间隔
     *
     * @return 默认30分钟
     */
    protected Duration scheduleDelay() {
        return Duration.ofMinutes(30);
    }

    @Override
    public void configureTasks(@NonNull ScheduledTaskRegistrar taskRegistrar) {
        taskRegistrar.addFixedDelayTask(() -> {
            if (!supportRetry()) {
                return;
            }
            try {
                addRetryTaskToQueue();
            } catch (Exception e) {
                log.error("定时执行添加任务重试异常：", e);
            }
        }, scheduleDelay().toMillis());
    }

    @Override
    public void addRetryTask(T retryTask) {
        Assert.hasText(retryTask.getTaskId(), "任务ID为空");
        if (retryTask.getRetryTime() == null) {
            var msg = supportRetry() ? "已达最大重试次数" : "重试功能已禁用";
            log.info("删除重试任务：{}，{}，{}", retryTask.getTaskId(), retryTask.getRetryTimes(), msg);
            retryTaskProvider.deleteTask(retryTask.getTaskId(), msg);
            return;
        }

        if (retryTaskQueueWrapper == null) {
            // 初始化队列
            this.initQueue();
        }
        retryTaskQueueWrapper.addTask(retryTask);
    }

    @Override
    public LocalDateTime generateNextRetryTime(LocalDateTime lastSendTime, int retryTimes) {
        if (!this.supportRetry()) {
            log.info("重试已关闭，无需重试");
            return null;
        }

        var max = this.retryTimes();
        if (retryTimes >= max) {
            // 已达到最大重试次数
            log.info("已达最大重试次数，不再重试");
            return null;
        }
        var intervals = this.retryIntervals();
        if (CollUtil.isEmpty(intervals)) {
            log.error("消息重试间隔未设置，无法重试");
            return null;
        }

        var interval = retryTimes > intervals.size() - 1 ? intervals.get(intervals.size() - 1) : intervals.get(retryTimes);
        return lastSendTime.plusSeconds(interval.toSeconds());
    }

    private void addRetryTaskToQueue() {
        String lastTaskId = null;
        while (true) {
            var taskList = retryTaskProvider.queryTask(lastTaskId, 50);
            if (CollUtil.isEmpty(taskList)) {
                break;
            }
            for (T task : taskList) {
                this.addRetryTask(task);
                lastTaskId = task.getTaskId();
            }
        }
    }

    private void initQueue() {
        retryTaskQueueWrapper = new RetryTaskQueueWrapper<>(threadPrefix(), 2000, (task, size) -> {
            // 重试发送消息
            Throwable exp = null;
            var lock = redissonClient.getLock(threadPrefix() + "retry-" + task.getTaskId() + "-" + task.getVersion());
            try {
                if (lock.tryLock(1, TimeUnit.MINUTES)) {
                    if (!retryTaskProvider.trySend(task.getTaskId(), task.getVersion())) {
                        log.info("任务{}, {}不需要再重试", task.getTaskId(), task.getVersion());
                        return;
                    }

                    log.info("重试任务：{}", task.getTaskId());
                    this.executeTask(task);
                    log.info("重试任务成功：{}", task.getTaskId());
                }
            } catch (Throwable e) {
                log.error("执行重试任务异常：{}", task.getTaskId(), e);
                exp = e;
            } finally {
                // 更新发送结果
                try {
                    retryTaskProvider.updateRetryResult(task.getTaskId(), exp == null, exp == null ? null : exp.getMessage());
                } catch (Exception e) {
                    log.error("更新任务重试结果异常：", e);
                }

                lock.unlock();
            }
        });
    }

    static class RetryTaskQueueWrapper<T extends RetryTask> {
        private final String threadPrefix;
        private final DelayQueue<DelayTask<T>> queue = new DelayQueue<>();
        private final Set<String> taskIdsAll = new HashSet<>();
        private final int size;
        private final BiConsumer<T, Integer> consumer;

        public RetryTaskQueueWrapper(String threadPrefix, int size, BiConsumer<T, Integer> consumer) {
            this.threadPrefix = threadPrefix;
            this.size = size;
            this.consumer = consumer;

            this.consumeMessage();
        }

        /**
         * 添加任务
         *
         * @param task 任务
         */
        public void addTask(T task) {
            var taskId = task.getTaskId();
            Assert.hasText(taskId, "添加重试队列失败，任务ID为空");
            if (taskIdsAll.contains(taskId)) {
                return;
            }
            if (queue.size() >= size) {
                log.info("重试队列已满");
                return;
            }
            queue.add(new DelayTask<>(task));
            taskIdsAll.add(taskId);
        }

        private void consumeMessage() {
            Runnable runnable = () -> {
                DelayTask<T> task;
                while (true) {
                    try {
                        task = queue.take();
                    } catch (InterruptedException e) {
                        log.error("从消息队列获取延迟任务异常", e);
                        continue;
                    }
                    taskIdsAll.remove(task.getRetryTask().getTaskId());
                    try {
                        consumer.accept(task.getRetryTask(), queue.size());
                    } catch (Exception e) {
                        log.error("延时任务处理异常：", e);
                    }
                }
            };
            Thread threadConsumer = DaemonThreadFactory.INSTANCE.newThread(runnable);
            threadConsumer.setName(threadPrefix + "retry");
            threadConsumer.setDaemon(true);
            threadConsumer.setUncaughtExceptionHandler((t, e) -> log.error("重试服务异常：", e));
            threadConsumer.start();
        }
    }

    @Getter
    static class DelayTask<T extends RetryTask> implements Delayed {
        private final T retryTask;
        private final LocalDateTime sendTime;

        public DelayTask(T retryTask) {
            this.retryTask = retryTask;
            this.sendTime = retryTask.getRetryTime();
            Assert.notNull(sendTime, "重试时间为空");
        }

        @Override
        public long getDelay(@NonNull TimeUnit unit) {
            return unit.convert(Duration.between(LocalDateTime.now(), sendTime));
        }

        @Override
        public int compareTo(@NonNull Delayed o) {
            if (o == this) {
                return 0;
            }
            if (o instanceof DelayTask) {
                DelayTask<T> msg = (DelayTask<T>) o;
                return getSendTime().compareTo(msg.getSendTime());
            }
            return 0;
        }
    }
}
