package dev.langchain4j.community.store.embedding.redis;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.params.ScanParams;
import redis.clients.jedis.resps.ScanResult;
import redis.clients.jedis.search.Document;
import redis.clients.jedis.search.FTCreateParams;
import redis.clients.jedis.search.IndexDataType;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.RediSearchUtil;
import redis.clients.jedis.search.schemafields.SchemaField;
import redis.clients.jedis.search.schemafields.TextField;

/* loaded from: input_file:dev/langchain4j/community/store/embedding/redis/RedisEmbeddingStore.class */
public class RedisEmbeddingStore implements EmbeddingStore<TextSegment>, AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(RedisEmbeddingStore.class);
    private static final String QUERY_TEMPLATE = "%s=>[ KNN %d @%s $BLOB AS %s ]";
    private final JedisPooled client;
    private final RedisSchema schema;
    private final RedisMetadataFilterMapper filterMapper;

    /* loaded from: input_file:dev/langchain4j/community/store/embedding/redis/RedisEmbeddingStore$Builder.class */
    public static class Builder {
        private String uri;
        private String host;
        private Integer port;
        private String user;
        private String password;
        private String indexName;
        private String prefix;
        private Integer dimension;
        private Map<String, SchemaField> metadataConfig = new HashMap();

        public Builder uri(String str) {
            this.uri = str;
            return this;
        }

        public Builder host(String str) {
            this.host = str;
            return this;
        }

        public Builder port(Integer num) {
            this.port = num;
            return this;
        }

        public Builder user(String str) {
            this.user = str;
            return this;
        }

        public Builder password(String str) {
            this.password = str;
            return this;
        }

        public Builder indexName(String str) {
            this.indexName = str;
            return this;
        }

        public Builder prefix(String str) {
            this.prefix = str;
            return this;
        }

        public Builder dimension(Integer num) {
            this.dimension = num;
            return this;
        }

        @Deprecated
        public Builder metadataFieldsName(Collection<String> collection) {
            return metadataKeys(collection);
        }

        public Builder metadataKeys(Collection<String> collection) {
            collection.forEach(str -> {
                this.metadataConfig.put(str, TextField.of("$." + str).as(str).weight(1.0d));
            });
            return this;
        }

        public Builder metadataConfig(Map<String, SchemaField> map) {
            this.metadataConfig = map;
            return this;
        }

        public RedisEmbeddingStore build() {
            return this.uri != null ? new RedisEmbeddingStore(this.uri, this.indexName, this.prefix, this.dimension, this.metadataConfig) : new RedisEmbeddingStore(this.host, this.port, this.user, this.password, this.indexName, this.prefix, this.dimension, this.metadataConfig);
        }
    }

    public RedisEmbeddingStore(String str, Integer num, String str2, String str3, String str4, String str5, Integer num2, Map<String, SchemaField> map) {
        ValidationUtils.ensureNotBlank(str, "host");
        ValidationUtils.ensureNotNull(num, "port");
        this.client = str2 == null ? new JedisPooled(str, num.intValue()) : new JedisPooled(str, num.intValue(), str2, str3);
        this.schema = RedisSchema.builder().indexName((String) Utils.getOrDefault(str4, "embedding-index")).prefix((String) Utils.getOrDefault(str5, "embedding:")).dimension(num2).metadataConfig(Utils.copyIfNotNull(map)).build();
        this.filterMapper = new RedisMetadataFilterMapper(map);
        if (isIndexExist(this.schema.indexName())) {
            return;
        }
        ValidationUtils.ensureNotNull(num2, "dimension");
        createIndex(this.schema.indexName());
    }

    public RedisEmbeddingStore(String str, String str2, String str3, Integer num, Map<String, SchemaField> map) {
        ValidationUtils.ensureNotBlank(str, "uri");
        this.client = new JedisPooled(str);
        this.schema = RedisSchema.builder().indexName((String) Utils.getOrDefault(str2, "embedding-index")).prefix((String) Utils.getOrDefault(str3, "embedding:")).dimension(num).metadataConfig(Utils.copyIfNotNull(map)).build();
        this.filterMapper = new RedisMetadataFilterMapper(map);
        if (isIndexExist(this.schema.indexName())) {
            return;
        }
        ValidationUtils.ensureNotNull(num, "dimension");
        createIndex(this.schema.indexName());
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAll(list2, list, null);
        return list2;
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        return new EmbeddingSearchResult<>(toEmbeddingMatch(this.client.ftSearch(this.schema.indexName(), new Query(String.format(QUERY_TEMPLATE, this.filterMapper.mapToFilter(embeddingSearchRequest.filter()), Integer.valueOf(embeddingSearchRequest.maxResults()), this.schema.vectorFieldName(), RedisSchema.SCORE_FIELD_NAME)).addParam("BLOB", RediSearchUtil.toByteArray(embeddingSearchRequest.queryEmbedding().vector())).setSortBy(RedisSchema.SCORE_FIELD_NAME, true).limit(0, Integer.valueOf(embeddingSearchRequest.maxResults())).dialect(2)).getDocuments(), embeddingSearchRequest.minScore()));
    }

    public void removeAll(Collection<String> collection) {
        ValidationUtils.ensureNotEmpty(collection, "ids");
        this.client.del((String[]) collection.stream().map(str -> {
            return this.schema.prefix() + str;
        }).toArray(i -> {
            return new String[i];
        }));
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull(filter, "filter");
        this.client.del((String[]) this.client.ftSearch(this.schema.indexName(), this.filterMapper.mapToFilter(filter)).getDocuments().stream().map((v0) -> {
            return v0.getId();
        }).toArray(i -> {
            return new String[i];
        }));
    }

    public void removeAll() {
        HashSet hashSet = new HashSet();
        ScanParams scanParams = new ScanParams();
        scanParams.match(this.schema.prefix() + "*");
        String str = "0";
        do {
            ScanResult scan = this.client.scan(str, scanParams);
            List result = scan.getResult();
            str = scan.getCursor();
            hashSet.addAll(result);
        } while (!str.equals("0"));
        if (hashSet.isEmpty()) {
            return;
        }
        this.client.del((String[]) hashSet.toArray(new String[0]));
    }

    private void createIndex(String str) {
        String ftCreate = this.client.ftCreate(str, FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.schema.prefix()), this.schema.toSchemaFields());
        if ("OK".equals(ftCreate)) {
            return;
        }
        if (log.isErrorEnabled()) {
            log.error("create index error, msg={}", ftCreate);
        }
        throw new RedisRequestFailedException("create index error, msg=" + ftCreate);
    }

    private boolean isIndexExist(String str) {
        return this.client.ftList().contains(str);
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAll(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("do not add empty embeddings to redis");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        Pipeline pipelined = this.client.pipelined();
        try {
            int size = list.size();
            for (int i = 0; i < size; i++) {
                String str = list.get(i);
                Embedding embedding = list2.get(i);
                TextSegment textSegment = list3 == null ? null : list3.get(i);
                HashMap hashMap = new HashMap();
                hashMap.put(this.schema.vectorFieldName(), embedding.vector());
                if (textSegment != null) {
                    hashMap.put(this.schema.scalarFieldName(), textSegment.text());
                    hashMap.putAll(textSegment.metadata().toMap());
                }
                pipelined.jsonSetWithEscape(this.schema.prefix() + str, RedisSchema.JSON_SET_PATH, hashMap);
            }
            List syncAndReturnAll = pipelined.syncAndReturnAll();
            if (pipelined != null) {
                pipelined.close();
            }
            Optional findAny = syncAndReturnAll.stream().filter(obj -> {
                return !"OK".equals(obj);
            }).findAny();
            if (findAny.isPresent()) {
                if (log.isErrorEnabled()) {
                    log.error("add embedding failed, msg={}", findAny.get());
                }
                throw new RedisRequestFailedException("add embedding failed, msg=" + String.valueOf(findAny.get()));
            }
        } catch (Throwable th) {
            if (pipelined != null) {
                try {
                    pipelined.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private List<EmbeddingMatch<TextSegment>> toEmbeddingMatch(List<Document> list, double d) {
        return (list == null || list.isEmpty()) ? new ArrayList() : (List) list.stream().map(document -> {
            double parseDouble = (2.0d - Double.parseDouble(document.getString(RedisSchema.SCORE_FIELD_NAME))) / 2.0d;
            String substring = document.getId().substring(this.schema.prefix().length());
            Map<String, Object> properties = RedisJsonUtils.toProperties(document.getString(RedisSchema.JSON_KEY));
            Embedding from = Embedding.from((List) ((List) properties.get(this.schema.vectorFieldName())).stream().map((v0) -> {
                return v0.floatValue();
            }).collect(Collectors.toList()));
            String str = properties.containsKey(this.schema.scalarFieldName()) ? (String) properties.get(this.schema.scalarFieldName()) : null;
            TextSegment textSegment = null;
            if (str != null) {
                Stream<String> stream = this.schema.schemaFieldMap().keySet().stream();
                Objects.requireNonNull(properties);
                Stream<String> filter = stream.filter((v1) -> {
                    return r1.containsKey(v1);
                });
                Function function = str2 -> {
                    return str2;
                };
                Objects.requireNonNull(properties);
                textSegment = TextSegment.from(str, Metadata.from((Map) filter.collect(Collectors.toMap(function, (v1) -> {
                    return r2.get(v1);
                }))));
            }
            return new EmbeddingMatch(Double.valueOf(parseDouble), substring, from, textSegment);
        }).filter(embeddingMatch -> {
            return embeddingMatch.score().doubleValue() >= d;
        }).collect(Collectors.toList());
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.client.close();
    }
}
