package com.elitescloud.boot.auth.sso.configurer.filter;

import com.elitescloud.boot.auth.config.AuthorizationSdkProperties;
import com.elitescloud.boot.auth.sso.SsoProvider;
import com.elitescloud.boot.auth.sso.TicketResolver;
import com.elitescloud.boot.auth.sso.common.SdkSsoConstants;
import com.elitescloud.boot.auth.sso.common.TicketAuthentication;
import com.elitescloud.boot.auth.sso.model.UserInfoDTO;
import com.elitescloud.boot.auth.util.AuthorizationServerHelper;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.Collections;

/**
 * 单点登录过滤器.
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/14
 */
public class CloudtSsoFilter extends OncePerRequestFilter {
    private static final Logger LOG = LoggerFactory.getLogger(CloudtSsoFilter.class);

    private final AuthorizationSdkProperties sdkProperties;
    private TicketResolver ticketResolver;
    private SsoProvider ssoProvider;
    private AuthenticationSuccessHandler authenticationSuccessHandler;
    private AuthenticationFailureHandler authenticationFailureHandler;
    private final AuthorizationServerHelper authorizationServerHelper = AuthorizationServerHelper.getInstance();
    private final ObjectMapper objectMapper = new ObjectMapper();

    public CloudtSsoFilter(AuthorizationSdkProperties sdkProperties) {
        this.sdkProperties = sdkProperties;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        // 是否是sso认证
        if (supportAuthorize(request)) {
            ssoAuthorize(request, response);
            return;
        }

        // 是否是注销
        if (supportRevoke(request)) {
            ssoRevoke(request, response);
            return;
        }

        filterChain.doFilter(request, response);
    }

    public void setSsoProvider(SsoProvider ssoProvider) {
        this.ssoProvider = ssoProvider;
    }

    public void setTicketResolver(TicketResolver ticketResolver) {
        this.ticketResolver = ticketResolver;
    }

    public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
        this.authenticationSuccessHandler = authenticationSuccessHandler;
    }

    public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
        this.authenticationFailureHandler = authenticationFailureHandler;
    }

    private void ssoAuthorize(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        // 内部已认证
        if (ssoProvider != null && ssoProvider.isAuthenticated(request)) {
            respAuthorized(response, false, null);
            return;
        }

        // 根据ticket从认证中心获取当前用户
        UserInfoDTO userInfoDTO = null;
        String ticket = ticketResolver.obtain(request);
        if (StringUtils.hasText(ticket)) {
            userInfoDTO = obtainUser(ticket);
        }
        if (userInfoDTO == null) {
            // 没有解析到用户，需重新登录
            respLogin(response);
            return;
        }

        // 开始认证（生成token）
        TicketAuthentication authentication = new TicketAuthentication(ticket, userInfoDTO);
        if (ssoProvider != null) {
            try {
                authentication = ssoProvider.authentication(request, response, authentication);
            } catch (Exception e) {
                LOG.info("认证失败：{}", e.getMessage());
                respLogin(response);
                return;
            }
        } else {
            authentication = new TicketAuthentication(ticket, userInfoDTO, Collections.emptyList());
        }

        // 认证后的回调
        if (authentication.isAuthenticated()) {
            if (authenticationSuccessHandler != null) {
                authenticationSuccessHandler.onAuthenticationSuccess(request, response, authentication);
            }
        } else {
            if (authenticationFailureHandler != null) {
                authenticationFailureHandler.onAuthenticationFailure(request, response, new AuthenticationServiceException("认证失败"));
            }
        }

        // ticket写入本地
        ticketResolver.save(request, response, ticket);
        respAuthorized(response, true, authentication.getToken());
    }

    private void ssoRevoke(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        // 解析ticket
        String ticket = ticketResolver.obtain(request);

        if (ssoProvider != null) {
            // 清理内部认证信息
            ssoProvider.clearToken(request, response, ticket);
        }

        // 调用认证服务清理全部的
        if (StringUtils.hasText(ticket)) {
            String server = request.getParameter(SdkSsoConstants.PARAM_SSO_SERVER);
            if (!"true".equalsIgnoreCase(server)) {
                // 来自服务端的请求，则不再调用服务端
                authorizationServerHelper.revokeTicket(sdkProperties.getAuthServer(), ticket);
            }

            ticketResolver.clear(request, response);
        }

        respLogin(response);
    }

    private boolean supportAuthorize(HttpServletRequest request) {
        if (!StringUtils.hasText(sdkProperties.getSso().getAuthorizeEndpoint())) {
            LOG.error("单点登录拦截失效，认证地址未配置！");
            return false;
        }
        if (!request.getRequestURI().equals(sdkProperties.getSso().getAuthorizeEndpoint())) {
            return false;
        }

        return true;
    }

    private boolean supportRevoke(HttpServletRequest request) {
        if (!StringUtils.hasText(sdkProperties.getSso().getAuthorizeRevokeEndpoint())) {
            LOG.error("单点登录的注销地址未配置！");
            return false;
        }
        if (!request.getRequestURI().equals(sdkProperties.getSso().getAuthorizeRevokeEndpoint())) {
            return false;
        }

        return true;
    }

    private UserInfoDTO obtainUser(String ticket) {
        // ticket换取用户信息
        UserInfoDTO userInfoDTO = null;
        try {
            userInfoDTO = authorizationServerHelper.ticket2UserInfo(sdkProperties.getAuthServer(), ticket);
        } catch (Exception e) {
            logger.error("单点登录异常：", e);
        }
        if (userInfoDTO == null) {
            LOG.info("未解析到有效用户信息，需登录");
        }
        return userInfoDTO;
    }

    private void respAuthorized(HttpServletResponse response, boolean sso, Serializable token) throws IOException {
        response.setCharacterEncoding(StandardCharsets.UTF_8.name());
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);
        response.setStatus(HttpStatus.OK.value());

        // 写入返回内容
        try (var writer = response.getWriter()) {
            String result = objectMapper.writeValueAsString(new AuthorizedResult(true, sso ? "SSO服务已认证" : "服务已认证", token));
            writer.write(result);
        }
    }

    private void respLogin(HttpServletResponse response) throws IOException {
        response.setCharacterEncoding(StandardCharsets.UTF_8.name());
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);
        response.setStatus(HttpStatus.UNAUTHORIZED.value());

        // 写入返回内容
        try (var writer = response.getWriter()) {
            String result = objectMapper.writeValueAsString(new AuthorizedResult(false, "请重新登录"));
            writer.write(result);
        }
    }

    static class AuthorizedResult implements Serializable {
        private static final long serialVersionUID = -4599981681764185389L;
        /**
         * 是否认证通过
         */
        private final boolean authorized;
        /**
         * 提示信息
         */
        private final String msg;
        /**
         * 内部token
         */
        private final Serializable token;

        public AuthorizedResult() {
            this.authorized = false;
            this.msg = "";
            this.token = null;
        }

        public AuthorizedResult(boolean authorized, String msg) {
            this.authorized = authorized;
            this.msg = msg;
            this.token = null;
        }

        public AuthorizedResult(boolean authorized, String msg, Serializable token) {
            this.authorized = authorized;
            this.msg = msg;
            this.token = token;
        }

        public boolean isAuthorized() {
            return authorized;
        }

        public String getMsg() {
            return msg;
        }

        public Serializable getToken() {
            return token;
        }
    }
}
