package com.elitescloud.boot.wrapper;

import cn.hutool.core.collection.IteratorEnumeration;
import cn.hutool.core.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Date;
import java.util.Enumeration;
import java.util.List;

/**
 * 云梯自定义Request.
 *
 * @author Kaiser（wang shao）
 * @date 2022/8/21
 */
public class CloudtRequestWrapper extends HttpServletRequestWrapper {
    private static final Logger LOG = LoggerFactory.getLogger(CloudtRequestWrapper.class);

    /**
     * 自定义header，允许添加新的header
     */
    private final MultiValueMap<String, String> headers = new LinkedMultiValueMap<>(8);

    private byte[] body;

    public CloudtRequestWrapper(HttpServletRequest request) {
        super(request);
    }

    /**
     * 添加请求头
     *
     * @param name headerName
     * @param value headerValue
     */
    public void addHeader(String name, String value) {
        headers.add(name, value);
    }

    /**
     * 获取请求头
     *
     * @return 请求头
     */
    public String getBodyString() {
        if (ArrayUtil.isEmpty(getBody())) {
            return null;
        }

        return new String(getBody());
    }

    @Override
    public long getDateHeader(String name) {
        String value = headers.getFirst(name);
        if (StringUtils.hasText(value)) {
            Date date = new Date(Long.parseLong(value));
            return date.getTime();
        }
        return super.getDateHeader(name);
    }

    @Override
    public String getHeader(String name) {
        String value = headers.getFirst(name);
        if (StringUtils.hasText(value)) {
            return value;
        }
        return super.getHeader(name);
    }

    @Override
    public Enumeration<String> getHeaders(String name) {
        List<String> values = headers.get(name);
        if (!CollectionUtils.isEmpty(values)) {
            List<String> result = new ArrayList<>(values);

            // 添加原来的header
            var headersOriginal = super.getHeaders(name);
            while (headersOriginal.hasMoreElements()) {
                result.add(headersOriginal.nextElement());
            }

            return new IteratorEnumeration<>(result.iterator());
        }
        return super.getHeaders(name);
    }

    @Override
    public Enumeration<String> getHeaderNames() {
        if (!headers.isEmpty()) {
            List<String> result = new ArrayList<>(headers.keySet());

            // 添加原来的header
            var headersOriginal = super.getHeaderNames();
            while (headersOriginal.hasMoreElements()) {
                result.add(headersOriginal.nextElement());
            }

            return new IteratorEnumeration<>(result.iterator());
        }
        return super.getHeaderNames();
    }

    @Override
    public int getIntHeader(String name) {
        String value = headers.getFirst(name);
        if (StringUtils.hasText(value)) {
            return Integer.parseInt(value);
        }
        return super.getIntHeader(name);
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream inputStream = new ByteArrayInputStream(getBody());
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() throws IOException {
                return inputStream.read();
            }
        };
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    private byte[] getBody() {
        if (body == null) {
            body = parseBody(super.getRequest());
        }
        return body;
    }

    private byte[] parseBody(ServletRequest request) {
        try {
            return request.getInputStream().readAllBytes();
        } catch (Exception e) {
            LOG.error("read body exception：", e);
        }
        return new byte[0];
    }
}