package com.elitesland.cloudt.tenant.transaction;

import com.elitesland.cloudt.tenant.TenantClient;
import com.elitesland.cloudt.tenant.config.datasource.TenantSession;
import com.elitesland.yst.common.exception.BusinessException;
import com.elitesland.yst.core.annotation.TenantTransaction;
import com.elitesland.yst.core.security.util.SecurityUtil;
import lombok.extern.log4j.Log4j2;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.aop.aspectj.MethodInvocationProceedingJoinPoint;
import org.springframework.core.Ordered;
import org.springframework.util.Assert;

/**
 * 租户事务处理拦截.
 *
 * @author Kaiser（wang shao）
 * @date 2022/4/7
 */
@Log4j2
@Aspect
public class TenantTransactionAspect implements Ordered {

    @Pointcut("@annotation(com.elitesland.yst.core.annotation.TenantTransaction)")
    private void pointCutMethod() {
    }

    @Pointcut("@within(com.elitesland.yst.core.annotation.TenantTransaction)")
    private void pointCutClass() {
    }

    @Around("pointCutClass() || pointCutMethod()")
    public Object cutAround(ProceedingJoinPoint point) throws Throwable {
        if (TenantSession.getUseDefaultSchema()) {
            // 已设置使用默认的schema
            return point.proceed();
        }

        TenantTransaction annotation = obtainAnnotation(point);
        Assert.notNull(annotation, "未获取到TenantTransaction注解信息");

        var user = SecurityUtil.getUser();
        if (annotation.defaultSchema() || (user != null && user.isSystemAdmin())) {
            return cutForDefaultSchema(point);
        } else {
            return cutForTenantSchema(point, annotation.tenantRequired());
        }
    }

    private TenantTransaction obtainAnnotation(ProceedingJoinPoint point) {
        if (point instanceof MethodInvocationProceedingJoinPoint) {
            // 优先获取方法上的注解
            TenantTransaction annotation = ((MethodSignature) point.getSignature()).getMethod().getAnnotation(TenantTransaction.class);

            if (annotation == null) {
                annotation = point.getThis().getClass().getAnnotation(TenantTransaction.class);
            }

            return annotation;
        }
        return null;
    }

    private Object cutForDefaultSchema(ProceedingJoinPoint point) throws Throwable {
        TenantSession.setUseDefaultSchema();
        try {
            return point.proceed();
        } finally {
            TenantSession.clearUseDefaultSchema();
        }
    }

    private Object cutForTenantSchema(ProceedingJoinPoint point, boolean requiredTenant) throws Throwable {
        TenantSession.clearUseDefaultSchema();

        if (TenantClient.getCurrentTenant() == null && TenantSession.getCurrentTenant() == null) {
            if (requiredTenant) {
                throw new BusinessException("未获取到当前租户信息");
            }
            log.warn("未获取到当前租户信息");
        }
        return point.proceed();
    }

    @Override
    public int getOrder() {
        return Ordered.HIGHEST_PRECEDENCE;
    }
}
