package com.elitescloud.boot.auth.provider.sso2.support.impl;

import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.crypto.digest.MD5;
import com.elitescloud.boot.SpringContextHolder;
import com.elitescloud.boot.auth.client.common.SecurityConstants;
import com.elitescloud.boot.auth.provider.config.properties.Sso2Properties;
import com.elitescloud.boot.auth.provider.sso2.common.TicketProvider;
import com.elitescloud.boot.auth.util.SecurityContextUtil;
import com.elitescloud.boot.auth.util.SecurityUtil;
import com.elitescloud.boot.redis.util.RedisUtils;
import com.elitescloud.boot.util.DatetimeUtil;
import com.elitescloud.boot.wrapper.RedisWrapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

/**
 * 默认票据提供者.
 *
 * @author Kaiser（wang shao）
 * @date 2024/6/26
 */
public class DefaultTicketProvider implements TicketProvider {
    private static final Logger logger = LoggerFactory.getLogger(DefaultTicketProvider.class);

    private final RedisUtils redisUtils;
    protected final Sso2Properties sso2Properties;
    private final MD5 md5 = MD5.create();
    private RedisWrapper redisWrapper;

    public DefaultTicketProvider(RedisUtils redisUtils, Sso2Properties sso2Properties) {
        this.redisUtils = redisUtils;
        this.sso2Properties = sso2Properties;
    }

    @Override
    public String generateTicket(HttpServletRequest request, HttpServletResponse response) {
        var token = SecurityContextUtil.currentToken();
        if (token == null) {
            SecurityUtil.throwUnauthorizedException();
            return null;
        }

        // 生成ticket
        var ticket = this.produceTicket(request, response, token);

        // 存储ticket
        this.storageTicket(ticket, token);

        return ticket;
    }

    @Override
    public Object exchangeTicket(String ticket) {
        return this.retrieveTokenByTicket(ticket);
    }


    protected String produceTicket(HttpServletRequest request, HttpServletResponse response, String token) {
        var ticket = md5.digestHex(token + "::" + DatetimeUtil.currentTimeLong());
        logger.info("produce sso ticket：{}", ticket);
        return ticket;
    }

    protected void storageTicket(String ticket, String token) {
        var ttl = sso2Properties.getServer().getTicketTtl();
        var ttlSeconds = ttl == null ? -1 : ttl.toSeconds();
        this.supplyRedis(redis -> {
            redis.set(SecurityConstants.CACHE_PREFIX_SSO_TICKET + ticket, token, ttlSeconds, TimeUnit.SECONDS);
            return null;
        });
    }

    protected String retrieveTokenByTicket(String ticket) {
        var cacheKey = SecurityConstants.CACHE_PREFIX_SSO_TICKET + ticket;

        return this.supplyRedis(redis -> {
            var token = (String) redis.get(cacheKey);
            if (CharSequenceUtil.isBlank(token)) {
                return null;
            }

            // 删掉票据
            if (Boolean.TRUE.equals(sso2Properties.getServer().getExpireTicketOnUsed())) {
                redis.del(cacheKey);
            }
            return token;
        });
    }

    /**
     * 应用redis工具类
     *
     * @param callback 调用redis的方法
     * @param <T>
     * @return
     */
    protected <T> T supplyRedis(Function<RedisUtils, T> callback) {
        return (T) getRedisWrapper().apply(() -> callback.apply(redisUtils), null);
    }

    private RedisWrapper getRedisWrapper() {
        if (redisWrapper == null) {
            redisWrapper = SpringContextHolder.getBean(RedisWrapper.class);
        }
        return redisWrapper;
    }
}
