package com.elitescloud.boot.auth.provider.sso2.support.impl;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.text.CharSequenceUtil;
import com.elitescloud.boot.auth.model.OAuthToken;
import com.elitescloud.boot.auth.provider.common.AuthorizationConstant;
import com.elitescloud.boot.auth.provider.config.properties.Sso2Properties;
import com.elitescloud.boot.auth.provider.security.grant.InternalAuthenticationGranter;
import com.elitescloud.boot.auth.provider.sso2.common.SsoAuthenticationConvert;
import com.elitescloud.boot.auth.provider.sso2.common.SsoConvertProperty;
import com.elitescloud.boot.auth.provider.sso2.support.SsoUnifyClientSupportProvider;
import com.elitescloud.boot.exception.BusinessException;
import com.elitescloud.boot.util.ClassUtil;
import com.elitescloud.boot.util.JSONUtil;
import com.elitescloud.cloudt.common.base.ApiResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.AuthenticationException;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * 统一客户端认证.
 *
 * @author Kaiser（wang shao）
 * @date 2025/6/6 周五
 */
public class SsoUnifyClientSupportProviderImpl implements SsoUnifyClientSupportProvider, InitializingBean {
    private static final Logger logger = LoggerFactory.getLogger(SsoUnifyClientSupportProviderImpl.class);

    private static final String PARAM_AUTH_TYPE = "ac";

    private final Sso2Properties sso2Properties;
    private final InternalAuthenticationGranter internalAuthenticationGranter;
    private final List<SsoAuthenticationConvert> authenticationConverts;

    public SsoUnifyClientSupportProviderImpl(Sso2Properties sso2Properties, InternalAuthenticationGranter internalAuthenticationGranter,
                                             List<SsoAuthenticationConvert> authenticationConverts) {
        this.sso2Properties = sso2Properties;
        this.internalAuthenticationGranter = internalAuthenticationGranter;
        this.authenticationConverts = authenticationConverts;
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        if (sso2Properties.getUnifyClient() == null) {
            return;
        }
        if (CollUtil.isNotEmpty(sso2Properties.getUnifyClient().getClients())) {
            Set<String> existsAuthCodes = new HashSet<>();
            for (Sso2Properties.UnifyClientProperty client : sso2Properties.getUnifyClient().getClients()) {
                if (!client.isEnabled()) {
                    continue;
                }

                Assert.notBlank(client.getAuthCode(), "授权编码为空");
                Assert.isFalse(existsAuthCodes.contains(client.getAuthCode()), "授权编码存在重复：" + client.getAuthCode());
                existsAuthCodes.add(client.getAuthCode());

                Assert.notNull(client.getSsoType(), "存在配置单点登录类型ssoType为空");
            }
        }

        if (CollUtil.isNotEmpty(authenticationConverts)) {
            for (var authenticationConvert : authenticationConverts) {
                if (authenticationConvert.supportType() == null) {
                    throw new IllegalStateException(ClassUtil.getTargetClass(authenticationConvert).getName() + "的supportType为空");
                }
                if (authenticationConvert.propertyType() == null) {
                    throw new IllegalStateException(ClassUtil.getTargetClass(authenticationConvert).getName() + "的propertyType为空");
                }
                if (!SsoConvertProperty.class.isAssignableFrom(authenticationConvert.propertyType())) {
                    throw new IllegalStateException(ClassUtil.getTargetClass(authenticationConvert).getName() + "的propertyType类型错误");
                }
            }
        }
    }

    @Override
    public ApiResult<OAuthToken> authenticate(HttpServletRequest request, HttpServletResponse response) {
        response.setStatus(HttpStatus.UNAUTHORIZED.value());

        // 获取配置
        Sso2Properties.UnifyClientProperty clientProperty = obtainProperty(request);
        if (clientProperty == null) {
            return ApiResult.fail("不支持的认证方式");
        }

        // 令牌转换器
        var authenticationConvert = matchConvert(request, clientProperty);
        if (authenticationConvert == null) {
            return ApiResult.fail("不支持的认证类型");
        }

        // 转换参数
        SsoConvertProperty convertProperty = convertProperty(clientProperty, authenticationConvert.propertyType());

        // 转换令牌
        InternalAuthenticationGranter.InternalAuthenticationToken authenticationToken = authenticationConvert.convert(request, response, convertProperty);
        if (authenticationToken == null) {
            return ApiResult.fail("转换认证令牌失败");
        }

        // 生成token
        if (CharSequenceUtil.isNotBlank(clientProperty.getClientId())) {
            request.setAttribute(AuthorizationConstant.REQUEST_ATTRIBUTE_CLIENT_ID, clientProperty.getClientId());
        }
        OAuthToken token = null;
        try {
            token = internalAuthenticationGranter.authenticate(request, response, authenticationToken);
        } catch (AuthenticationException e) {
            return ApiResult.fail("认证异常，" + e.getMessage());
        }
        response.setStatus(HttpStatus.OK.value());
        return ApiResult.ok(token);
    }

    private SsoConvertProperty convertProperty(Sso2Properties.UnifyClientProperty clientProperty, Class<?> clazz) {
        SsoConvertProperty convertProperty = null;
        try {
            if (CollUtil.isEmpty(clientProperty.getProperties())) {
                convertProperty = (SsoConvertProperty) clazz.getDeclaredConstructor().newInstance();
            } else {
                convertProperty = (SsoConvertProperty) JSONUtil.convertObj(clientProperty.getProperties(), clazz, true);
            }
        } catch (Exception e) {
            throw new BusinessException("认证异常，请联系管理员", e);
        }

        convertProperty.setParamName(clientProperty.getParamName());
        convertProperty.setParamIn(clientProperty.getParamIn());
        convertProperty.setIdType(clientProperty.getIdType());
        convertProperty.setClientId(clientProperty.getClientId());

        // 参数校验
        convertProperty.validate();
        return convertProperty;
    }

    private Sso2Properties.UnifyClientProperty obtainProperty(HttpServletRequest request) {
        String authTypeCode = request.getParameter(PARAM_AUTH_TYPE);
        if (CharSequenceUtil.isBlank(authTypeCode)) {
            throw new BusinessException("缺少必要的参数：" + PARAM_AUTH_TYPE);
        }

        if (sso2Properties.getUnifyClient() == null || CollUtil.isEmpty(sso2Properties.getUnifyClient().getClients())) {
            return null;
        }

        for (Sso2Properties.UnifyClientProperty client : sso2Properties.getUnifyClient().getClients()) {
            if (authTypeCode.equals(client.getAuthCode())) {
                return client;
            }
        }
        return null;
    }

    private SsoAuthenticationConvert matchConvert(HttpServletRequest request, Sso2Properties.UnifyClientProperty property) {
        if (CollUtil.isEmpty(authenticationConverts) || property == null) {
            return null;
        }

        for (var authenticationConvert : authenticationConverts) {
            if (authenticationConvert.supportType() != null && authenticationConvert.supportType() == property.getSsoType()) {
                return authenticationConvert;
            }
        }

        return null;
    }
}
