package com.elitescloud.boot.flyway.common;

import cn.hutool.core.lang.Assert;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.db.dialect.DriverUtil;
import com.elitescloud.boot.flyway.FlywayCloudtProperties;
import com.elitescloud.boot.exception.BusinessException;
import lombok.extern.log4j.Log4j2;
import org.flywaydb.core.Flyway;
import org.springframework.core.io.ClassPathResource;
import org.springframework.util.StringUtils;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;

/**
 * .
 *
 * @author Kaiser（wang shao）
 * @date 2022/3/23
 */
@Log4j2
public class FlywayBuilder {

    private final DataSource dataSource;
    private final FlywayCloudtProperties properties;

    public FlywayBuilder(DataSource dataSource, FlywayCloudtProperties properties) {
        this.dataSource = dataSource;
        this.properties = properties;
    }

    /**
     * 默认系统
     *
     * @return
     */
    public Flyway createFlywayForSys() {
        Assert.notEmpty(properties.getLocationSys(), "未知系统默认数据库脚本路径");
        String[] locations = properties.getLocationSys().stream().map(this::buildLocation).toArray(String[]::new);
        log.info("默认SQL脚本路径：{}", () -> Arrays.toString(locations));

        return Flyway.configure()
                .dataSource(dataSource)
                .locations(locations)
                .schemas(StringUtils.hasText(properties.getSchemaDefault()) ? properties.getSchemaDefault() : null)
                .baselineOnMigrate(properties.isBaselineOnMigrate())
                .baselineVersion(properties.getBaselineVersion())
                .validateOnMigrate(properties.isValidateOnMigrate())
                .validateMigrationNaming(properties.isValidateMigrationNaming())
                .outOfOrder(properties.isOutOfOrder())
                .load();
    }

    /**
     * 创建租户使用
     *
     * @param prefix schema前缀
     * @param schema 租户对应的schema
     * @return
     */
    public Flyway createFlywayForTenant(String prefix, String schema) {
        Assert.notEmpty(properties.getLocationTenant(), "未知租户数据库脚本路径");
        String[] locations = properties.getLocationTenant().stream().map(this::buildLocation).toArray(String[]::new);
        log.info("租户SQL脚本路径：{}", () -> Arrays.toString(locations));

        prefix = CharSequenceUtil.isBlank(prefix) ? "" : prefix + "_";
        return Flyway.configure()
                .dataSource(dataSource)
                .locations(locations)
                .schemas(prefix + schema)
                .validateOnMigrate(properties.isBaselineOnMigrate())
                .baselineVersion(properties.getBaselineVersion())
                .validateOnMigrate(properties.isValidateOnMigrate())
                .validateMigrationNaming(properties.isValidateMigrationNaming())
                .outOfOrder(properties.isOutOfOrder())
                .load();
    }

    private String buildLocation(String location) {
        String tempLocation = "";

        // 如果配置了项目，则判断追加项目的路径是否存在
        if (StringUtils.hasText(properties.getProject())) {
            tempLocation = location + "/" + properties.getProject();
            if (resourceExists(tempLocation)) {
                // 没有项目标识
                location = tempLocation;
            }
        }

        // 判断对应数据库下是否有专属脚本
        String databaseProductName = getDatabaseProductName();
        tempLocation = location + "/" + databaseProductName;
        if (resourceExists(tempLocation)) {
            location = tempLocation;
        }

        return location;
    }

    /**
     * 判断资源是否存在
     *
     * @param path 资源路径
     * @return 是否存在
     */
    private boolean resourceExists(String path) {
        String classpath = "classpath:";
        if (path.startsWith(classpath)) {
            path = path.substring(classpath.length());
        }

        try {
            ClassPathResource resource = new ClassPathResource(path);
            return resource.exists();
        } catch (Exception ignored) {
        }
        return false;
    }

    /**
     * 获取数据库厂商名称
     * <p>
     * 小写，简写
     *
     * @return 厂商名称
     */
    private String getDatabaseProductName() {
        String productName = null;
        try (Connection connection = dataSource.getConnection()) {
            productName = connection.getMetaData().getDatabaseProductName();
        } catch (SQLException e) {
            log.error("获取数据库厂商名称失败：", e);
        }

        if (StringUtils.hasText(productName)) {
            return productName.toLowerCase();
        }

        productName = ObjectUtil.defaultIfNull(DriverUtil.identifyDriver(dataSource), "").toLowerCase();
        if (productName.contains("mysql")) {
            return "mysql";
        } else if (productName.contains("oracle")) {
            return "oracle";
        }
        throw new BusinessException("未知数据库厂商");
    }

}
