package com.elitescloud.boot.support;

import com.elitescloud.boot.CloudtSpringContextHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.discovery.DiscoveryClient;
import org.springframework.cloud.client.loadbalancer.LoadBalancerUriTools;
import org.springframework.http.HttpRequest;
import org.springframework.http.client.ClientHttpRequestExecution;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.support.HttpRequestWrapper;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

import java.io.IOException;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;

/**
 * 自动发现服务端地址的请求拦截器.
 *
 * @author Kaiser（wang shao）
 * @date 2024/1/30
 */
public class DiscoveryClientHttpRequestInterceptor implements ClientHttpRequestInterceptor {
    private static final Logger LOG = LoggerFactory.getLogger(DiscoveryClientHttpRequestInterceptor.class);

    private final Set<String> serviceNames;

    private DiscoveryClient discoveryClient;

    public DiscoveryClientHttpRequestInterceptor(Set<String> serviceNames) {
        this.serviceNames = serviceNames;
    }

    @NonNull
    @Override
    public ClientHttpResponse intercept(@NonNull HttpRequest request, @NonNull byte[] body, @NonNull ClientHttpRequestExecution execution) throws IOException {
        var uri = request.getURI();
        Assert.notNull(uri, "Request URI为空");

        // 重构uri
        var newUri = this.reBuildUri(uri);
        if (newUri == null) {
            LOG.info("request uri is ：{}", uri);
            return execution.execute(request, body);
        }

        // 调用重构后的uri
        LOG.info("new request uri is: {}", newUri);
        return execution.execute(new HttpRequestWrapper(request) {
            @Override
            public URI getURI() {
                return newUri;
            }
        }, body);
    }

    private URI reBuildUri(URI uri) {
        var host = uri.getHost();
        if (!StringUtils.hasText(host)) {
            return null;
        }

        if (!CollectionUtils.isEmpty(serviceNames)) {
            if(serviceNames.contains(host)) {
                return this.reBuildUriByDiscoveryClient(uri, host);
            }
            return null;
        }

        return this.reBuildUriByDiscoveryClient(uri, host);
    }

    private URI reBuildUriByDiscoveryClient(URI uri, String serviceName) {
        List<ServiceInstance> instances = null;
        try {
            instances = getDiscoveryClient().getInstances(serviceName);
        } catch (Exception e) {
            LOG.warn("服务注册中心获取服务失败：{}", serviceName, e);
        }
        if (CollectionUtils.isEmpty(instances)) {
            LOG.info("未发现有效服务：{}", serviceName);
            return null;
        }
        var instanceNum = instances.size();

        var instance = instanceNum == 1 ? instances.get(0) : instances.get(ThreadLocalRandom.current().nextInt(0, instanceNum));
        return LoadBalancerUriTools.reconstructURI(new CustomServiceInstance(instance), uri);
    }

    private DiscoveryClient getDiscoveryClient() {
        if (discoveryClient == null) {
            discoveryClient = CloudtSpringContextHolder.getApplicationContext().getBean(DiscoveryClient.class);
        }
        return discoveryClient;
    }

    public static class CustomServiceInstance implements ServiceInstance {

        private final ServiceInstance instance;

        public CustomServiceInstance(ServiceInstance instance) {
            this.instance = instance;
        }

        @Override
        public String getServiceId() {
            return instance.getServiceId();
        }

        @Override
        public String getHost() {
            return instance.getHost();
        }

        @Override
        public int getPort() {
            return instance.getPort();
        }

        @Override
        public boolean isSecure() {
            return instance.isSecure();
        }

        @Override
        public URI getUri() {
            return instance.getUri();
        }

        @Override
        public Map<String, String> getMetadata() {
            return instance.getMetadata();
        }

        @Override
        public String getInstanceId() {
            return instance.getInstanceId();
        }

        @Override
        public String getScheme() {
            var scheme = instance.getScheme();
            if (StringUtils.hasText(scheme)) {
                return scheme;
            }

            return instance.getUri().getScheme();
        }
    }
}
