package com.elitescloud.boot.auth.provider.sso2.support.convert;

import cn.hutool.core.lang.Assert;
import cn.hutool.core.text.CharSequenceUtil;
import com.elitescloud.boot.auth.provider.security.grant.InternalAuthenticationGranter;
import com.elitescloud.boot.auth.provider.sso2.common.SsoConvertProperty;
import com.elitescloud.boot.auth.provider.sso2.common.SsoTypeEnum;
import com.elitescloud.boot.auth.provider.sso2.support.convert.properties.JwtSsoConvertProperty;
import com.elitescloud.boot.exception.BusinessException;
import com.elitescloud.boot.util.JSONUtil;
import com.elitescloud.boot.util.RsaUtil;
import com.fasterxml.jackson.core.type.TypeReference;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.*;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jwt.EncryptedJWT;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.PlainJWT;
import com.nimbusds.jwt.SignedJWT;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.nio.charset.StandardCharsets;
import java.security.PublicKey;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.Date;
import java.util.Map;

/**
 * JWT类型的.
 *
 * @author Kaiser（wang shao）
 * @date 2025/6/6 周五
 */
public class JwtSsoAuthenticationConvert extends BasePlainSsoAuthenticationConvert {
    private static final Logger logger = LoggerFactory.getLogger(JwtSsoAuthenticationConvert.class);

    @Override
    public SsoTypeEnum supportType() {
        return SsoTypeEnum.JWT;
    }

    @Override
    public <T extends SsoConvertProperty> Class<T> propertyType() {
        return (Class<T>) JwtSsoConvertProperty.class;
    }

    @Nullable
    @Override
    public <T extends SsoConvertProperty> InternalAuthenticationGranter.InternalAuthenticationToken convert(HttpServletRequest request, HttpServletResponse response, T properties) {
        JwtSsoConvertProperty props = (JwtSsoConvertProperty) properties;

        String value = getParam(request, props.getParamName(), props.getParamIn());
        if (CharSequenceUtil.isBlank(value)) {
            throw new IllegalArgumentException("参数为空:" + props.getParamName());
        }

        // 解析jwt
        JWT jwt = null;
        if (props.isEncrypt()) {
            jwt = parseEncryptedJwt(value, props);
        } else if (props.isSigned()) {
            jwt = parseSignedJwt(value, props);
        } else {
            jwt = parsePlainJwt(value);
        }

        // 从jwt payload中获取用户名
        String username = null;
        try {
            username = obtainUsernameFromJwt(jwt, props);
        } catch (Exception e) {
            throw new BusinessException("解析令牌异常", e);
        }
        if (CharSequenceUtil.isBlank(username)) {
            throw new BusinessException("授权账户为空");
        }

        return new InternalAuthenticationGranter.InternalAuthenticationToken(props.getIdType(), username);
    }

    private String obtainUsernameFromJwt(JWT jwt, JwtSsoConvertProperty props) throws Exception {
        if (jwt == null) {
            return null;
        }

        // 没有路径
        if (CharSequenceUtil.isBlank(props.getPayloadUserNamePath())) {
            return jwt.getJWTClaimsSet().getStringClaim(props.getPayloadUserName());
        }

        // 解析路径
        String payloadUserNamePath = props.getPayloadUserNamePath();
        int dotIndex = payloadUserNamePath.indexOf(".");
        if (dotIndex <= 0) {
            return jwt.getJWTClaimsSet().getStringClaim(payloadUserNamePath);
        }

        String payloadName = payloadUserNamePath.substring(0, dotIndex);
        String path  = payloadUserNamePath.substring(dotIndex + 1);
        Map<String, Object> userInfo = JSONUtil.json2Obj(jwt.getJWTClaimsSet().getStringClaim(payloadName), new TypeReference<>() {}, true, () -> "解析用户信息响应内容异常");
        return getValueByPath(path, userInfo);
    }

    private PlainJWT parsePlainJwt(String jwtValue) {
        PlainJWT jwt = null;
        try {
            jwt = PlainJWT.parse(jwtValue);
        } catch (Exception e) {
            throw new BusinessException("解析令牌异常", e);
        }
        return jwt;
    }

    private SignedJWT parseSignedJwt(String jwtValue, JwtSsoConvertProperty props) {
        Assert.notNull(props.getSignType(), "签名类型为空");
        Assert.notBlank(props.getSignKey(), "签名密钥为空");


        SignedJWT jwt = null;
        try {
            jwt = SignedJWT.parse(jwtValue);

            // 验证签名
            JWSVerifier verifier = null;
            switch (props.getSignType()) {
                case HMAC:
                    OctetSequenceKey jwk = new OctetSequenceKey.Builder(props.getSignKey().getBytes(StandardCharsets.UTF_8))
                            .algorithm(jwt.getHeader().getAlgorithm())
                            .build();
                    verifier = new MACVerifier(jwk);
                    break;
                case RSA:
                    PublicKey publicKey = "X.509".equals(props.getRsaFormat()) ? RsaUtil.convert2PublicKey(props.getSignKey()) : RsaUtil.convert2PublicKeyForPkcs1(props.getSignKey());
                    verifier = new RSASSAVerifier((RSAPublicKey)publicKey);
                    break;
                case ECDSA:
                    ECKey ecKey = ECKey.parse(props.getSignKey());
                    verifier = new ECDSAVerifier(ecKey);
                    break;
                case NONE:
                    break;
                default:
                    throw new IllegalArgumentException("不支持的签名方式:" + props.getSignType());
            }
            boolean verified = verifier == null || jwt.verify(verifier);
            if (!verified) {
                throw new BusinessException("令牌签名验证失败");
            }

            // 验证过期时间
            var dateExpired = jwt.getJWTClaimsSet().getExpirationTime();
            if (dateExpired != null && dateExpired.before(new Date())) {
                throw new BusinessException("令牌已过期");
            }
        } catch (Exception e) {
            if (e instanceof BusinessException) {
                throw (BusinessException) e;
            }
            throw new BusinessException("解析令牌异常", e);
        }

        return jwt;
    }

    private EncryptedJWT parseEncryptedJwt(String jwtValue, JwtSsoConvertProperty props) {
        Assert.notNull(props.getEncryptType(), "加密方式为空");
        Assert.notBlank(props.getEncryptKey(), "密钥为空");

        String encryptKey = props.getEncryptKey();

        EncryptedJWT jwt = null;
        try {
            jwt = EncryptedJWT.parse(jwtValue);

            // 解密payload
            switch (props.getEncryptType()) {
                case RSA:
                    jwt.decrypt(new RSADecrypter(RsaUtil.convert2PrivateKey(encryptKey)));
                    break;
                case AES:
                    jwt.decrypt(new AESDecrypter(encryptKey.getBytes(StandardCharsets.UTF_8)));
                    break;
                case PASSWORD:
                    jwt.decrypt(new PasswordBasedDecrypter(encryptKey));
                    break;
                default:
                    throw new IllegalArgumentException("不支持的加密方式:" + props.getEncryptType());
            }
        } catch (Exception e) {
            throw new BusinessException("解析令牌异常", e);
        }

        return jwt;
    }
}
