package com.elitescloud.boot.auth.cas.provider.impl;

import com.elitescloud.boot.auth.cas.AuthorizeCacheable;
import com.elitescloud.boot.auth.cas.model.AuthorizeDTO;
import com.elitescloud.boot.auth.cas.model.OAuth2UserInfoDTO;
import com.elitescloud.boot.auth.cas.provider.OAuth2ClientTemplate;
import com.elitescloud.boot.auth.common.AuthSdkConstant;
import com.elitescloud.boot.auth.config.AuthorizationSdkProperties;
import com.elitescloud.boot.auth.config.CloudtOAuth2ClientProperties;
import com.elitescloud.boot.auth.model.OAuthToken;
import com.elitescloud.boot.auth.resolver.UniqueRequestResolver;
import com.elitescloud.boot.auth.resolver.impl.DefaultUniquestResolver;
import com.elitescloud.boot.auth.util.AuthSdkUtil;
import com.elitescloud.boot.util.ObjectMapperFactory;
import com.elitescloud.boot.util.RestTemplateFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.util.*;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.util.UriComponentsBuilder;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.constraints.NotNull;
import java.net.URI;
import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
 * 默认OAuth2Client操作实现.
 *
 * @author Kaiser（wang shao）
 * @date 2023/10/19
 */
public class DefaultOAuth2ClientTemplate implements OAuth2ClientTemplate {
    private static final Logger LOG = LoggerFactory.getLogger(DefaultOAuth2ClientTemplate.class);
    protected static final String PARAM_REDIRECT_URL = "redirectUrl";
    protected static final String PARAM_STATE = "state";
    protected static final String PARAM_AUTH_SERVER = "authServer";
    protected static final String PARAM_SOURCE = "source";

    private final AuthorizationSdkProperties sdkProperties;
    private final AuthorizeCacheable authorizeCacheable;
    protected ObjectMapper objectMapper = ObjectMapperFactory.instance();
    protected UniqueRequestResolver uniqueRequestResolver = new DefaultUniquestResolver("X-Auth-Cas-Client");
    protected RestTemplate restTemplate = RestTemplateFactory.dynamicInstance(null, AuthSdkConstant.serverName);
    private final Map<String, EndpointInfo> endpointInfoMap = new HashMap<>(4);
    private final Map<String, CloudtOAuth2ClientProperties> clientPropertiesMap = new HashMap<>(4);

    public DefaultOAuth2ClientTemplate(AuthorizationSdkProperties sdkProperties,
                                       AuthorizeCacheable authorizeCacheable) {
        this.sdkProperties = sdkProperties;
        this.authorizeCacheable = authorizeCacheable == null ? new AuthorizeCacheDefault() : authorizeCacheable;
    }

    @Override
    public AuthorizeDTO generateAuthorizeInfo(HttpServletRequest request, @NotNull HttpServletResponse response) {
        var clientProperties = this.detectOAuth2ClientProperties(request);
        Assert.notNull(clientProperties, "OAuth2客户端未配置");
        clientPropertiesMap.put(clientProperties.getClientId(), clientProperties);

        var endpointInfo = endpointInfoMap.get(clientProperties.getClientId());
        if (endpointInfo == null) {
            endpointInfo = this.buildEndpointInfo(clientProperties);
            if (endpointInfo == null) {
                throw new OAuth2AuthenticationException("OAuth2客户端初始化失败，请联系管理员检查配置");
            }
            endpointInfoMap.put(clientProperties.getClientId(), endpointInfo);
        }

        AuthorizeDTO authorizeDTO = new AuthorizeDTO();
        authorizeDTO.setAuthorizeEndpoint(endpointInfo.getAuthorizeEndpoint());
        authorizeDTO.setClientId(clientProperties.getClientId());
        authorizeDTO.setResponseType("code");
        authorizeDTO.setScope("openid");
        authorizeDTO.setRedirectUri(this.obtainRedirectUrl(request, clientProperties));

        // pkce
        if (clientProperties.isPkceEnabled()) {
            authorizeDTO.setCodeVerifier(AuthSdkUtil.generateCodeVerifier());
            authorizeDTO.setCodeChallengeMethod("S256");
            authorizeDTO.setCodeChallenge(AuthSdkUtil.generateCodeChallenge(authorizeDTO.getCodeVerifier()));
        }
        authorizeDTO.setAuthServer(this.detectServerAddr(request, clientProperties, endpointInfo.getServerAddr()));
        authorizeDTO.setServerRouter(clientProperties.getDetectServerAddrRouter());
        authorizeDTO.setState(request.getParameter(PARAM_STATE));
        authorizeDTO.setLogoutUrl(this.normalizeEndpoint(authorizeDTO.getAuthServer(),
                this.buildLogoutUri(clientProperties, authorizeDTO.getRedirectUri())));

        // 缓存请求标识
        var reqId = uniqueRequestResolver.signRequest(response);
        authorizeCacheable.setCache(reqId, authorizeDTO);

        return authorizeDTO;
    }

    @Override
    public OAuthToken code2Token(HttpServletRequest request, HttpServletResponse response, String code) {
        // 获取客户端信息
        AuthorizeDTO authorizedInfo = null;
        String reqId = uniqueRequestResolver.analyze(request);
        if (StringUtils.hasText(reqId)) {
            authorizedInfo = authorizeCacheable.get(reqId);
        }
        CloudtOAuth2ClientProperties clientProperties = null;
        if (authorizedInfo != null) {
            clientProperties = clientPropertiesMap.get(authorizedInfo.getClientId());
        }
        if (clientProperties == null) {
            LOG.info("未获取到已认证信息，尝试获取默认配置信息...");
            clientProperties = this.detectOAuth2ClientProperties(request);
        }
        if (clientProperties == null) {
            LOG.info("获取认证信息失败：{}，或已超时", reqId);
            throw new OAuth2AuthenticationException("系统繁忙，请稍后再试");
        }

        var endpointInfo = endpointInfoMap.get(clientProperties.getClientId());
        if (endpointInfo == null) {
            endpointInfo = buildEndpointInfo(clientProperties);
        }

        // 组织请求参数
        MultiValueMap<String, Object> postParam = new LinkedMultiValueMap<>(8);
        postParam.add(OAuth2ParameterNames.CLIENT_ID, clientProperties.getClientId());
        postParam.add(OAuth2ParameterNames.CLIENT_SECRET, clientProperties.getClientSecret());
        postParam.add(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
        postParam.add(OAuth2ParameterNames.CODE, code);

        // 重定向地址
        String redirectUri = authorizedInfo == null ? clientProperties.getRedirectUrl() : authorizedInfo.getRedirectUri();
        if (StringUtils.hasText(redirectUri)) {
            LOG.info("redirectUri：{}", redirectUri);
            postParam.add(OAuth2ParameterNames.REDIRECT_URI, redirectUri);
        }
        if (authorizedInfo != null && StringUtils.hasText(authorizedInfo.getCodeVerifier())) {
            LOG.info("codeVerifier：{}", authorizedInfo.getCodeVerifier());
            postParam.add(PkceParameterNames.CODE_VERIFIER, authorizedInfo.getCodeVerifier());
        }

        // 发起请求
        try {
            var resp = restTemplate.exchange(endpointInfo.getTokenEndpoint(),
                    HttpMethod.POST, new HttpEntity<>(postParam), new ParameterizedTypeReference<HashMap<String, Object>>() {
                    });
            if (resp.getStatusCode().is2xxSuccessful()) {
                uniqueRequestResolver.clear(response, reqId);
                return this.convertAuthToken(resp.getBody());
            }

            LOG.error("授权码转token失败，参数：{}, 响应：{}", objectMapper.writeValueAsString(postParam), resp);
            if (resp.getStatusCode() == HttpStatus.UNAUTHORIZED) {
                throw new OAuth2AuthenticationException("认证失败，请稍后重试！");
            }

            return null;
        } catch (Exception e) {
            if (e instanceof OAuth2AuthenticationException) {
                throw (OAuth2AuthenticationException) e;
            }
            try {
                LOG.error("获取认证token异常，参数：" + objectMapper.writeValueAsString(postParam) + ", ：", e);
            } catch (JsonProcessingException ex) {
                LOG.error("获取认证token异常，打印入参异常，", e);
            }
        }

        return null;
    }

    @Override
    public OAuth2UserInfoDTO getUserInfo(HttpServletRequest request, String tokenType, String token) {
        Assert.hasText(tokenType, "token类型为空");
        Assert.hasText(token, "token为空");

        var clientProperties = this.detectOAuth2ClientProperties(request);
        Assert.notNull(clientProperties, "未获取到有效的客户端配置");
        var endpointInfo = getEndpointInfo(clientProperties);

        MultiValueMap<String, String> headers = new LinkedMultiValueMap<>(4);
        headers.add(HttpHeaders.AUTHORIZATION, tokenType + " " + token);

        try {
            var resp = restTemplate.exchange(endpointInfo.getUserinfoEndpoint(), HttpMethod.GET, new HttpEntity<>(null, headers),
                    new ParameterizedTypeReference<OAuth2UserInfoDTO>() {
                    });
            if (resp.getStatusCode().is2xxSuccessful()) {
                return resp.getBody();
            }
            LOG.error("获取用户信息失败，token：{}，响应：{}", tokenType + " " + token, resp);
        } catch (Exception e) {
            LOG.error("获取用户信息异常，token:{}，异常：", tokenType + " " + token, e);
        }
        return null;
    }

    @Override
    public OAuthToken clientToken() {
        var clientProperties = this.detectOAuth2ClientProperties(null);
        Assert.notNull(clientProperties, "未获取到有效的客户端配置");
        var endpointInfo = endpointInfoMap.get(clientProperties.getClientId());
        if (endpointInfo == null) {
            endpointInfo = this.buildEndpointInfo(clientProperties);
        }

        // 组织请求参数
        MultiValueMap<String, Object> postParam = new LinkedMultiValueMap<>(8);
        postParam.add(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
        postParam.add(OAuth2ParameterNames.CLIENT_ID, clientProperties.getClientId());
        postParam.add(OAuth2ParameterNames.CLIENT_SECRET, clientProperties.getClientSecret());

        try {
            var resp = restTemplate.exchange(endpointInfo.getTokenEndpoint(),
                    HttpMethod.POST, new HttpEntity<>(postParam), new ParameterizedTypeReference<OAuthToken>() {
                    });
            if (resp.getStatusCode().is2xxSuccessful()) {
                return resp.getBody();
            }
            LOG.error("生成token失败：{}", resp.getStatusCode());
        } catch (Exception e) {
            LOG.error("获取认证token失败：", e);
        }
        return null;
    }

    @Override
    public Boolean revokeToken(String token) {
        Assert.hasText(token, "token为空");
        var clientProperties = this.detectOAuth2ClientProperties(null);
        Assert.notNull(clientProperties, "未获取到有效的客户端配置");
        var endpointInfo = endpointInfoMap.get(clientProperties.getClientId());
        if (endpointInfo == null) {
            endpointInfo = this.buildEndpointInfo(clientProperties);
        }

        // 组织请求参数
        MultiValueMap<String, Object> postParam = new LinkedMultiValueMap<>(8);
        postParam.add(OAuth2ParameterNames.CLIENT_ID, sdkProperties.getCasClient().getOauth2Client().getClientId());
        postParam.add(OAuth2ParameterNames.CLIENT_SECRET, sdkProperties.getCasClient().getOauth2Client().getClientSecret());
        postParam.add(OAuth2ParameterNames.TOKEN, token);

        try {
            var resp = restTemplate.exchange(endpointInfo.getRevocationEndpoint(),
                    HttpMethod.POST, new HttpEntity<>(postParam), new ParameterizedTypeReference<String>() {
                    });
            if (resp.getStatusCode().is2xxSuccessful()) {
                return true;
            }
            LOG.error("注销token失败：{}", resp.getStatusCode());
        } catch (Exception e) {
            LOG.error("注销token失败：", e);
        }
        return false;
    }

    public void setUniqueRequestResolver(UniqueRequestResolver uniqueRequestResolver) {
        this.uniqueRequestResolver = uniqueRequestResolver;
    }

    public void setObjectMapper(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
    }

    protected CloudtOAuth2ClientProperties detectOAuth2ClientProperties(HttpServletRequest request) {
        var source = obtainSource(request);
        if (!StringUtils.hasText(source)) {
            return sdkProperties.getCasClient().getOauth2Client();
        }
        LOG.info("认证来源：{}", source);

        var oauth2Clients = sdkProperties.getCasClient().getExternalOauth2Client();
        if (oauth2Clients != null && !oauth2Clients.isEmpty()) {
            if (oauth2Clients.containsKey(source)) {
                return oauth2Clients.get(source);
            }
        }
        return sdkProperties.getCasClient().getOauth2Client();
    }

    protected OAuthToken convertAuthToken(Map<String, Object> map) {
        try {
            return objectMapper.convertValue(map, OAuthToken.class);
        } catch (Exception e) {
            String originalValue = null;
            try {
                originalValue = objectMapper.writeValueAsString(map);
            } catch (JsonProcessingException ex) {
                LOG.info("json转字符串异常：", ex);
            }
            LOG.error("转换token失败：{}", originalValue, e);
        }
        return null;
    }

    protected String redirectUrlNameOfLogout() {
        return PARAM_REDIRECT_URL;
    }

    protected String detectServerAddr(HttpServletRequest request, CloudtOAuth2ClientProperties clientProperties,
                                      String defaultServerAddr) {
        if (request == null) {
            return defaultServerAddr;
        }

        // 优先根据请求的
        String serverAddr = request.getParameter(PARAM_AUTH_SERVER);
        if (StringUtils.hasText(serverAddr)) {
            LOG.info("detect serverAddr from param：{}，{}", PARAM_AUTH_SERVER, serverAddr);
            return serverAddr;
        }

        // 其次根据请求头
        var serverAddrHeaderName = clientProperties.getDetectServerAddrHeader();
        if (!StringUtils.hasText(serverAddrHeaderName)) {
            return defaultServerAddr;
        }
        var serverAddrHeader = request.getHeader(serverAddrHeaderName);
        if (!StringUtils.hasText(serverAddrHeader) || serverAddrHeader.startsWith(defaultServerAddr)) {
            return defaultServerAddr;
        }
        LOG.info("detect serverAddr from header：{}，{}", serverAddrHeaderName, serverAddrHeader);
        try {
            var path = URI.create(serverAddrHeader).getPath();
            var router = clientProperties.getDetectServerAddrRouter() == null ? "" : clientProperties.getDetectServerAddrRouter();
            return serverAddrHeader.substring(0, serverAddrHeader.indexOf(path)) + router;
        } catch (Exception e) {
            throw new IllegalArgumentException("获取认证服务地址异常，解析请求头失败：" + serverAddrHeader);
        }
    }

    protected EndpointInfo getEndpointInfo(CloudtOAuth2ClientProperties clientProperties) {
        var endpointInfo = endpointInfoMap.get(clientProperties.getClientId());
        if (endpointInfo == null) {
            endpointInfo = this.buildEndpointInfo(clientProperties);
        }
        return endpointInfo;
    }

    protected String obtainSource(HttpServletRequest request) {
        if (request == null) {
            return null;
        }
        var source = request.getParameter(PARAM_SOURCE);
        if (StringUtils.hasText(source)) {
            return source;
        }

        String referer = request.getHeader(HttpHeaders.REFERER);
        if (!StringUtils.hasText(referer)) {
            return null;
        }

        return UriComponentsBuilder.fromUriString(referer)
                .build()
                .getQueryParams()
                .getFirst(PARAM_SOURCE);
    }

    private EndpointInfo buildEndpointInfo(CloudtOAuth2ClientProperties clientProperties) {
        String serverAddr = StringUtils.hasText(clientProperties.getServerAddr()) ? clientProperties.getServerAddr() :
                sdkProperties.getAuthServer();

        EndpointInfo endpointInfo = new EndpointInfo();
        endpointInfo.setServerAddr(serverAddr);
        if (!clientProperties.isDetectEndpoint()) {
            // 根据配置获取
            endpointInfo.setAuthorizeEndpoint(this.normalizeEndpoint(serverAddr, clientProperties.getAuthorizeEndpoint()));
            endpointInfo.setTokenEndpoint(this.normalizeEndpoint(serverAddr, clientProperties.getTokenEndpoint()));
            endpointInfo.setUserinfoEndpoint(this.normalizeEndpoint(serverAddr, clientProperties.getUserinfoEndpoint()));
            endpointInfo.setRevocationEndpoint(this.normalizeEndpoint(serverAddr, clientProperties.getRevocationEndpoint()));
            endpointInfo.setJwksUri(this.normalizeEndpoint(serverAddr, clientProperties.getJwksUri()));

            return endpointInfo;
        }

        // 自动根据接口获取
        var metadataUrl = this.normalizeEndpoint(serverAddr, clientProperties.getMetadataEndpoint());
        var metadataInfo = this.restExchangeByGet(metadataUrl);
        if (CollectionUtils.isEmpty(metadataInfo)) {
            return null;
        }
        serverAddr = (String) metadataInfo.get("issuer");
        endpointInfo.setServerAddr(serverAddr);
        endpointInfo.setAuthorizeEndpoint(this.detectEndpoint(clientProperties.getAuthorizeEndpoint(), (String) metadataInfo.get("authorization_endpoint")));
        endpointInfo.setTokenEndpoint(this.detectEndpoint(clientProperties.getTokenEndpoint(), (String) metadataInfo.get("token_endpoint")));
        endpointInfo.setUserinfoEndpoint(this.detectEndpoint(clientProperties.getUserinfoEndpoint(), (String) metadataInfo.get("userinfo_endpoint")));
        endpointInfo.setRevocationEndpoint(this.detectEndpoint(clientProperties.getRevocationEndpoint(), (String) metadataInfo.get("revocation_endpoint")));
        endpointInfo.setJwksUri(this.detectEndpoint(clientProperties.getJwksUri(), (String) metadataInfo.get("jwks_uri")));

        return endpointInfo;
    }

    protected static HttpServletRequest currentRequest() {
        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        return requestAttributes == null ? null : ((ServletRequestAttributes) requestAttributes).getRequest();
    }

    private String buildLogoutUri(CloudtOAuth2ClientProperties clientProperties, String redirectUrl) {
        var logoutUri = clientProperties.getLogoutUri();
        if (!StringUtils.hasText(logoutUri)) {
            return null;
        }

        // 判断是否已包含
        var redirectUrlName = redirectUrlNameOfLogout();
        if (StringUtils.hasText(redirectUrlName)) {
            var index = logoutUri.indexOf("?");
            if (index < 1) {
                return logoutUri + "?" + redirectUrlName + "=" + redirectUrl;
            }

            var paramsTemp = logoutUri.substring(index);
            return paramsTemp.contains(redirectUrlName) ? logoutUri : logoutUri + "&" + redirectUrlName + "=" + redirectUrl;
        }

        return logoutUri;
    }

    private String obtainRedirectUrl(HttpServletRequest request, CloudtOAuth2ClientProperties clientProperties) {
        var redirectUrl = request.getParameter(PARAM_REDIRECT_URL);
        if (StringUtils.hasText(redirectUrl)) {
            return redirectUrl;
        }

        return clientProperties.getRedirectUrl();
    }

    private String detectEndpoint(String customEndpoint, String metadataEndpoint) {
        if (StringUtils.hasText(customEndpoint) && (customEndpoint.startsWith("http://") || customEndpoint.startsWith("https://"))) {
            return customEndpoint;
        }
        return metadataEndpoint;
    }

    private String normalizeEndpoint(String serverAddr, String endpoint) {
        if (!StringUtils.hasText(endpoint)) {
            return serverAddr;
        }
        if (endpoint.toLowerCase().startsWith("http://") || endpoint.toLowerCase().startsWith("https://")) {
            return endpoint;
        }

        Assert.hasText(serverAddr, "serverAddr为空");

        if (serverAddr.endsWith("/")) {
            serverAddr = serverAddr.substring(0, serverAddr.length() - 1);
        }
        if (!endpoint.startsWith("/")) {
            endpoint = "/" + endpoint;
        }

        return serverAddr + endpoint;
    }

    private Map<String, Object> restExchangeByGet(String url) {
        try {
            var resp = restTemplate.exchange(url, HttpMethod.GET, null, String.class);
            if (resp.getStatusCode().is2xxSuccessful()) {
                LOG.info("查询OAuth2服务端配置成功：{}", resp.getBody());
                return objectMapper.readValue(resp.getBody(), new TypeReference<HashMap<String, Object>>() {
                });
            }
            LOG.warn("查询OAuth2服务端配置失败：{}, {}", url, resp.getStatusCode());
        } catch (Exception e) {
            LOG.error("查询OAuth2服务端配置异常: {}", url, e);
        }
        return Collections.emptyMap();
    }

    protected static class EndpointInfo {
        /**
         * 服务地址
         */
        private String serverAddr;
        /**
         * 服务端认证地址
         */
        private String authorizeEndpoint;
        /**
         * 服务端生成token地址
         */
        private String tokenEndpoint;
        /**
         * 用户信息地址
         */
        private String userinfoEndpoint;
        /**
         * 撤销token地址
         */
        private String revocationEndpoint;
        /**
         * jwks地址
         */
        private String jwksUri;

        public String getServerAddr() {
            return serverAddr;
        }

        public void setServerAddr(String serverAddr) {
            this.serverAddr = serverAddr;
        }

        public String getAuthorizeEndpoint() {
            return authorizeEndpoint;
        }

        public void setAuthorizeEndpoint(String authorizeEndpoint) {
            this.authorizeEndpoint = authorizeEndpoint;
        }

        public String getTokenEndpoint() {
            return tokenEndpoint;
        }

        public void setTokenEndpoint(String tokenEndpoint) {
            this.tokenEndpoint = tokenEndpoint;
        }

        public String getUserinfoEndpoint() {
            return userinfoEndpoint;
        }

        public void setUserinfoEndpoint(String userinfoEndpoint) {
            this.userinfoEndpoint = userinfoEndpoint;
        }

        public String getRevocationEndpoint() {
            return revocationEndpoint;
        }

        public void setRevocationEndpoint(String revocationEndpoint) {
            this.revocationEndpoint = revocationEndpoint;
        }

        public String getJwksUri() {
            return jwksUri;
        }

        public void setJwksUri(String jwksUri) {
            this.jwksUri = jwksUri;
        }
    }

    static class AuthorizeCacheDefault implements AuthorizeCacheable {
        private final Cache<String, AuthorizeDTO> authorizeCache;

        public AuthorizeCacheDefault() {
            this.authorizeCache = Caffeine.newBuilder()
                    .maximumSize(2000)
                    .expireAfterWrite(Duration.ofMinutes(60))
                    .build();
        }

        @Override
        public void setCache(String reqId, AuthorizeDTO authorizeDTO) {
            authorizeCache.put(reqId, authorizeDTO);
        }

        @Override
        public AuthorizeDTO get(String reqId) {
            return authorizeCache.getIfPresent(reqId);
        }
    }
}
