package com.elitescloud.cloudt.authorization.sdk.cas.provider;

import com.elitescloud.cloudt.authorization.sdk.cas.AuthorizeCacheable;
import com.elitescloud.cloudt.authorization.sdk.cas.model.AuthorizeDTO;
import com.elitescloud.cloudt.authorization.sdk.config.AuthorizationSdkProperties;
import com.elitescloud.cloudt.authorization.sdk.model.OAuthToken;
import com.elitescloud.cloudt.authorization.sdk.model.Result;
import com.elitescloud.cloudt.authorization.sdk.resolver.UniqueRequestResolver;
import com.elitescloud.cloudt.authorization.sdk.resolver.impl.DefaultUniquestResolver;
import com.elitescloud.cloudt.authorization.sdk.util.RestTemplateFactory;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.CompletableFuture;

/**
 * .
 *
 * @author Kaiser（wang shao）
 * @date 2022/12/11
 */
public class OAuth2ClientProvider implements InitializingBean {
    private static final Logger LOG = LoggerFactory.getLogger(OAuth2ClientProvider.class);
    private static final String BASE_CHAR = "abcdefghijklmnopqrstuvwxyz";

    private final AuthorizationSdkProperties sdkProperties;
    private final AuthorizeCacheable authorizeCacheable;
    private final MessageDigest digestSha256;
    private UniqueRequestResolver uniqueRequestResolver = new DefaultUniquestResolver("X-Auth-Req-Client");
    private RestTemplate restTemplate;
    private OAuth2ClientBO clientBO;

    public OAuth2ClientProvider(AuthorizationSdkProperties sdkProperties, AuthorizeCacheable authorizeCacheable) {
        this.sdkProperties = sdkProperties;
        this.authorizeCacheable = authorizeCacheable == null ? new AuthorizeCacheDefault() : authorizeCacheable;
        try {
            this.digestSha256 = MessageDigest.getInstance("SHA-256");
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 获取认证信息
     *
     * @return 地址
     */
    public String getAuthorizeInfo(@NotNull HttpServletResponse response, @NotBlank String redirectUrl, String state) {
        if (clientBO == null) {
            init();
        }
        // 请求标识
        var reqId = uniqueRequestResolver.signRequest(response);

        AuthorizeDTO authorizeDTO = new AuthorizeDTO();
        authorizeDTO.setAuthorizeEndpoint(clientBO.getAuthorizeEndpoint());
        authorizeDTO.setClientId(sdkProperties.getCasClient().getOauth2Client().getClientId());
        authorizeDTO.setResponseType("code");
        authorizeDTO.setScope("openid");
        authorizeDTO.setRedirectUri(redirectUrl);

        // pkce
        if (sdkProperties.getCasClient().getOauth2Client().isPkceEnabled()) {
            authorizeDTO.setCodeVerifier(generateCodeVerifier());
            authorizeDTO.setCodeChallengeMethod("S256");
            authorizeDTO.setCodeChallenge(generateCodeChallenge(authorizeDTO.getCodeVerifier()));
        }
        authorizeDTO.setState(state);

        authorizeCacheable.setCache(reqId, authorizeDTO);

        return authorizeDTO.getUrl();
    }

    /**
     * 授权码换取token
     *
     * @param request request
     * @param code    授权码
     * @return token信息
     */
    public Result<OAuthToken> code2AccessToken(@NotNull HttpServletRequest request, @NotBlank String code) {
        String reqId = uniqueRequestResolver.analyze(request);
        Assert.hasText(reqId, "请求失败，未获取到有效的请求标识");
        AuthorizeDTO authorizeInfo = authorizeCacheable.get(reqId);
        Assert.notNull(authorizeInfo, "认证超时，请重试");

        // 组织请求参数
        MultiValueMap<String, Object> postParam = new LinkedMultiValueMap<>(8);
        postParam.add(OAuth2ParameterNames.CLIENT_ID, authorizeInfo.getClientId());
        postParam.add(OAuth2ParameterNames.CLIENT_SECRET, sdkProperties.getCasClient().getOauth2Client().getClientSecret());
        postParam.add(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
        postParam.add(OAuth2ParameterNames.CODE, code);

        String redirectUri = authorizeInfo.getRedirectUri();
        if (StringUtils.hasText(redirectUri)) {
            postParam.add(OAuth2ParameterNames.REDIRECT_URI, redirectUri);
        }
        if (StringUtils.hasText(authorizeInfo.getCodeVerifier())) {
            postParam.add(PkceParameterNames.CODE_VERIFIER, authorizeInfo.getCodeVerifier());
//            postParam.remove(OAuth2ParameterNames.CLIENT_SECRET);
        }

        try {
            var resp = restTemplate.exchange(clientBO.getTokenEndpoint(),
                    HttpMethod.POST, new HttpEntity<>(postParam), new ParameterizedTypeReference<OAuthToken>() {
                    });
            if (resp.getStatusCode().is2xxSuccessful()) {
                return Result.ok(resp.getBody());
            }
            LOG.error("授权码转token失败：{}", resp);
            return Result.fail("获取认证token失败");
        } catch (Exception e) {
            LOG.error("获取认证token异常：", e);
            return Result.fail("获取认证token异常！");
        }
    }

    /**
     * 获取用户信息
     *
     * @param tokenType   token类型
     * @param accessToken token
     * @return 用户信息
     */
    public Result<HashMap<String, String>> queryUserInfo(@NotBlank String tokenType, @NotBlank String accessToken) {
        Assert.hasText(tokenType, "token类型为空");
        Assert.hasText(accessToken, "token为空");

        MultiValueMap<String, String> headers = new LinkedMultiValueMap<>(4);
        headers.add(HttpHeaders.AUTHORIZATION, tokenType + " " + accessToken);
        try {
            var resp = restTemplate.exchange(clientBO.getUserinfoEndpoint(), HttpMethod.GET, new HttpEntity<>(null, headers),
                    new ParameterizedTypeReference<HashMap<String, String>>() {
                    });
            if (resp.getStatusCode().is2xxSuccessful()) {
                return Result.ok(resp.getBody());
            }
            LOG.error("获取用户信息失败：{}", resp);
            return Result.fail("获取用户信息失败！");
        } catch (Exception e) {
            LOG.error("获取用户信息异常：", e);
            return Result.fail("获取用户信息异常！");
        }
    }

    /**
     * 刷新token
     *
     * @param refreshToken 刷新token
     * @return 新的token信息
     */
    public Result<OAuthToken> refreshToken(@NotBlank String refreshToken) {
        Assert.notNull(refreshToken, "刷新token为空");

        // 组织请求参数
        MultiValueMap<String, Object> postParam = new LinkedMultiValueMap<>(8);
        postParam.add(OAuth2ParameterNames.CLIENT_ID, sdkProperties.getCasClient().getOauth2Client().getClientId());
        postParam.add(OAuth2ParameterNames.CLIENT_SECRET, sdkProperties.getCasClient().getOauth2Client().getClientSecret());
        postParam.add(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue());
        postParam.add(OAuth2ParameterNames.REFRESH_TOKEN, refreshToken);

        try {
            var resp = restTemplate.exchange(clientBO.getTokenEndpoint(),
                    HttpMethod.POST, new HttpEntity<>(postParam), new ParameterizedTypeReference<OAuthToken>() {
                    });
            if (resp.getStatusCode().is2xxSuccessful()) {
                return Result.ok(resp.getBody());
            }
            LOG.error("刷新token失败：{}", resp);
            return Result.fail("刷新token失败");
        } catch (Exception e) {
            LOG.error("刷新token异常：", e);
            return Result.fail("刷新认证token异常！");
        }
    }

    /**
     * 客户端认证token
     *
     * @return token信息
     */
    public Result<OAuthToken> clientToken() {
        // 组织请求参数
        MultiValueMap<String, Object> postParam = new LinkedMultiValueMap<>(8);
        postParam.add(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
        postParam.add(OAuth2ParameterNames.CLIENT_ID, sdkProperties.getCasClient().getOauth2Client().getClientId());
        postParam.add(OAuth2ParameterNames.CLIENT_SECRET, sdkProperties.getCasClient().getOauth2Client().getClientSecret());

        try {
            var resp = restTemplate.exchange(clientBO.getTokenEndpoint(),
                    HttpMethod.POST, new HttpEntity<>(postParam), new ParameterizedTypeReference<OAuthToken>() {
                    });
            if (resp.getStatusCode().is2xxSuccessful()) {
                return Result.ok(resp.getBody());
            }
            LOG.error("生成token失败：{}", resp.getStatusCode());
            return Result.fail("获取认证token失败");
        } catch (Exception e) {
            LOG.error("获取认证token失败：", e);
            return Result.fail("获取认证token失败！");
        }
    }

    /**
     * 注销token
     *
     * @param token token（刷新token、accessToken等）
     * @return 注销结果
     */
    public Result<Boolean> revokeToken(@NotBlank String token) {
        Assert.hasText(token, "token为空");

        // 组织请求参数
        MultiValueMap<String, Object> postParam = new LinkedMultiValueMap<>(8);
        postParam.add(OAuth2ParameterNames.CLIENT_ID, sdkProperties.getCasClient().getOauth2Client().getClientId());
        postParam.add(OAuth2ParameterNames.CLIENT_SECRET, sdkProperties.getCasClient().getOauth2Client().getClientSecret());
        postParam.add(OAuth2ParameterNames.TOKEN, token);

        try {
            var resp = restTemplate.exchange(clientBO.getRevocationEndpoint(),
                    HttpMethod.POST, new HttpEntity<>(postParam), new ParameterizedTypeReference<String>() {
                    });
            if (resp.getStatusCode().is2xxSuccessful()) {
                return Result.ok(true);
            }
            LOG.error("注销token失败：{}", resp.getStatusCode());
            return Result.fail("注销token失败");
        } catch (Exception e) {
            LOG.error("注销token失败：", e);
            return Result.fail("注销token失败！");
        }
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        if (!sdkProperties.getCasClient().getOauth2Client().isPkceEnabled()) {
            Assert.hasText(sdkProperties.getCasClient().getOauth2Client().getClientSecret(), "OAuth2 Client的clientSecret为空");
        }

        CompletableFuture.runAsync(this::init)
                .whenComplete((res, exp) -> {
                    if (exp != null) {
                        LOG.error("初始化OAuth2客户端异常：", exp);
                    }
                });
    }

    public void setUniqueRequestResolver(UniqueRequestResolver uniqueRequestResolver) {
        this.uniqueRequestResolver = uniqueRequestResolver;
    }

    private void init() {
        if (restTemplate == null) {
            restTemplate = RestTemplateFactory.instance();
        }

        // 服务端地址
        String authServer = obtainAuthServer();
        clientBO = initClient(authServer);
    }

    private String obtainAuthServer() {
        return sdkProperties.getAuthServer();
    }

    private OAuth2ClientBO initClient(String authServer) {
        var client = new OAuth2ClientBO();

        var clientProperties = sdkProperties.getCasClient().getOauth2Client();
        client.setAuthorizeEndpoint(normalizeUrl(authServer, clientProperties.getAuthorizeEndpoint()));
        if (StringUtils.hasText(client.getAuthorizeEndpoint())) {
            client.setTokenEndpoint(normalizeUrl(authServer, clientProperties.getTokenEndpoint()));
            client.setUserinfoEndpoint(normalizeUrl(authServer, clientProperties.getUserinfoEndpoint()));
            return client;
        }

        // 调用远程接口查询
        Assert.hasText(authServer, "未知认证服务器地址");
        var queryResult = queryServerConfig(normalizeUrl(authServer, CasUrlConstant.URI_OIDC_METADATA));
        client.setAuthorizeEndpoint((String) queryResult.get("authorization_endpoint"));
        Assert.hasText(client.getAuthorizeEndpoint(), "OAuth2客户端初始化失败");
        client.setTokenEndpoint((String) queryResult.get("token_endpoint"));
        client.setUserinfoEndpoint((String) queryResult.get("userinfo_endpoint"));
        client.setRevocationEndpoint((String) queryResult.get("revocation_endpoint"));
        return client;
    }

    private Map<String, Object> queryServerConfig(String url) {
        try {
            var resp = restTemplate.exchange(url, HttpMethod.GET, null, new ParameterizedTypeReference<Map<String, Object>>() {
            });
            if (resp.getStatusCode().is2xxSuccessful()) {
                LOG.info("查询OAuth2服务端配置成功：{}", resp.getBody());
                return resp.getBody();
            }
            LOG.warn("查询OAuth2服务端配置失败：{}", resp.getStatusCode());
        } catch (Exception e) {
            LOG.error("查询OAuth2服务端配置异常", e);
        }
        return Collections.emptyMap();
    }

    private String generateSeqId() {
        String state = System.nanoTime() + "";
        Random random = new Random();
        StringBuilder stateBuilder = new StringBuilder(state);
        for (int i = 0; i < 8; i++) {
            stateBuilder.append(random.nextInt(10));
        }
        return stateBuilder.toString();
    }

    private String generateCodeVerifier() {
        int len = BASE_CHAR.length();
        Random random = new Random();
        StringBuilder stateBuilder = new StringBuilder();
        for (int i = 0; i < 10; i++) {
            stateBuilder.append(BASE_CHAR.charAt(random.nextInt(len)));
        }
        return stateBuilder.toString();
    }

    private String generateCodeChallenge(String codeVerifier) {
        byte[] bytes = digestSha256.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
        return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
    }

    private String normalizeUrl(String authServer, String uri) {
        if (!StringUtils.hasText(uri)) {
            return null;
        }
        if (!uri.toLowerCase().startsWith("http") && !uri.toLowerCase().startsWith("https")) {
            Assert.hasText(authServer, "未知认证服务器地址");
        }
        authServer = authServer == null ? "" : authServer;

        return UriComponentsBuilder.fromUriString(authServer + "/" + uri).toUriString();
    }

    static class AuthorizeCacheDefault implements AuthorizeCacheable {
        private final Cache<String, AuthorizeDTO> authorizeCache;

        public AuthorizeCacheDefault() {
            this.authorizeCache = Caffeine.newBuilder()
                    .maximumSize(2000)
                    .expireAfterWrite(Duration.ofMinutes(5))
                    .build();
        }

        @Override
        public void setCache(String reqId, AuthorizeDTO authorizeDTO) {
            authorizeCache.put(reqId, authorizeDTO);
        }

        @Override
        public AuthorizeDTO get(String reqId) {
            return authorizeCache.getIfPresent(reqId);
        }
    }
}
