package com.elitesland.cloudt.tenant.config.datasource.hibernate;

import cn.hutool.core.text.CharSequenceUtil;
import com.elitesland.cloudt.tenant.config.datasource.AbstractTenantDatasourceProvider;
import com.elitesland.cloudt.tenant.config.datasource.TenantSession;
import com.elitesland.cloudt.tenant.config.support.TenantContextHolder;
import lombok.extern.log4j.Log4j2;
import org.hibernate.engine.jdbc.connections.spi.AbstractDataSourceBasedMultiTenantConnectionProviderImpl;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.springframework.orm.jpa.vendor.Database;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

/**
 * 租户数据源提供.
 *
 * @author Kaiser（wang shao）
 * @date 2022/3/25
 */
@Log4j2
public class HibernateTenantDatasourceProvider extends AbstractDataSourceBasedMultiTenantConnectionProviderImpl {

    private static final long serialVersionUID = -3009600657851822507L;

    private final Map<String, ConnectionProvider> connectionProviderMap;
    private final Map<String, Database> databaseMapMap;

    public HibernateTenantDatasourceProvider() {
        this.connectionProviderMap = new HashMap<>();
        databaseMapMap = new HashMap<>();
    }

    @Override
    public Connection getConnection(String tenantIdentifier) throws SQLException {
        Connection connection = super.getConnection(tenantIdentifier);
        String schema = getTenantSchema();
        Database database = getDatabaseType(tenantIdentifier);
        log.info("use schema '{}' for tenant '{}'", schema, tenantIdentifier);

        String command = AbstractTenantDatasourceProvider.generateSwitchSchemaSql(database, schema);
        try (var statement = connection.createStatement()) {
            statement.execute(command);
        }
        return connection;
    }

    @Override
    protected DataSource selectAnyDataSource() {
        return AbstractTenantDatasourceProvider.getDefaultDataSource();
    }

    @Override
    protected DataSource selectDataSource(String tenantIdentifier) {
        return AbstractTenantDatasourceProvider.getDefaultDataSource();
    }

    @Override
    public void releaseAnyConnection(Connection connection) throws SQLException {
        super.releaseAnyConnection(connection);
        TenantSession.clearCurrent();
    }

    @Override
    public void releaseConnection(String tenantIdentifier, Connection connection) throws SQLException {
        super.releaseConnection(tenantIdentifier, connection);
        TenantSession.clearCurrent();
    }

    private String getTenantSchema() {
        // 判断是否使用默认schema
        boolean useDefault = TenantSession.getUseDefaultSchema();
        if (useDefault) {
            return AbstractTenantDatasourceProvider.getDefaultSchema();
        }

        // 获取设置的租户
        var tenant = TenantSession.getCurrentTenant();
        if (tenant == null) {
            tenant = TenantContextHolder.getCurrentTenant();
        }

        if (tenant == null) {
            // 当前无租户时，使用默认的schema
            return AbstractTenantDatasourceProvider.getDefaultSchema();
        }

        String prefix = CharSequenceUtil.isBlank(AbstractTenantDatasourceProvider.getDefaultSchema()) ?
                "" : AbstractTenantDatasourceProvider.getDefaultSchema() + "_";
        return prefix + tenant.getSchemaName();
    }

    private Database getDatabaseType(String tenantIdentifier) {
        return Optional.ofNullable(tenantIdentifier)
                .map(databaseMapMap::get)
                .orElseGet(() -> {
                    var db = AbstractTenantDatasourceProvider.getDatabaseType();
                    databaseMapMap.put(tenantIdentifier, db);
                    return db;
                });
    }
}
