package com.elitesland.cloudt.tenant.util;

import com.elitesland.cloudt.context.SpringContextHolder;
import com.elitesland.cloudt.tenant.config.TenantClientProperties;
import com.elitesland.cloudt.tenant.core.common.ConstantTenant;
import com.elitesland.cloudt.tenant.provider.TenantProvider;
import com.elitesland.yst.common.exception.BusinessException;
import com.elitesland.yst.core.security.util.SecurityUtil;
import com.elitesland.yst.system.dto.SysTenantDTO;
import lombok.extern.log4j.Log4j2;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;

/**
 * .
 *
 * @author Kaiser（wang shao）
 * @date 2022/4/9
 */
@Log4j2
public class TenantRequestUtil {

    private static TenantClientProperties clientProperties;
    private static TenantProvider tenantProvider;

    private TenantRequestUtil() {
    }

    /**
     * 解析租户域名
     *
     * @param url           请求url
     * @return 当前请求的域名
     */
    public static String obtainTenantDomain(String url) {
        boolean isHttpProtocol = false;
        if (url.startsWith(ConstantTenant.PROTOCOL_HTTP)) {
            url = url.substring(7);
            isHttpProtocol = true;
        } else if (url.startsWith(ConstantTenant.PROTOCOL_HTTPS)) {
            url = url.substring(8);
            isHttpProtocol = true;
        }

        if (!isHttpProtocol) {
            return null;
        }

        int lastIndex = url.indexOf("/");
        if (lastIndex > 0) {
            url = url.substring(0, lastIndex);
        }
        lastIndex = url.indexOf(":");
        if (lastIndex > 0) {
            url = url.substring(0, lastIndex);
        }

        // 判断是否属于我们的子域名
        String defaultDomain = getClientProperties().getTenantDomain();
        if (StringUtils.hasText(defaultDomain)) {
            url = filterChildDomain(url, defaultDomain);
        }

        return url;
    }

    /**
     * 从请求信息中解析租户
     *
     * @param request request
     * @return 租户信息
     */
    public static SysTenantDTO obtainTenant(HttpServletRequest request) {
        // 优先从当前用户信息中获取
        var user = SecurityUtil.getUser();
        if (user != null && user.getTenant() != null) {
            return user.getTenant();
        }

        // 从请求头获取
        String tenantId = request.getHeader(ConstantTenant.HEADER_TENANT_ID);
        if (tenantId != null) {
            var tenant = getTenantProvider().getById(Long.parseLong(tenantId));
            return tenant.orElseThrow(new BusinessException("租户不存在"));
        }

        // 根据域名转换
        if (StringUtils.hasText(getClientProperties().getTenantDomain())) {
            String domain = obtainTenantDomain(request.getRequestURL().toString());
            if (StringUtils.hasText(domain)) {
                var tenant = getTenantProvider().getByDomain(domain);
                return tenant.orElseThrow(new BusinessException("未知域名所绑定的租户"));
            }
        }

        return null;
    }

    public static TenantClientProperties getClientProperties() {
        if (clientProperties == null) {
            clientProperties = SpringContextHolder.getBean(TenantClientProperties.class);
        }
        return clientProperties;
    }

    public static TenantProvider getTenantProvider() {
        if (tenantProvider == null) {
            tenantProvider = SpringContextHolder.getBean(TenantProvider.class);
        }
        return tenantProvider;
    }

    private static String filterChildDomain(String domain, String defaultDomain) {
        if (domain == null) {
            return null;
        }
        if (domain.endsWith(defaultDomain)) {
            var end = domain.length() - defaultDomain.length() - 1;
            if (end <= 0) {
                return domain;
            }
            return domain.substring(0, end);
        }
        return domain;
    }
}
