package com.elitesland.cbpl.tool.core.http;

import cn.hutool.core.net.url.UrlBuilder;
import cn.hutool.core.net.url.UrlQuery;
import cn.hutool.core.util.CharsetUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.extra.servlet.ServletUtil;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.net.URI;
import java.util.Enumeration;
import java.util.Map;

/**
 * @author eric.hao
 * @since 2023/07/01
 */
public class HttpServletUtil {

    /**
     * 获取客户端真实IP
     *
     * @return 客户端IP
     */
    public static String currentClientIp() {
        HttpServletRequest request = currentRequest();
        if (request == null) {
            return null;
        }
        return ServletUtil.getClientIP(request);
    }

    /**
     * 获取当前请求
     *
     * @return request
     */
    public static HttpServletRequest currentRequest() {
        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        if (requestAttributes == null) {
            return null;
        }
        return ((ServletRequestAttributes) requestAttributes).getRequest();
    }

    /**
     * 获取HttpServletRequest包装类
     *
     * @return requestWrapper
     */
    public static RequestWrapper currentRequestWrapper() {
        HttpServletRequest request = currentRequest();
        return ObjectUtil.isNotNull(request) ? new RequestWrapper(request) : null;
    }

    /**
     * 解析请求参数
     * <p>
     * 包含查询参数和表单参数
     *
     * @param request The portlet request to be parsed.
     * @return
     */
    public static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
        Map<String, String[]> parameterMap = request.getParameterMap();
        MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
        parameterMap.forEach((key, values) -> {
            if (values.length > 0) {
                for (String value : values) {
                    parameters.add(key, value);
                }
            } else {
                parameters.add(key, null);
            }
        });
        return parameters;
    }

    /**
     * 获取查询参数
     * <p>
     * 仅返回查询参数
     *
     * @param request The portlet request to be parsed.
     * @return
     */
    public static MultiValueMap<String, String> getQueryParameters(HttpServletRequest request) {
        String queryString = request.getQueryString();
        if (!org.springframework.util.StringUtils.hasText(queryString)) {
            return new LinkedMultiValueMap<>(0);
        }

        String[] parameterArray = queryString.split("&");
        MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterArray.length);
        for (String s : parameterArray) {
            if (!StringUtils.hasText(s)) {
                continue;
            }
            String[] params = s.split("=");
            parameters.add(params[0], params.length == 2 ? params[1] : null);
        }

        return parameters;
    }

    /**
     * 获取查询参数
     *
     * @param uri URI
     * @return 查询参数对象
     */
    public static UrlQuery getQueryParameters(URI uri) {
        UrlBuilder builder = UrlBuilder.of(uri, CharsetUtil.CHARSET_UTF_8);
        return builder.getQuery();
    }

    /**
     * 获取查询参数
     *
     * @param httpUrl URL字符串
     * @return 查询参数对象
     */
    public static UrlQuery getQueryParameters(String httpUrl) {
        UrlBuilder builder = UrlBuilder.ofHttp(httpUrl, CharsetUtil.CHARSET_UTF_8);
        return builder.getQuery();
    }

    /**
     * 获取表单参数
     * <p>
     * 仅返回表单部分参数
     *
     * @param request The portlet request to be parsed.
     * @return
     */
    public static MultiValueMap<String, String> getFormParameters(HttpServletRequest request) {
        MultiValueMap<String, String> queryParams = getQueryParameters(request);
        if (queryParams.isEmpty()) {
            return getParameters(request);
        }

        Map<String, String[]> parameterMap = request.getParameterMap();
        MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
        parameterMap.forEach((key, values) -> {
            if (queryParams.containsKey(key)) {
                return;
            }
            if (values.length > 0) {
                for (String value : values) {
                    parameters.add(key, value);
                }
            } else {
                parameters.add(key, null);
            }
        });
        return parameters;
    }

    /**
     * 获取所有的请求头
     *
     * @param request The portlet request to be parsed.
     * @return 请求头
     */
    public static MultiValueMap<String, String> getHeaders(HttpServletRequest request) {
        MultiValueMap<String, String> headerMap = new LinkedMultiValueMap<>(64);
        Enumeration<String> names = request.getHeaderNames();
        while (names.hasMoreElements()) {
            String name = names.nextElement();
            Enumeration<String> values = request.getHeaders(name);
            while (values.hasMoreElements()) {
                headerMap.add(name, values.nextElement());
            }
        }
        return headerMap;
    }

    /**
     * ServerHttpRequest -> RequestWrapper
     *
     * @param request The portlet request to be parsed.
     * @return 请求包装类
     */
    public static RequestWrapper wrapper(ServerHttpRequest request) {
        ServletServerHttpRequest serverHttpRequest = (ServletServerHttpRequest) request;
        return new RequestWrapper(serverHttpRequest.getServletRequest());
    }

    /**
     * HttpServletResponse -> ResponseWrapper
     *
     * @param response 客户端响应
     * @return 响应包装类
     */
    public static ResponseWrapper wrapper(ServerHttpResponse response) {
        ServletServerHttpResponse serverHttpResponse = (ServletServerHttpResponse) response;
        return new ResponseWrapper(serverHttpResponse.getServletResponse());
    }
}
