package com.elitescloud.boot.websocket.support;

import com.elitescloud.cloudt.security.entity.GeneralUserDetails;
import com.elitescloud.cloudt.system.dto.SysTenantDTO;
import org.springframework.web.socket.WebSocketSession;

import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import java.io.Serializable;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * WebSocket Session管理.
 * <p>
 * 由于{@link org.springframework.web.socket.WebSocketSession} 不支持序列化，因此在分布式环境或集群环境中需借助中间件来发布
 *
 * @author Kaiser（wang shao）
 * @date 2023/5/31
 */
public class WebSocketSessionManager {

    /**
     * 用户的所有session
     */
    private static final ConcurrentHashMap<String, List<WebSocketSession>> USER_SESSION_ALL = new ConcurrentHashMap<>();
    /**
     * session管理的所有用户
     */
    private static final ConcurrentHashMap<String, WebSocketSessionWrapper> SESSION_ALL = new ConcurrentHashMap<>();

    /**
     * 获取用户的session
     *
     * @param username 用户名
     * @return session
     */
    public static List<WebSocketSession> getSession(@NotBlank String username) {
        return USER_SESSION_ALL.getOrDefault(username, Collections.emptyList());
    }

    /**
     * 在线用户数
     *
     * @return 数量
     */
    public static int countSession() {
        return SESSION_ALL.size();
    }

    /**
     * 在线的用户
     *
     * @return 用户账号
     */
    public static List<String> onlineUsers() {
        var usernameEnumeration = SESSION_ALL.keys();
        List<String> usernameList = new ArrayList<>();
        while (usernameEnumeration.hasMoreElements()) {
            usernameList.add(usernameEnumeration.nextElement());
        }
        return usernameList;
    }

    /**
     * 获取在线用户
     *
     * @param username 用户名
     * @return 在线用户信息
     */
    public static List<SessionUser> sessionOfUser(@NotBlank String username) {
        return USER_SESSION_ALL.getOrDefault(username, Collections.emptyList())
                .stream()
                .filter(t -> SESSION_ALL.containsKey(t.getId()))
                .map(t -> {
                    var session = SESSION_ALL.get(t.getId());
                    SessionUser sessionUser = new SessionUser();
                    sessionUser.setUsername(session.currentUser.getUsername());
                    sessionUser.setUserId(session.currentUser.getUserId());
                    sessionUser.setCreateTime(session.createTime);
                    sessionUser.setSessionId(t.getId());
                    if (session.tenantDTO != null) {
                        sessionUser.setTenantCode(session.tenantDTO.getTenantCode());
                    }

                    return sessionUser;
                }).collect(Collectors.toList());
    }

    /**
     * 添加session
     *
     * @param session session
     */
    static void addSession(@NotNull WebSocketSession session) {
        addAuthSession(session, null);
    }

    /**
     * 添加认证session
     *
     * @param session     session
     * @param userDetails 用户信息
     */
    static void addAuthSession(@NotNull WebSocketSession session, GeneralUserDetails userDetails) {
        SESSION_ALL.put(session.getId(), new WebSocketSessionWrapper(session, userDetails));
        if (userDetails != null) {
            USER_SESSION_ALL.computeIfAbsent(userDetails.getUsername(), t -> new ArrayList<>())
                    .add(session);
        }
    }

    /**
     * 添加租户session
     *
     * @param session   session
     * @param tenantDTO 租户信息
     */
    static void addTenantSession(@NotNull WebSocketSession session, SysTenantDTO tenantDTO) {
        SESSION_ALL.put(session.getId(), new WebSocketSessionWrapper(session, tenantDTO));
    }

    /**
     * 获取session中的用户
     *
     * @param session session
     * @return 用户
     */
    static GeneralUserDetails obtainAuthUser(@NotNull WebSocketSession session) {
        var wrapperSession = SESSION_ALL.get(session.getId());
        return wrapperSession == null ? null : wrapperSession.getCurrentUser();
    }

    /**
     * 移除session
     *
     * @param session session信息
     */
    public static void removeSession(@NotNull WebSocketSession session) {
        var sessionWrapper = SESSION_ALL.get(session.getId());
        if (sessionWrapper == null) {
            return;
        }
        SESSION_ALL.remove(session.getId());

        // 移除用户的session
        if (sessionWrapper.getCurrentUser() != null) {
            List<WebSocketSession> sessions = new ArrayList<>();
            for (WebSocketSession webSocketSession : USER_SESSION_ALL.get(sessionWrapper.getCurrentUser().getUsername())) {
                if (Objects.equals(session.getId(), webSocketSession.getId())) {
                    // 需要移除的
                    continue;
                }
                sessions.add(webSocketSession);
            }
            USER_SESSION_ALL.put(sessionWrapper.getCurrentUser().getUsername(), sessions);
        }

        // 关闭session
        try {
            session.close();
        } catch (Exception e) {
            // ignored
        }
    }

    static class WebSocketSessionWrapper {
        private final WebSocketSession webSocketSession;
        private final GeneralUserDetails currentUser;
        private final SysTenantDTO tenantDTO;
        private final LocalDateTime createTime;

        public WebSocketSessionWrapper(@NotNull WebSocketSession webSocketSession) {
            this.webSocketSession = webSocketSession;
            this.currentUser = null;
            this.tenantDTO = null;
            this.createTime = LocalDateTime.now();
        }

        public WebSocketSessionWrapper(@NotNull WebSocketSession webSocketSession, GeneralUserDetails currentUser) {
            this.webSocketSession = webSocketSession;
            this.currentUser = currentUser;
            this.tenantDTO = null;
            this.createTime = LocalDateTime.now();
        }

        public WebSocketSessionWrapper(WebSocketSession webSocketSession, SysTenantDTO tenantDTO) {
            this.webSocketSession = webSocketSession;
            this.currentUser = null;
            this.tenantDTO = tenantDTO;
            this.createTime = LocalDateTime.now();
        }

        public WebSocketSession getWebSocketSession() {
            return webSocketSession;
        }

        public GeneralUserDetails getCurrentUser() {
            return currentUser;
        }

        public SysTenantDTO getTenantDTO() {
            if (currentUser != null && currentUser.getTenant() != null) {
                return currentUser.getTenant();
            }
            return tenantDTO;
        }

        public LocalDateTime getCreateTime() {
            return createTime;
        }
    }

    public static class SessionUser implements Serializable {
        private static final long serialVersionUID = -1969237509857337862L;
        private String username;
        private Long userId;
        private String tenantCode;
        private LocalDateTime createTime;
        private String sessionId;

        public String getUsername() {
            return username;
        }

        public void setUsername(String username) {
            this.username = username;
        }

        public Long getUserId() {
            return userId;
        }

        public void setUserId(Long userId) {
            this.userId = userId;
        }

        public LocalDateTime getCreateTime() {
            return createTime;
        }

        public void setCreateTime(LocalDateTime createTime) {
            this.createTime = createTime;
        }

        public String getSessionId() {
            return sessionId;
        }

        public void setSessionId(String sessionId) {
            this.sessionId = sessionId;
        }

        public String getTenantCode() {
            return tenantCode;
        }

        public void setTenantCode(String tenantCode) {
            this.tenantCode = tenantCode;
        }
    }
}
