package com.elitescloud.boot.util;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;

import javax.validation.constraints.NotNull;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * 数据分割器.
 *
 * @author Kaiser（wang shao）
 * @date 2024/10/11
 */
public class DataSplitterUtil<T> {

    private final Function<Integer, List<T>> dataProducer;
    private final int shardSize;
    private final Function<T, String> masterGenerator;
    private final PaddingMode paddingMode;

    public DataSplitterUtil(@NotNull Function<Integer, List<T>> dataProducer, int shardSize) {
        this(dataProducer, shardSize, null, null);
    }

    public DataSplitterUtil(@NotNull Function<Integer, List<T>> dataProducer, int shardSize, Function<T, String> masterGenerator) {
        this(dataProducer, shardSize, masterGenerator, null);
    }

    public DataSplitterUtil(@NotNull Function<Integer, List<T>> dataProducer, int shardSize, Function<T, String> masterGenerator, PaddingMode paddingMode) {
        this.dataProducer = dataProducer;
        this.shardSize = shardSize;
        this.masterGenerator = masterGenerator;
        this.paddingMode = paddingMode == null ? (masterGenerator == null ? PaddingMode.NOOP : PaddingMode.FORWARD) : paddingMode;

        Assert.isTrue(shardSize > 0, "分片大小必须大于0");
    }

    public void consume(@NotNull SplitterListener<T> listener) {
        if (paddingMode == PaddingMode.NOOP || masterGenerator == null) {
            this.consumeForPaddingNoop(listener);
            return;
        }

        if (paddingMode == PaddingMode.FORWARD) {
            this.consumeForPaddingForward(listener);
            return;
        }

        if (paddingMode == PaddingMode.BACKWARD) {
            this.consumeForPaddingBackWard(listener);
            return;
        }

        throw new IllegalStateException("暂不支持的模式");
    }

    private void consumeForPaddingNoop(SplitterListener<T> listener) {
        this.consumeForBackWard(listener, (records, shardRemainingSize) -> Math.min(records.size(), shardRemainingSize));
    }

    private void consumeForPaddingForward(SplitterListener<T> listener) {
        // 总消费的数据量
        AtomicInteger count = new AtomicInteger(0);
        // 分片消费的数据量
        AtomicInteger shardCount = new AtomicInteger(0);

        // 数据查询次数
        int dataQueryTimes = 0;
        // 分片索引
        AtomicInteger shardIndex = new AtomicInteger(0);

        List<T> records = dataProducer.apply(++dataQueryTimes);
        Set<String> currentMasterIds = new HashSet<>();
        while (true) {
            if (CollUtil.isEmpty(records)) {
                // 数据为空，则直接结束
                if (shardCount.get() > 0) {
                    listener.onShardFinish(shardIndex.get());
                }
                listener.onFinish(count.get());
                return;
            }

            // 未超过分片大小，则直接全部消费
            if (shardCount.get() + records.size() <= shardSize) {
                this.applyConsumeForForward(records, count.get(), shardIndex.get(), shardCount.get(), listener,
                        count::set, shardIndex::set, shardCount::set, currentMasterIds);

                records = dataProducer.apply(++dataQueryTimes);
                continue;
            }

            // 超过分片的则切割
            // 先填充分片的剩余空位
            if (shardSize > shardCount.get()) {
                List<T> recordsTemp = records.subList(0, shardSize - shardCount.get());
                records = records.subList(shardSize - shardCount.get(), records.size());
                this.applyConsumeForForward(recordsTemp, count.get(), shardIndex.get(), shardCount.get(), listener,
                        count::set, shardIndex::set, shardCount::set, currentMasterIds);
            }

            // 检索出剩余的数据中与上一个分片同主ID的数据
            List<T> recordsPrev = new ArrayList<>(records.size());
            List<T> recordsNext = new ArrayList<>(records.size());
            for (T record : records) {
                String masterId = masterGenerator.apply(record);
                if (currentMasterIds.contains(masterId)) {
                    recordsPrev.add(record);
                } else {
                    recordsNext.add(record);
                }
            }
            if (!recordsPrev.isEmpty()) {
                this.applyConsumeForForward(recordsPrev, count.get(), shardIndex.get(), shardCount.get(), listener,
                        count::set, shardIndex::set, shardCount::set, currentMasterIds);
            }
            // 结束上一个分片
            listener.onShardFinish(shardIndex.get());
            shardCount.set(0);
            shardIndex.incrementAndGet();

            records = recordsNext.isEmpty() ? dataProducer.apply(++dataQueryTimes) : recordsNext;
        }
    }

    private void consumeForPaddingBackWard(SplitterListener<T> listener) {
        this.consumeForBackWard(listener, (records, shardRemainingSize) -> {
            if (shardRemainingSize <= 0) {
                return 0;
            }

            // 先分组
            List<List<T>> dataGroupList = new ArrayList<>();
            String lastMasterId = null;
            List<T> tempGroupList = null;
            for (T data : records) {
                String masterId = masterGenerator.apply(data);
                if (!masterId.equals(lastMasterId)) {
                    if (tempGroupList != null) {
                        dataGroupList.add(tempGroupList);
                    }
                    tempGroupList = new ArrayList<>();
                    lastMasterId = masterId;
                }
                tempGroupList.add(data);
            }
            dataGroupList.add(tempGroupList);

            // 查找最后一个分组的索引
            int size = 0;
            for (List<T> data : dataGroupList) {
                if (size + data.size() > shardRemainingSize) {
                    return size;
                }
                size += data.size();
            }

            return size;
        });
    }

    private void consumeForBackWard(SplitterListener<T> listener, BiFunction<List<T>, Integer, Integer> tryConsume) {
        // 总消费的数据量
        AtomicInteger count = new AtomicInteger(0);
        // 分片消费的数据量
        AtomicInteger shardCount = new AtomicInteger(0);

        // 数据查询次数
        int dataQueryTimes = 0;
        // 分片索引
        AtomicInteger shardIndex = new AtomicInteger(0);

        List<T> records = dataProducer.apply(++dataQueryTimes);
        while (true) {
            if (CollUtil.isEmpty(records)) {
                // 数据为空，则直接结束
                if (shardCount.get() > 0) {
                    listener.onShardFinish(shardIndex.get());
                }
                listener.onFinish(count.get());
                return;
            }

            // 未超过分片大小，则直接全部消费
            if (shardCount.get() + records.size() <= shardSize) {
                this.applyConsumeForBackward(records, count.get(), shardIndex.get(), shardCount.get(), listener,
                        count::set, shardIndex::set, shardCount::set);

                records = dataProducer.apply(++dataQueryTimes);
                continue;
            }

            // 超过分片的则切割
            int consumeSize = tryConsume.apply(records, shardSize - shardCount.get());
            if (consumeSize < 1) {
                // 分片空间不足，则开启新分片
                listener.onShardFinish(shardIndex.getAndIncrement());
                shardCount.set(0);
                continue;
            }
            if (consumeSize >= records.size()) {
                // 可消费全部，则直接消费
                this.applyConsumeForBackward(records, count.get(), shardIndex.get(), shardCount.get(), listener,
                        count::set, shardIndex::set, shardCount::set);
                break;
            }

            // 消费可消费的部分
            this.applyConsumeForBackward(records.subList(0, consumeSize), count.get(), shardIndex.get(), shardCount.get(), listener,
                    count::set, shardIndex::set, shardCount::set);
            records = records.subList(consumeSize, records.size());
        }
    }

    private void applyConsumeForBackward(List<T> records, int count, int shardIndex, int shardCount, SplitterListener<T> listener,
                                         Consumer<Integer> countConsumer, Consumer<Integer> shardIndexConsumer, Consumer<Integer> shardCountConsumer) {
        if (CollUtil.isEmpty(records)) {
            return;
        }
        if (shardCount == 0) {
            // 初始化分片
            listener.onShardInitialize(shardIndex);
        }

        // 开始消费
        listener.onConsume(records);

        // 更新总消费数量
        countConsumer.accept(count + records.size());
        // 更新分片消费数量
        if (shardCount + records.size() >= shardSize) {
            // 分片已满
            listener.onShardFinish(shardIndex);
            shardIndexConsumer.accept(shardIndex + 1);
            shardCountConsumer.accept(0);
        } else {
            shardCountConsumer.accept(shardCount + records.size());
        }
    }

    private void applyConsumeForForward(List<T> records, int count, int shardIndex, int shardCount, SplitterListener<T> listener,
                                        Consumer<Integer> countConsumer, Consumer<Integer> shardIndexConsumer, Consumer<Integer> shardCountConsumer,
                                        Set<String> currentMasterIds) {
        if (CollUtil.isEmpty(records)) {
            return;
        }
        if (shardCount == 0) {
            // 初始化分片
            listener.onShardInitialize(shardIndex);
            currentMasterIds.clear();
        }

        // 开始消费
        listener.onConsume(records);
        // 记录消费的主ID
        currentMasterIds.addAll(
                records.parallelStream().map(masterGenerator).collect(Collectors.toSet())
        );
        // 更新总消费数量
        countConsumer.accept(count + records.size());
        // 更新分片消费数量
        shardCountConsumer.accept(shardCount + records.size());
    }

    public static interface SplitterListener<T> {

        /**
         * 分片初始化
         */
        void onShardInitialize(int shardIndex);

        /**
         * 分段结束
         */
        void onShardFinish(int shardIndex);

        /**
         * 全部结束
         */
        void onFinish(long total);

        /**
         * 消费数据
         *
         * @param dataList 待消费数据
         */
        void onConsume(List<T> dataList);
    }

    public static enum PaddingMode {
        /**
         * 不切割，完全按照分片的大小处理数据
         */
        NOOP,
        /**
         * 向前追加，如果当前页中存在与上一页主ID相同的数据，则将其追加至上一分片，该模式下数据量是不小于分片数量
         */
        FORWARD,
        /**
         * 向后追加，超过分片限制时，则将主ID相同的数据拆到下一分片
         */
        BACKWARD
    }
}
