package com.elitescloud.boot.log.interceptor;

import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.extra.servlet.ServletUtil;
import com.elitescloud.boot.auth.util.SecurityContextUtil;
import com.elitescloud.boot.constant.CommonConstant;
import com.elitescloud.boot.log.LogProperties;
import com.elitescloud.boot.log.model.bo.AccessLogBO;
import com.elitescloud.boot.log.common.RequestBodyDesensitize;
import com.elitescloud.boot.log.common.ResponseBodyDesensitize;
import com.elitescloud.boot.log.queue.LogEvent;
import com.elitescloud.boot.support.CloudtInterceptor;
import com.elitescloud.boot.threadpool.common.ThreadPoolHolder;
import com.elitescloud.boot.wrapper.CloudtRequestWrapper;
import com.elitescloud.cloudt.common.base.ApiResult;
import com.elitescloud.boot.exception.BusinessException;
import com.lmax.disruptor.RingBuffer;
import io.swagger.annotations.ApiOperation;
import lombok.extern.log4j.Log4j2;
import org.slf4j.MDC;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.web.servlet.error.ErrorAttributes;
import org.springframework.core.MethodParameter;
import org.springframework.core.Ordered;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.PathContainer;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.NonNull;
import org.springframework.util.CollectionUtils;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.mvc.method.annotation.RequestBodyAdvice;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyAdvice;
import org.springframework.web.servlet.resource.ResourceHttpRequestHandler;
import org.springframework.web.util.pattern.PathPattern;
import org.springframework.web.util.pattern.PathPatternParser;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;

/**
 * 接口访问日志的拦截器.
 *
 * @author Kaiser（wang shao）
 * @date 2022/8/18
 */
@Log4j2
@ControllerAdvice
@ConditionalOnBean(LogProperties.class)
public class AccessLogInterceptor<T extends Serializable> implements CloudtInterceptor, RequestBodyAdvice, ResponseBodyAdvice<T> {

    private final LogProperties configProperties;
    private final RingBuffer<LogEvent> ringBuffer;
    private final List<PathPattern> excludeRequestMatcher;
    private final List<PathPattern> includeRecordReqBodyMatcher;
    private final List<PathPattern> includeRecordRespBodyMatcher;
    private final Executor executor;

    private final ThreadLocal<AccessLogBO> accessLogThreadLocal = new ThreadLocal<>();

    private ErrorAttributes errorAttributes;
    private List<RequestBodyDesensitize> requestBodyDesensitizes;
    private List<ResponseBodyDesensitize> responseBodyDesensitizes;

    public AccessLogInterceptor(LogProperties configProperties,
                                RingBuffer<LogEvent> ringBuffer) {
        this.configProperties = configProperties;
        this.ringBuffer = ringBuffer;
        this.excludeRequestMatcher = createExcludeRequestMatcher(configProperties.getAccessLog().getExcludedRequest());
        this.includeRecordReqBodyMatcher = convertRequestMatcher(configProperties.getAccessLog().getRecordRequestBody());
        this.includeRecordRespBodyMatcher = convertRequestMatcher(configProperties.getAccessLog().getRecordResponseBody());
        this.executor = createExecutor();
    }

    @Override
    public int order() {
        return Ordered.LOWEST_PRECEDENCE;
    }

    @Override
    public boolean preHandle(@NonNull HttpServletRequest request, @NonNull HttpServletResponse response, @NonNull Object handler) throws Exception {
        // 判断是否是排除的接口
        if (handler instanceof ResourceHttpRequestHandler || isMatch(request, excludeRequestMatcher)) {
            return true;
        }

        // 添加至线程上下文中
        var logBo = new AccessLogBO();
        logBo.setRequestTime(LocalDateTime.now());
        accessLogThreadLocal.set(logBo);

        return true;
    }

    @Override
    public void afterCompletion(@NonNull HttpServletRequest request, @NonNull HttpServletResponse response, @NonNull Object handler, Exception ex) throws Exception {
        AccessLogBO accessLogBO = accessLogThreadLocal.get();
        accessLogThreadLocal.remove();
        if (accessLogBO == null) {
            return;
        }
        if (ex == null) {
            // 默认异常
            ex = (Exception) obtainException(request);
        }

        // 生成日志对象
        fillAccessLog(accessLogBO, request, handler, ex)
                // 添加至队列
                .thenAccept(this::addToQueue)
                .whenComplete((r, e) -> {
                    if (e != null) {
                        log.error("组装接口日志参数异常：", e);
                    }
                })
        ;
    }

    @Override
    public boolean supports(@NonNull MethodParameter methodParameter, @NonNull Type targetType, @NonNull Class<? extends HttpMessageConverter<?>> converterType) {
        return true;
    }

    @NonNull
    @Override
    public HttpInputMessage beforeBodyRead(@NonNull HttpInputMessage inputMessage, @NonNull MethodParameter parameter, @NonNull Type targetType, @NonNull Class<? extends HttpMessageConverter<?>> converterType) throws IOException {
        return inputMessage;
    }

    @NonNull
    @Override
    public Object afterBodyRead(@NonNull Object body, @NonNull HttpInputMessage inputMessage, @NonNull MethodParameter parameter, @NonNull Type targetType, @NonNull Class<? extends HttpMessageConverter<?>> converterType) {
        AccessLogBO accessLogBO = accessLogThreadLocal.get();
        if (accessLogBO != null) {
            List<Object> bodies = ObjectUtil.defaultIfNull(accessLogBO.getRequestBody(), new ArrayList<>(4));
            bodies.add(body);
            accessLogBO.setRequestBody(bodies);
        }
        return body;
    }

    @Override
    public Object handleEmptyBody(Object body, @NonNull HttpInputMessage inputMessage, @NonNull MethodParameter parameter, @NonNull Type targetType, @NonNull Class<? extends HttpMessageConverter<?>> converterType) {
        return body;
    }

    @Override
    public boolean supports(@NonNull MethodParameter returnType, @NonNull Class converterType) {
        var parameterType = returnType.getParameterType();
        return Serializable.class.isAssignableFrom(parameterType);
    }

    @Override
    public T beforeBodyWrite(T body, @NonNull MethodParameter returnType, @NonNull MediaType selectedContentType,
                             @NonNull Class<? extends HttpMessageConverter<?>> selectedConverterType, @NonNull ServerHttpRequest request, @NonNull ServerHttpResponse response) {
        AccessLogBO accessLogBO = accessLogThreadLocal.get();
        if (accessLogBO != null) {
            accessLogBO.setResult(body);
        }
        return body;
    }

    @Autowired(required = false)
    public void setErrorAttributes(ErrorAttributes errorAttributes) {
        this.errorAttributes = errorAttributes;
    }

    @Autowired
    public void setRequestBodyDesensitizeProvider(ObjectProvider<RequestBodyDesensitize> requestBodyDesensitizeProvider) {
        this.requestBodyDesensitizes = requestBodyDesensitizeProvider.stream().collect(Collectors.toList());
    }

    @Autowired
    public void setResponseBodyDesensitizeProvider(ObjectProvider<ResponseBodyDesensitize> requestBodyDesensitizeProvider) {
        this.responseBodyDesensitizes = requestBodyDesensitizeProvider.stream().collect(Collectors.toList());
    }

    private boolean isMatch(HttpServletRequest request, List<PathPattern> matchers) {
        if (matchers.isEmpty()) {
            return false;
        }

        var pathContainer = PathContainer.parsePath(request.getRequestURI());
        for (PathPattern requestMatcher : matchers) {
            if (requestMatcher.matches(pathContainer)) {
                return true;
            }
        }
        return false;
    }

    private void addToQueue(AccessLogBO accessLogBO) {
        var msgSequence = ringBuffer.next();
        try {
            ringBuffer.get(msgSequence)
                    .setLog(accessLogBO);
        } catch (Exception e) {
            log.error("添加接口访问日志到队列时异常：", e);
        } finally {
            ringBuffer.publish(msgSequence);
        }
    }

    private CompletableFuture<AccessLogBO> fillAccessLog(AccessLogBO accessLogBO, HttpServletRequest request, Object handler, Exception ex) {
        accessLogBO.setResponseTime(LocalDateTime.now());
        accessLogBO.setThreadId(Thread.currentThread().getName());
        accessLogBO.setTraceId(MDC.get(CommonConstant.LOG_TRACE_ID));

        // controller方法·
        Method handlerMethod = handler instanceof HandlerMethod ? ((HandlerMethod) handler).getMethod() : null;
        if (handlerMethod == null) {
            log.warn("未能识别的方法类型：{}", handler.getClass());
        }
        return CompletableFuture.supplyAsync(() -> {
            // 记录请求信息
            fillRequestInfo(accessLogBO, request, handlerMethod);

            // 记录响应信息
            fillResponseInfo(accessLogBO, request, handlerMethod, ex);

            return accessLogBO;
        }, executor);
    }

    private void fillRequestInfo(AccessLogBO accessLogBO, HttpServletRequest request, Method handlerMethod) {
        String token = SecurityContextUtil.currentToken();

        accessLogBO.setToken(token);
        accessLogBO.setUserAgent(request.getHeader(HttpHeaders.USER_AGENT));
        accessLogBO.setMethod(request.getMethod());
        accessLogBO.setReqContentType(request.getContentType());
        accessLogBO.setUri(request.getRequestURI());
        accessLogBO.setOperation(obtainOperation(handlerMethod));
        accessLogBO.setReqIp(ServletUtil.getClientIP(request));
        accessLogBO.setReqOuterIp(request.getRemoteAddr());
        accessLogBO.setQueryParams(request.getQueryString());

        // 处理请求体
        if (accessLogBO.getRequestBody() != null && isMatch(request, includeRecordReqBodyMatcher)) {
            // 脱敏操作
            if (!requestBodyDesensitizes.isEmpty()) {
                for (RequestBodyDesensitize requestBodyDesensitize : requestBodyDesensitizes) {
                    if (requestBodyDesensitize.support(handlerMethod, request)) {
                        List<Object> newBodies = new ArrayList<>(4);
                        for (Object obj : accessLogBO.getRequestBody()) {
                            var newObj = requestBodyDesensitize.desensitize(obj);
                            if (newObj != null) {
                                newBodies.add(newObj);
                            }
                        }
                        accessLogBO.setRequestBody(newBodies);
                        break;
                    }
                }
            }
        } else {
            // 不需要记录请求参数
            accessLogBO.setRequestBody(null);
        }
    }

    private void fillResponseInfo(AccessLogBO accessLogBO, HttpServletRequest request, Method handlerMethod, Throwable ex) {
        var result = accessLogBO.getResult();
        if (result != null && isMatch(request, includeRecordRespBodyMatcher)) {
            // 脱敏操作
            if (!includeRecordRespBodyMatcher.isEmpty()) {
                for (ResponseBodyDesensitize responseBodyDesensitize : responseBodyDesensitizes) {
                    if (responseBodyDesensitize.support(handlerMethod, request)) {
                        accessLogBO.setResult(responseBodyDesensitize.desensitize(result));
                        break;
                    }
                }
            }
        } else {
            // 不需要记录请求结果
            accessLogBO.setResult(null);
        }
        if (result instanceof ApiResult) {
            ApiResult<?> apiResult = (ApiResult<?>) result;
            accessLogBO.setResultCode(apiResult.getCode());
            accessLogBO.setMsg(apiResult.getMsg());
        } else {
            if (ex == null) {
                accessLogBO.setResultCode(HttpStatus.OK.value());
                accessLogBO.setMsg("操作成功");
            } else if (ex instanceof BusinessException) {
                BusinessException exp = (BusinessException) ex;
                accessLogBO.setResultCode(ObjectUtil.defaultIfNull(exp.getCode(), HttpStatus.INTERNAL_SERVER_ERROR.value()));
                accessLogBO.setMsg(CharSequenceUtil.blankToDefault(exp.getMessage(), "操作失败"));
            } else {
                accessLogBO.setResultCode(HttpStatus.INTERNAL_SERVER_ERROR.value());
                accessLogBO.setMsg("操作失败");
            }
        }

        accessLogBO.setThrowable(ex);
    }

    private String obtainOperation(Method method) {
        if (method == null) {
            return null;
        }
        var apiOperation = method.getAnnotation(ApiOperation.class);
        if (apiOperation != null) {
            return apiOperation.value();
        }

        return null;
    }

    private String obtainRequestBody(HttpServletRequest request) {
        if (request instanceof CloudtRequestWrapper) {
            return ((CloudtRequestWrapper) request).getBodyString();
        }

        return null;
    }

    private Throwable obtainException(HttpServletRequest request) {
        Throwable throwable = (Throwable) request.getAttribute(DispatcherServlet.EXCEPTION_ATTRIBUTE);
        if (throwable != null) {
            return throwable;
        }

        if (errorAttributes != null) {
            return errorAttributes.getError(new ServletWebRequest(request));
        }
        return null;
    }

    private List<PathPattern> createExcludeRequestMatcher(List<LogProperties.Matcher> matchers) {
        List<PathPattern> requestMatchers = CloudtInterceptor.convert2RequestMatcher(CloudtInterceptor.staticResourcePatter());
        requestMatchers.addAll(convertRequestMatcher(matchers));

        return requestMatchers;
    }

    private List<PathPattern> convertRequestMatcher(List<LogProperties.Matcher> matchers) {
        if (CollectionUtils.isEmpty(matchers)) {
            return Collections.emptyList();
        }

        List<PathPattern> requestMatchers = new ArrayList<>(matchers.size());
        for (LogProperties.Matcher matcher : matchers) {
            requestMatchers.add(PathPatternParser.defaultInstance.parse(matcher.getUri()));
        }
        return requestMatchers;
    }

    private Executor createExecutor() {
        var threadPool = configProperties.getThreadPool();
        return ThreadPoolHolder.createThreadPool(threadPool.getThreadNamePrefix(), threadPool.getCoreSize(), threadPool.getMaxSize());
    }
}
