package com.elitescloud.boot.auth.provider.security.handler;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.CharSequenceUtil;
import com.elitescloud.boot.auth.client.common.AuthorizationException;
import com.elitescloud.boot.auth.client.common.SecurityConstants;
import com.elitescloud.boot.auth.client.config.AuthorizationProperties;
import com.elitescloud.boot.auth.client.config.support.AuthenticationCache;
import com.elitescloud.boot.auth.client.config.support.AuthenticationCallable;
import com.elitescloud.boot.auth.provider.common.AuthorizationConstant;
import com.elitescloud.boot.auth.provider.common.LoginDeviceLimitStrategy;
import com.elitescloud.boot.auth.provider.common.param.UserLoginDeviceDTO;
import com.elitescloud.boot.auth.provider.config.properties.AuthorizationProviderProperties;
import com.elitescloud.boot.auth.provider.config.system.LoginProperties;
import com.elitescloud.boot.auth.provider.security.TokenPropertiesProvider;
import com.elitescloud.boot.util.DatetimeUtil;
import com.elitescloud.cloudt.context.util.HttpServletUtil;
import com.elitescloud.cloudt.security.entity.GeneralUserDetails;
import lombok.extern.log4j.Log4j2;
import org.springframework.http.HttpHeaders;
import org.springframework.security.core.Authentication;
import org.springframework.util.StringUtils;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

/**
 * 认证成功后的回调.
 * <p>
 * 缓存用户信息
 *
 * @author Kaiser（wang shao）
 * @date 2022/7/14
 */
@Log4j2
public class CacheUserAuthenticationCallable implements AuthenticationCallable {

    private final AuthorizationProperties authorizationProperties;
    private final AuthorizationProviderProperties authorizationProviderProperties;
    private final AuthenticationCache authenticationCache;
    private final TokenPropertiesProvider tokenPropertiesProvider;

    public CacheUserAuthenticationCallable(AuthorizationProperties authorizationProperties, AuthorizationProviderProperties authorizationProviderProperties,
                                           AuthenticationCache authenticationCache, TokenPropertiesProvider tokenPropertiesProvider) {
        this.authorizationProperties = authorizationProperties;
        this.authorizationProviderProperties = authorizationProviderProperties;
        this.authenticationCache = authenticationCache;
        this.tokenPropertiesProvider = tokenPropertiesProvider;
    }

    @Override
    public void onLogin(HttpServletRequest request, HttpServletResponse response, String token, Authentication authentication) throws IOException, ServletException {
        GeneralUserDetails user = null;
        if (authentication.getPrincipal() instanceof GeneralUserDetails) {
            user = (GeneralUserDetails) authentication.getPrincipal();
        }

        if (user != null && StringUtils.hasText(token)) {
            log.info("用户{}登录，token：{}", user.getUsername(), token);
            authenticationCache.setUserDetail(token, user, cachePrincipalDuration());

            // 登录设备
            var loginDevice = buildLoginDevice(request, user, token);
            List<UserLoginDeviceDTO> deviceList = queryUserDeviceOnline(user);

            // 过滤掉需要踢掉的设备
            deviceList = expireOtherDevice(deviceList, loginDevice.getClientId());
            deviceList = new ArrayList<>(deviceList);
            deviceList.add(loginDevice);
            authenticationCache.setAttribute(user.getUserId().toString() + SecurityConstants.CACHE_SUFFIX_CURRENT_USER_ATTRIBUTE_DEVICE, deviceList, null);
        }
    }

    @Override
    public void onLogout(HttpServletRequest request, HttpServletResponse response, String token, Object principal) {
        if (!StringUtils.hasText(token)) {
            return;
        }
        if (principal instanceof GeneralUserDetails) {
            log.info("用户{}注销", ((GeneralUserDetails) principal).getUsername());

            // 清理在线设备
            CompletableFuture.runAsync(() -> {
                clearUserDevice((GeneralUserDetails) principal, token);
            }).whenComplete((res, exp) -> {
                if (exp != null) {
                    log.error("清理用户在线设备异常：", exp);
                }
            });
        }

        authenticationCache.removeUserDetail(token);
    }

    private void clearUserDevice(GeneralUserDetails user, String token) {
        List<UserLoginDeviceDTO> deviceList = (List<UserLoginDeviceDTO>) authenticationCache.getAttribute(user.getUserId().toString() + SecurityConstants.CACHE_SUFFIX_CURRENT_USER_ATTRIBUTE_DEVICE);
        if (CollUtil.isEmpty(deviceList)) {
            return;
        }

        List<UserLoginDeviceDTO> finalDeviceList = new ArrayList<>(deviceList.size());
        for (UserLoginDeviceDTO userLoginDeviceDTO : deviceList) {
            if (token.equals(userLoginDeviceDTO.getToken())) {
                continue;
            }
            finalDeviceList.add(userLoginDeviceDTO);
        }
        authenticationCache.setAttribute(user.getUserId().toString() + SecurityConstants.CACHE_SUFFIX_CURRENT_USER_ATTRIBUTE_DEVICE, finalDeviceList, null);
    }

    private List<UserLoginDeviceDTO> queryUserDeviceOnline(GeneralUserDetails user) {
        List<UserLoginDeviceDTO> deviceList = (List<UserLoginDeviceDTO>) authenticationCache.getAttribute(user.getUserId().toString() + SecurityConstants.CACHE_SUFFIX_CURRENT_USER_ATTRIBUTE_DEVICE);
        if (CollUtil.isEmpty(deviceList)) {
            return Collections.emptyList();
        }

        // 过滤出有效的
        return deviceList.stream().filter(t -> authenticationCache.exists(t.getToken())).collect(Collectors.toList());
    }

    private List<UserLoginDeviceDTO> expireOtherDevice(List<UserLoginDeviceDTO> currentDeviceList, String clientId) {
        if (CollUtil.isEmpty(currentDeviceList) || CharSequenceUtil.isBlank(clientId)) {
            return currentDeviceList;
        }

        LoginDeviceLimitStrategy strategy = null;
        for (LoginProperties.LoginDeviceLimiter l : authorizationProviderProperties.getLogin().getLoginDeviceLimiters()) {
            if (clientId.equals(l.getClientId())) {
                strategy = l.getStrategy();
                break;
            }
        }
        if (LoginDeviceLimitStrategy.INVALID_OTHER != strategy) {
            return currentDeviceList;
        }

        List<UserLoginDeviceDTO> finalDeviceList = new ArrayList<>(currentDeviceList.size());
        for (UserLoginDeviceDTO userLoginDeviceDTO : currentDeviceList) {
            if (clientId.equals(userLoginDeviceDTO.getClientId())) {
                authenticationCache.removeUserDetail(userLoginDeviceDTO.getToken());
                log.info("自动注销已登录的设备：{}, {}, {}", userLoginDeviceDTO.getLoginTime(), userLoginDeviceDTO.getUserAgent(), userLoginDeviceDTO.getToken());
                continue;
            }

            finalDeviceList.add(userLoginDeviceDTO);
        }

        return finalDeviceList;
    }

    private UserLoginDeviceDTO buildLoginDevice(HttpServletRequest request, GeneralUserDetails user, String token) {
        UserLoginDeviceDTO loginDevice = new UserLoginDeviceDTO();
        if (request != null) {
            String clientId = (String) request.getAttribute(AuthorizationConstant.REQUEST_ATTRIBUTE_CLIENT_ID);
            loginDevice.setClientId(clientId);
            loginDevice.setLoginIp(HttpServletUtil.currentClientIp());
            loginDevice.setUserAgent(request.getHeader(HttpHeaders.USER_AGENT));
        }
        loginDevice.setToken(token);
        loginDevice.setLoginTime(DatetimeUtil.toStr(LocalDateTime.now()));

        return loginDevice;
    }

    private Duration cachePrincipalDuration() {
        var tokenProperties = tokenPropertiesProvider.get();
        if (tokenProperties != null) {
            return tokenProperties.getTokenTtl();
        }

        if (authorizationProperties.getTokenTtl() != null && authorizationProperties.getTokenTtl().getSeconds() > 0) {
            return authorizationProperties.getTokenTtl();
        }
        return null;
    }
}
