package com.elitesland.cloudt.tenant.config.datasource;

import com.elitesland.cloudt.tenant.config.TenantClientProperties;
import com.zaxxer.hikari.HikariDataSource;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.boot.context.event.ApplicationContextInitializedEvent;
import org.springframework.boot.context.event.ApplicationEnvironmentPreparedEvent;
import org.springframework.boot.context.properties.bind.BindResult;
import org.springframework.boot.context.properties.bind.Binder;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.core.env.Environment;
import org.springframework.lang.NonNull;
import org.springframework.orm.jpa.vendor.Database;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import javax.sql.DataSource;

/**
 * 租户数据源监听配置.
 *
 * @author Kaiser（wang shao）
 * @date 2022/4/4
 */
@Configuration(proxyBeanMethods = false)
public class TenantDataSourceListener implements SmartApplicationListener, InitializingBean {

    @Override
    public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) {
        return eventType.isAssignableFrom(ApplicationContextInitializedEvent.class);
    }

    @Override
    public void onApplicationEvent(@NonNull ApplicationEvent event) {
        if (event instanceof ApplicationContextInitializedEvent) {
            Environment env = ((ApplicationContextInitializedEvent) event).getApplicationContext().getEnvironment();

            // 加载数据源
            DataSource dataSource = loadingDataSource(env);

            // 加载默认schema
            loadingDefaultSchema(env);

            // 加载数据库类型
            loadingDatabase(env, dataSource);
        }
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        Assert.hasText(AbstractTenantDatasourceProvider.getDefaultSchema(), "默认租户schema为空");
        Assert.notNull(AbstractTenantDatasourceProvider.getDefaultDataSource(), "默认数据源为空");
    }

    private Database loadingDatabase(Environment env, DataSource dataSource) {
        Database database = null;

        // 如果是hikari数据源
        if (dataSource instanceof HikariDataSource) {
            database = AbstractTenantDatasourceProvider.getDatabaseTypeByUrl(((HikariDataSource) dataSource).getJdbcUrl());
            if (database != null) {
                AbstractTenantDatasourceProvider.setDatabaseType(database);
                return database;
            }
        }

        // 从jpa指定的数据库类型
        String dbTypeStr = env.getProperty("spring.jpa.database");
        if (StringUtils.hasText(dbTypeStr)) {
            try {
                database = Database.valueOf(dbTypeStr.toUpperCase());
            } catch (IllegalArgumentException e) {
                e.printStackTrace();
            }
            if (database != null) {
                AbstractTenantDatasourceProvider.setDatabaseType(database);
                return database;
            }
        }

        return database;
    }

    private String loadingDefaultSchema(Environment env) {
        String schema = env.getProperty(TenantClientProperties.CONFIG_PREFIX + ".default-schema");
        if (StringUtils.hasText(schema)) {
            AbstractTenantDatasourceProvider.setDefaultSchema(schema);
        }
        return schema;
    }

    private DataSource loadingDataSource(Environment env) {
        DataSource dataSource = loadDataSourceBySharding(env);

        if (dataSource == null) {
            dataSource = loadDataSourceByHikari(env);
        }

        AbstractTenantDatasourceProvider.setDefaultDataSource(dataSource);

        return dataSource;
    }

    @SuppressWarnings("unchecked")
    private DataSource loadDataSourceBySharding(Environment env) {
        String prefix = "spring.shardingsphere.datasource.";

        // 获取数据源名称，默认取第一个
        String datasourceName = env.getProperty(prefix + "name");
        if (!StringUtils.hasText(datasourceName)) {
            datasourceName = env.getProperty(prefix + "names");
            if (StringUtils.hasText(datasourceName)) {
                datasourceName = datasourceName.split(",")[0];
            }
        }
        if (!StringUtils.hasText(datasourceName)) {
            return null;
        }

        // 初始化数据源
        String datasourceKey = prefix + datasourceName.trim();
        String datasourceType = env.getProperty(datasourceKey + ".type");
        Assert.hasText(datasourceType, "未知shardingSphere配置的数据源类型");

        Class<DataSource> dataSourceClass = null;
        try {
            dataSourceClass = (Class<DataSource>) Class.forName(datasourceType);
        } catch (ClassNotFoundException e) {
            throw new RuntimeException("未找到数据源类型class：" + datasourceType, e);
        }

        BindResult<DataSource> dataSourceBindResult = Binder.get(env).bind(datasourceKey, dataSourceClass);
        if (dataSourceBindResult.isBound()) {
            return dataSourceBindResult.get();
        }

        throw new RuntimeException("初始化数据源失败");
    }

    private HikariDataSource loadDataSourceByHikari(Environment env) {
        String prefix = "spring.datasource.hikari";

        String jdbcUrl = env.getProperty(prefix + ".jdbc-url");
        if (!StringUtils.hasText(jdbcUrl)) {
            return null;
        }

        BindResult<HikariDataSource> dataSourceBindResult = Binder.get(env).bind(prefix, HikariDataSource.class);
        if (dataSourceBindResult.isBound()) {
            return dataSourceBindResult.get();
        }

        throw new RuntimeException("初始化数据源失败");
    }
}
