package com.el.coordinator.boot.fsm.support;

import cn.hutool.core.io.IORuntimeException;
import cn.hutool.core.io.resource.ResourceUtil;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.RuntimeUtil;
import com.el.coordinator.core.common.api.ApiResult;
import com.el.coordinator.core.common.constant.ConstantFsm;
import com.el.coordinator.core.common.exception.BusinessException;
import com.el.coordinator.file.api.FileUser;
import com.el.coordinator.file.api.FsmApiConstant;
import com.el.coordinator.file.business.dto.ImportRateDTO;
import com.el.coordinator.file.business.dto.TmplDTO;
import com.el.coordinator.file.business.param.ImportResultDTO;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.InputStreamResource;
import org.springframework.data.redis.connection.ReturnType;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.http.*;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RequestCallback;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.multipart.MultipartFile;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

/**
 * .
 *
 * @author Kaiser（wang shao）
 * @date 2021/6/3
 */
@Slf4j
public class FsmTmplSupport {

    private final RestTemplate restTemplate;

    // 限流相关属性
    private static RedisTemplate<Object, Object> redisTemplate;
    private static byte[] luaLimiterAdd = null;
    private static byte[] luaLimiterSubtract = null;
    private static volatile boolean destroy = false;
    private static final ConcurrentMap<String, AtomicLong> LOCAL_LIMITER = new ConcurrentHashMap<>();
    private static final byte[] LIMITER_BUSINESS = "{yst}:el_fsm_tmpl".getBytes(StandardCharsets.UTF_8);
    private static final String CACHE_KEY_IMPORT_RATE_PREFIX = "el_fsm_tmpl_import_";

    static {
//        RuntimeUtil.addShutdownHook(FsmTmplSupport::shutdown);
    }

    public FsmTmplSupport(RestTemplate restTemplate, RedisTemplate<Object, Object> redisTemplate) {
        this.restTemplate = restTemplate;
        FsmTmplSupport.redisTemplate = redisTemplate;
    }

    static {
        try {
            luaLimiterAdd = ResourceUtil.readBytes("classpath:redis/limiter_add.lua");
            luaLimiterSubtract = ResourceUtil.readBytes("classpath:redis/limiter_subtract.lua");
        } catch (IORuntimeException e) {
            log.warn("加载限流资源文件失败", e);
        }
    }

    /**
     * 根据模板编号下载模板文件
     *
     * @param code 模板编号
     * @return 模板文件
     */
    public ResponseEntity<InputStreamResource> downloadByCode(String code) {
        RequestCallback requestCallback = request -> request.getHeaders().setAccept(List.of(MediaType.APPLICATION_OCTET_STREAM, MediaType.ALL));

        var outputStream = new ByteArrayOutputStream();
        AtomicReference<HttpHeaders> headers = new AtomicReference<>();
        try {
            restTemplate.execute(FsmApiConstant.API_TMPL_DOWNLOAD, HttpMethod.GET, requestCallback, clientHttpResponse -> {
                headers.set(clientHttpResponse.getHeaders());
                IOUtils.copy(clientHttpResponse.getBody(), outputStream);
                return null;
            }, code);
        } catch (Exception e) {
            log.error("下载模板文件失败:{}", code, e);
            return ResponseEntity.badRequest().build();
        }

        return ResponseEntity.ok()
                .headers(headers.get())
                .body(new InputStreamResource(new ByteArrayInputStream(outputStream.toByteArray())));
    }

    /**
     * 根据模板编号获取模板信息
     *
     * @param code 模板编号
     * @return 模板信息
     */
    public TmplDTO getTmplByCode(String code) {
        String cacheKey = StringUtils.hasText(code) ? code : "unknown";
        TmplDTO tmplDTO = (TmplDTO) redisTemplate.opsForHash().get(ConstantFsm.CACHE_KEY_TMPL_DETAIL, cacheKey);
        if (tmplDTO != null) {
            return tmplDTO;
        }

        ApiResult<TmplDTO> result = remoteFsmExchange(FsmApiConstant.API_TMPL_INFO, HttpMethod.GET,
                null, new ParameterizedTypeReference<>() {
                }, code);
        if (result == null || !result.isSuccess()) {
            log.error("查询模板【{}】信息失败：{}", code, result);
            throw new BusinessException("查询模板信息失败");
        }

        return result.getData();
    }

    /**
     * 保存导入导出记录
     *
     * @param code        模板编号
     * @param dataFile    数据文件
     * @param currentUser 当前用户
     * @return 记录标识
     */
    public Long saveRecord(String code, MultipartFile dataFile, FileUser currentUser, Map<String, Object> param) {
        MultiValueMap<String, Object> postParam = new LinkedMultiValueMap<>(8);
        if (currentUser != null) {
            postParam.add("userId", currentUser.getUserId());
            postParam.add("userName", currentUser.getUserName());
        }
        if (param != null && !param.isEmpty()) {
            param.forEach(postParam::add);
        }
        if (dataFile != null) {
            postParam.add("file", dataFile.getResource());
        }

        ApiResult<Long> result = remoteFsmExchange(FsmApiConstant.API_TMPL_IMPORT_RECORD_SAVE, HttpMethod.POST, new HttpEntity<>(postParam)
                , new ParameterizedTypeReference<>() {
                }, code);
        if (result == null || !result.isSuccess()) {
            log.error("保存模板【{}】导入数据记录失败：{}", code, result);
            throw new BusinessException("保存导入记录失败");
        }

        return result.getData();
    }

    /**
     * 更新导入数量
     *
     * @param importId 导入记录标识
     * @param numTotal 导入总数量
     */
    public void updateImportNum(Long importId, Long numTotal) {
        ApiResult<Long> result = remoteFsmExchange(FsmApiConstant.API_TMPL_IMPORT_TOTAL, HttpMethod.PATCH, null,
                new ParameterizedTypeReference<>() {
                }, importId, ObjectUtil.defaultIfNull(numTotal, 0L));
        if (result == null || !result.isSuccess()) {
            throw new BusinessException("更新导入记录失败");
        }
    }

    /**
     * 更新导入结果
     *
     * @param importId        导入记录标识
     * @param importResultDTO 导入结果
     */
    public void updateImportResult(Long importId, ImportResultDTO importResultDTO) {
        ApiResult<Long> result = remoteFsmExchange(FsmApiConstant.API_TMPL_IMPORT_RESULT,
                HttpMethod.PATCH, new HttpEntity<>(importResultDTO), new ParameterizedTypeReference<>() {
                }, importId);
        if (result == null || !result.isSuccess()) {
            throw new BusinessException("更新导入结果失败");
        }
    }

    /**
     * 保存导出文件
     *
     * @param importId 导入记录标识
     * @param fileCode 导出文件编码
     * @param order    顺序
     * @return
     */
    public Long saveExportFile(Long importId, String fileCode, int order) {
        ApiResult<Long> result = remoteFsmExchange(FsmApiConstant.API_TMPL_EXPORT_FILE, HttpMethod.POST, null
                , new ParameterizedTypeReference<>() {
                }, importId, fileCode, order);
        if (result == null || !result.isSuccess()) {
            log.error("保存模板【{}, {}】导出数据记录失败：{}", importId, order, result);
            throw new BusinessException("保存导出记录失败");
        }

        return result.getData();
    }

    /**
     * 更新限流数据
     * <p>
     * 如果返回-1，标识设置失败
     *
     * @param key  唯一标识
     * @param max  最大限流量
     * @param step 步，最小为1
     * @param add  是否是增加
     * @return 是否更新成功，若false，标识超过限流
     */
    public static boolean updateLimiter(String key, String max, int step, boolean add) {
        step = Math.max(step, 1);
        byte[] stepBytes = (step + "").getBytes(StandardCharsets.UTF_8);

        Long result;
        if (add) {
            if (ArrayUtil.isEmpty(luaLimiterAdd) || redisTemplate == null) {
                throw new BusinessException("设置限流失败");
            }
            result = redisTemplate.execute((RedisCallback<Long>) connection -> connection.eval(luaLimiterAdd,
                    ReturnType.INTEGER, 1, LIMITER_BUSINESS, key.getBytes(StandardCharsets.UTF_8),
                    max.getBytes(StandardCharsets.UTF_8), stepBytes));
        } else {
            if (ArrayUtil.isEmpty(luaLimiterSubtract) || redisTemplate == null) {
                throw new BusinessException("设置限流失败");
            }
            result = redisTemplate.execute((RedisCallback<Long>) connection -> connection.eval(luaLimiterSubtract,
                    ReturnType.INTEGER, 1, LIMITER_BUSINESS, key.getBytes(StandardCharsets.UTF_8), stepBytes));
        }

        if (result != null && result != -1) {
            LOCAL_LIMITER.computeIfAbsent(key, k -> new AtomicLong(0L)).addAndGet(step * (add ? 1L : -1L));
            return true;
        }
        return false;
    }

    /**
     * 缓存导入进度
     *
     * @param importId      导入记录ID
     * @param importRateDTO 导入进度
     */
    public void storeImportRate(Long importId, ImportRateDTO importRateDTO) {
        redisTemplate.opsForValue().set(CACHE_KEY_IMPORT_RATE_PREFIX + importId, importRateDTO, Duration.ofMinutes(20));
    }

    /**
     * 移除导入进度
     *
     * @param importId 导入记录ID
     */
    public void removeImportRate(Long importId) {
        redisTemplate.delete(CACHE_KEY_IMPORT_RATE_PREFIX + importId);
    }

    /**
     * 获取导入进度
     *
     * @param importId 导入记录ID
     * @return 导入进度
     */
    public ImportRateDTO getImportRateFromCache(Long importId) {
        return (ImportRateDTO) redisTemplate.opsForValue().get(CACHE_KEY_IMPORT_RATE_PREFIX + importId);
    }

    /**
     * 获取导入进度
     *
     * @param importId 导入记录ID
     * @return 导入进度
     */
    public ImportRateDTO getImportRate(Long importId) {
        ApiResult<ImportRateDTO> result = remoteFsmExchange(FsmApiConstant.API_TMPL_IMPORT_RATE, HttpMethod.GET,
                null, new ParameterizedTypeReference<>() {
                }, importId);
        if (result == null || !result.isSuccess()) {
            log.error("查询导入进度【{}】失败：{}", importId, result);
            throw new BusinessException("查询导入进度失败");
        }

        return result.getData();
    }

    /**
     * 获取记录的文件编号
     *
     * @param importId 导入记录ID
     * @return 导入进度
     */
    public String getRecordFileCode(Long importId) {
        ApiResult<String> result = remoteFsmExchange(FsmApiConstant.API_TMPL_RECORD_FILE_CODE, HttpMethod.GET,
                null, new ParameterizedTypeReference<>() {
                }, importId);
        if (result == null || !result.isSuccess()) {
            log.error("查询记录的文件编号【{}】失败：{}", importId, result);
            throw new BusinessException("获取文件失败");
        }

        return result.getData();
    }

    /**
     * 获取未导入结束的
     *
     * @return 导入记录ID
     */
    public List<Long> getUnFinished() {
        ApiResult<List<Long>> result = remoteFsmExchange(FsmApiConstant.API_TMPL_IMPORT_UNFINISHED_ID, HttpMethod.GET,
                null, new ParameterizedTypeReference<>() {
                });
        if (result == null || !result.isSuccess()) {
            log.error("查询未导入结束的失败：{}", result);
            throw new BusinessException("查询未导入结束的失败");
        }

        return result.getData();
    }

    private static void shutdown() {
        if (destroy) {
            return;
        }
        destroy = true;
        log.info("数据导入服务销毁...");

        for (Map.Entry<String, AtomicLong> entry : LOCAL_LIMITER.entrySet()) {
            updateLimiter(entry.getKey(), "1", entry.getValue().intValue(), false);
        }
    }

    private <T> T remoteFsmExchange(String url, HttpMethod httpMethod, HttpEntity<?> httpEntity, ParameterizedTypeReference<T> responseType,
                                    Object... param) {
        ResponseEntity<T> response = null;
        try {
            response = restTemplate.exchange(url, httpMethod, httpEntity, responseType, param);
        } catch (RestClientException e) {
            log.error("文件服务器调用失败：", e);
            throw new BusinessException("文件服务器异常");
        }

        if (response.getStatusCode() != HttpStatus.OK) {
            log.error("调用文件服务器接口失败：{}", response);
            throw new BusinessException("调用文件服务器接口失败");
        }

        return response.getBody();
    }
}
