package com.elitescloud.boot.mq.config;

import com.elitescloud.boot.provider.TenantClientProvider;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.cloud.stream.config.ChannelsEndpointAutoConfiguration;
import org.springframework.cloud.stream.messaging.DirectWithAttributesChannel;
import org.springframework.cloud.stream.messaging.Sink;
import org.springframework.cloud.stream.messaging.Source;
import org.springframework.context.annotation.Bean;

/**
 * SpringCloudtStream配置.
 *
 * @author Kaiser（wang shao）
 * @date 2024/3/22
 */
@Slf4j
@ConditionalOnClass({ChannelsEndpointAutoConfiguration.class})
public class CloudtSpringCloudStreamAutoConfiguration {

    @Bean
    public BeanPostProcessor springCloudStreamChannelInterceptorTenant(@Value("${spring.application.name:#{'unknown'}}") String applicationName,
                                                                       TenantClientProvider tenantClientProvider) {
        return new BeanPostProcessor() {
            @Override
            public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
                return BeanPostProcessor.super.postProcessBeforeInitialization(bean, beanName);
            }

            @Override
            public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
                if (bean instanceof DirectWithAttributesChannel) {
                    DirectWithAttributesChannel channel = (DirectWithAttributesChannel) bean;
                    String type = (String) channel.getAttribute("type");
                    if (Sink.INPUT.equals(type)) {
                        channel.addInterceptor(0, new CloudtMessagingInputInterceptor(applicationName, tenantClientProvider));
                    } else if (Source.OUTPUT.equals(type)) {
                        channel.addInterceptor(0, new CloudtMessagingOutputInterceptor(tenantClientProvider));
                    } else {
                        log.error("未知MessageChannel类型：{}，{}", beanName, type);
                    }
                }
                return BeanPostProcessor.super.postProcessAfterInitialization(bean, beanName);
            }
        };
    }
}
