/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query.nativelib;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOSupplier;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.index.query.PerLeafResult;
import org.opensearch.knn.index.query.ResultUtil;
import org.opensearch.knn.index.query.TopDocsDISI;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.exactsearch.ExactSearcher;
import org.opensearch.knn.index.query.memoryoptsearch.MemoryOptimizedKNNWeight;
import org.opensearch.knn.index.query.memoryoptsearch.optimistic.OptimisticSearchStrategyUtils;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.profile.KNNProfileUtil;
import org.opensearch.knn.profile.LongMetric;
import org.opensearch.knn.profile.StopWatchUtils;
import org.opensearch.lucene.ReentrantKnnCollectorManager;
import org.opensearch.search.profile.AbstractProfileBreakdown;
import org.opensearch.search.profile.ContextualProfileBreakdown;
import org.opensearch.search.profile.query.QueryProfiler;

public class NativeEngineKnnVectorQuery
extends Query {
    @Generated
    private static final Logger log = LogManager.getLogger(NativeEngineKnnVectorQuery.class);
    private static final boolean FORCE_REENTER_TESTING = Boolean.parseBoolean(System.getProperty("mem_opt_srch.force_reenter", "false"));
    private final KNNQuery knnQuery;
    private final QueryUtils queryUtils;
    private final boolean expandNestedDocs;

    public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException {
        StopWatch stopWatch;
        List<PerLeafResult> perLeafResults;
        KNNWeight knnWeight;
        boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(this.knnQuery.getIndexName());
        Integer firstPassKFor2PhaseSearch = this.getFirstPassK(isShardLevelRescoringDisabled);
        IOSupplier<KNNWeight> weightSupplier = this.getKNNWeightSupplier(firstPassKFor2PhaseSearch, indexSearcher, scoreMode);
        QueryProfiler profiler = KNNProfileUtil.getProfiler(indexSearcher);
        if (profiler != null) {
            profiler.getQueryBreakdown((Object)this.knnQuery);
            knnWeight = (KNNWeight)((Object)weightSupplier.get());
            profiler.pollLastElement();
        } else {
            knnWeight = (KNNWeight)((Object)weightSupplier.get());
        }
        IndexReader reader = indexSearcher.getIndexReader();
        List leafReaderContexts = reader.leaves();
        int finalK = this.knnQuery.getK();
        if (!this.isRescoreRequired(firstPassKFor2PhaseSearch)) {
            perLeafResults = this.doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
        } else {
            perLeafResults = this.doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassKFor2PhaseSearch);
            if (!isShardLevelRescoringDisabled) {
                ResultUtil.reduceToTopK(perLeafResults, firstPassKFor2PhaseSearch);
            }
            stopWatch = new StopWatch().start();
            perLeafResults = this.doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
            long rescoreTime = stopWatch.stop().totalTime().millis();
            log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", (Object)rescoreTime, (Object)firstPassKFor2PhaseSearch, (Object)leafReaderContexts.size());
        }
        if (!this.knnQuery.isMemoryOptimizedSearch()) {
            ResultUtil.reduceToTopK(perLeafResults, finalK);
        }
        if (this.expandNestedDocs) {
            stopWatch = new StopWatch().start();
            perLeafResults = this.retrieveAll(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, Objects.isNull(this.knnQuery.getRescoreContext()));
            long time_in_millis = stopWatch.stop().totalTime().millis();
            if (log.isDebugEnabled()) {
                long totalNestedDocs = perLeafResults.stream().mapToLong(perLeafResult -> perLeafResult.getResult().scoreDocs.length).sum();
                log.debug("Expanding of nested docs took {} ms. totalNestedDocs:{} ", (Object)time_in_millis, (Object)totalNestedDocs);
            }
        }
        TopDocs[] topDocs = new TopDocs[perLeafResults.size()];
        for (int i = 0; i < perLeafResults.size(); ++i) {
            TopDocs leafTopDocs = perLeafResults.get(i).getResult();
            for (ScoreDoc scoreDoc : leafTopDocs.scoreDocs) {
                scoreDoc.doc += ((LeafReaderContext)leafReaderContexts.get((int)i)).docBase;
            }
            topDocs[i] = leafTopDocs;
        }
        TopDocs topK = TopDocs.merge((int)this.getTotalTopDoc(topDocs), (TopDocs[])topDocs);
        if (topK.scoreDocs.length == 0) {
            return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost);
        }
        return this.queryUtils.createDocAndScoreQuery(reader, topK, knnWeight).createWeight(indexSearcher, scoreMode, boost);
    }

    private boolean isRescoreRequired(Integer firstPassKFor2PhaseSearch) {
        return firstPassKFor2PhaseSearch != null;
    }

    private IOSupplier<KNNWeight> getKNNWeightSupplier(Integer firstPassKFor2PhaseSearch, IndexSearcher indexSearcher, ScoreMode scoreMode) {
        if (!this.isRescoreRequired(firstPassKFor2PhaseSearch)) {
            return () -> (KNNWeight)this.knnQuery.createWeight(indexSearcher, scoreMode, 1.0f);
        }
        return () -> (KNNWeight)this.knnQuery.createWeight(indexSearcher, scoreMode, 1.0f, firstPassKFor2PhaseSearch);
    }

    private Integer getFirstPassK(boolean isShardLevelRescoringDisabled) {
        RescoreContext rescoreContext = this.knnQuery.getRescoreContext();
        if (rescoreContext != null && rescoreContext.isRescoreEnabled()) {
            int dimension = this.knnQuery.getQueryVector().length;
            return rescoreContext.getFirstPassK(this.knnQuery.getK(), isShardLevelRescoringDisabled, dimension);
        }
        return null;
    }

    private int getTotalTopDoc(TopDocs[] topDocs) {
        if (this.knnQuery.isMemoryOptimizedSearch() || this.expandNestedDocs) {
            int sum = 0;
            for (TopDocs topDoc : topDocs) {
                sum += topDoc.scoreDocs.length;
            }
            return sum;
        }
        return this.knnQuery.getK();
    }

    private List<PerLeafResult> retrieveAll(IndexSearcher indexSearcher, List<LeafReaderContext> leafReaderContexts, KNNWeight knnWeight, List<PerLeafResult> perLeafResults, boolean useQuantizedVectors) throws IOException {
        ArrayList<Callable<PerLeafResult>> nestedQueryTasks = new ArrayList<Callable<PerLeafResult>>(leafReaderContexts.size());
        int i = 0;
        while (i < perLeafResults.size()) {
            LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
            QueryProfiler profiler = KNNProfileUtil.getProfiler(indexSearcher);
            int finalI = i++;
            nestedQueryTasks.add(() -> {
                PerLeafResult result = this.retrieveLeafResult(leafReaderContext, knnWeight, perLeafResults, useQuantizedVectors, finalI);
                if (profiler != null) {
                    AbstractProfileBreakdown profile = ((ContextualProfileBreakdown)profiler.getProfileBreakdown((Query)this)).context((Object)leafReaderContext);
                    LongMetric metric = (LongMetric)profile.getMetric("num_nested_docs");
                    metric.setValue(Long.valueOf(result.getResult().scoreDocs.length));
                }
                return result;
            });
        }
        return indexSearcher.getTaskExecutor().invokeAll(nestedQueryTasks);
    }

    private PerLeafResult retrieveLeafResult(LeafReaderContext leafReaderContext, KNNWeight knnWeight, List<PerLeafResult> perLeafResults, boolean useQuantizedVectors, int finalI) throws IOException {
        PerLeafResult perLeafResult = perLeafResults.get(finalI);
        if (perLeafResult.getResult().scoreDocs.length == 0) {
            return perLeafResult;
        }
        Set<Integer> docIds = Arrays.stream(perLeafResult.getResult().scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toSet());
        DocIdSetIterator allSiblings = this.queryUtils.getAllSiblings(leafReaderContext, docIds, this.knnQuery.getParentsFilter(), (Bits)perLeafResult.getFilterBits());
        ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder().matchedDocsIterator(allSiblings).numberOfMatchedDocs(allSiblings.cost()).useQuantizedVectorsForSearch(useQuantizedVectors).k((int)allSiblings.cost()).field(this.knnQuery.getField()).radius(this.knnQuery.getRadius()).floatQueryVector(this.knnQuery.getQueryVector()).byteQueryVector(this.knnQuery.getByteQueryVector()).isMemoryOptimizedSearchEnabled(this.knnQuery.isMemoryOptimizedSearch()).build();
        TopDocs rescoreResult = knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
        return new PerLeafResult(perLeafResult.getFilterBits(), perLeafResult.getFilterBitsCardinality(), rescoreResult, PerLeafResult.SearchMode.EXACT_SEARCH);
    }

    private List<PerLeafResult> doSearch(IndexSearcher indexSearcher, List<LeafReaderContext> leafReaderContexts, KNNWeight knnWeight, int k) throws IOException {
        ArrayList<Callable<PerLeafResult>> tasks = new ArrayList<Callable<PerLeafResult>>(leafReaderContexts.size());
        for (LeafReaderContext leafReaderContext : leafReaderContexts) {
            tasks.add(() -> this.searchLeaf(leafReaderContext, knnWeight, k));
        }
        List perLeafResults = indexSearcher.getTaskExecutor().invokeAll(tasks);
        if (this.knnQuery.isMemoryOptimizedSearch() && perLeafResults.size() > 1) {
            log.debug("Running second deep dive search in optimistic while memory optimized search is enabled. perLeafResults.size()={}", (Object)perLeafResults.size());
            StopWatch stopWatch = StopWatchUtils.startStopWatch(log);
            this.reentrantSearch(perLeafResults, knnWeight, leafReaderContexts, k, indexSearcher);
            StopWatchUtils.stopStopWatchAndLog(log, stopWatch, "2ndOptimisticSearch", this.knnQuery.getShardId(), "All Shards", this.knnQuery.getField());
        }
        return perLeafResults;
    }

    private void reentrantSearch(List<PerLeafResult> perLeafResults, KNNWeight knnWeight, List<LeafReaderContext> leafReaderContexts, int k, IndexSearcher indexSearcher) throws IOException {
        if (!(knnWeight instanceof MemoryOptimizedKNNWeight)) {
            log.error("Memory optimized search was enabled, but got [" + (knnWeight == null ? "null" : ((Object)((Object)knnWeight)).getClass().getSimpleName()) + "], expected=" + MemoryOptimizedKNNWeight.class.getSimpleName());
            return;
        }
        assert (perLeafResults.size() == leafReaderContexts.size());
        MemoryOptimizedKNNWeight memoryOptKNNWeight = (MemoryOptimizedKNNWeight)knnWeight;
        int totalResults = 0;
        for (PerLeafResult perLeafResult : perLeafResults) {
            totalResults += perLeafResult.getResult().scoreDocs.length;
        }
        if (totalResults == 0) {
            return;
        }
        float minTopKScore = OptimisticSearchStrategyUtils.findKthLargestScore(perLeafResults, k, totalResults);
        ArrayList<Callable<TopDocs>> secondDeepDiveTasks = new ArrayList<Callable<TopDocs>>();
        ArrayList<Integer> contextIndices = new ArrayList<Integer>();
        HashMap<Integer, TopDocs> segmentOrdToResults = new HashMap<Integer, TopDocs>();
        for (int i = 0; i < leafReaderContexts.size(); ++i) {
            LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
            PerLeafResult perLeafResult = perLeafResults.get(i);
            TopDocs perLeaf = perLeafResults.get(i).getResult();
            if (perLeaf.scoreDocs.length <= 0 || perLeafResult.getSearchMode() != PerLeafResult.SearchMode.APPROXIMATE_SEARCH || !FORCE_REENTER_TESTING && !(perLeaf.scoreDocs[perLeaf.scoreDocs.length - 1].score >= minTopKScore)) continue;
            log.debug("Entering the second deep dive approximate search while FORCE_REENTER_TESTING={}", (Object)FORCE_REENTER_TESTING);
            segmentOrdToResults.put(leafReaderContext.ord, perLeaf);
            secondDeepDiveTasks.add(() -> knnWeight.approximateSearch(leafReaderContext, perLeafResult.getFilterBits(), perLeafResult.getFilterBitsCardinality(), k));
            contextIndices.add(i);
        }
        if (!secondDeepDiveTasks.isEmpty()) {
            ReentrantKnnCollectorManager reentrantCollectorManager = new ReentrantKnnCollectorManager((KnnCollectorManager)new TopKnnCollectorManager(k, indexSearcher), segmentOrdToResults, this.knnQuery.getVectorDataType() == VectorDataType.FLOAT ? this.knnQuery.getQueryVector() : (float[])this.knnQuery.getByteQueryVector(), this.knnQuery.getField());
            memoryOptKNNWeight.setReentrantKNNCollectorManager(reentrantCollectorManager);
            List deepDiveTopDocs = indexSearcher.getTaskExecutor().invokeAll(secondDeepDiveTasks);
            for (int i = 0; i < deepDiveTopDocs.size(); ++i) {
                TopDocs resultsFromDeepDive = (TopDocs)deepDiveTopDocs.get(i);
                PerLeafResult perLeafResult = perLeafResults.get((Integer)contextIndices.get(i));
                perLeafResult.setResult(resultsFromDeepDive);
            }
        }
    }

    private List<PerLeafResult> doRescore(IndexSearcher indexSearcher, List<LeafReaderContext> leafReaderContexts, KNNWeight knnWeight, List<PerLeafResult> perLeafResults, int k) throws IOException {
        ArrayList<Callable<PerLeafResult>> rescoreTasks = new ArrayList<Callable<PerLeafResult>>(leafReaderContexts.size());
        int i = 0;
        while (i < perLeafResults.size()) {
            LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
            int finalI = i++;
            rescoreTasks.add(() -> {
                PerLeafResult perLeafeResult = (PerLeafResult)perLeafResults.get(finalI);
                if (perLeafeResult.getResult().scoreDocs.length == 0) {
                    return perLeafeResult;
                }
                Set<Integer> docIds = Arrays.stream(perLeafeResult.getResult().scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toSet());
                TopDocsDISI matchedDocs = this.knnQuery.getParentsFilter() != null ? this.queryUtils.getAllSiblings(leafReaderContext, docIds, this.knnQuery.getParentsFilter(), (Bits)perLeafeResult.getFilterBits()) : new TopDocsDISI(perLeafeResult.getResult());
                ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder().matchedDocsIterator(matchedDocs).numberOfMatchedDocs(matchedDocs.cost()).useQuantizedVectorsForSearch(false).k(k).radius(this.knnQuery.getRadius()).field(this.knnQuery.getField()).floatQueryVector(this.knnQuery.getQueryVector()).byteQueryVector(this.knnQuery.getByteQueryVector()).isMemoryOptimizedSearchEnabled(this.knnQuery.isMemoryOptimizedSearch()).parentsFilter(this.knnQuery.getParentsFilter()).build();
                TopDocs rescoreResult = knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
                return new PerLeafResult(perLeafeResult.getFilterBits(), perLeafeResult.getFilterBitsCardinality(), rescoreResult, PerLeafResult.SearchMode.EXACT_SEARCH);
            });
        }
        return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks);
    }

    private PerLeafResult searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException {
        PerLeafResult perLeafResult = queryWeight.searchLeaf(ctx, k);
        Bits liveDocs = ctx.reader().getLiveDocs();
        if (liveDocs != null) {
            ArrayList<ScoreDoc> list = new ArrayList<ScoreDoc>();
            for (ScoreDoc scoreDoc : perLeafResult.getResult().scoreDocs) {
                if (!liveDocs.get(scoreDoc.doc)) continue;
                list.add(scoreDoc);
            }
            ScoreDoc[] filteredScoreDoc = list.toArray(new ScoreDoc[0]);
            TotalHits totalHits = new TotalHits((long)filteredScoreDoc.length, TotalHits.Relation.EQUAL_TO);
            perLeafResult.setResult(new TopDocs(totalHits, filteredScoreDoc));
        }
        return perLeafResult;
    }

    public String toString(String field) {
        return ((Object)((Object)this)).getClass().getSimpleName() + "[" + field + "]..." + KNNQuery.class.getSimpleName() + "[" + this.knnQuery.toString() + "]";
    }

    public void visit(QueryVisitor visitor) {
        visitor.visitLeaf((Query)this);
    }

    public boolean equals(Object obj) {
        if (!this.sameClassAs(obj)) {
            return false;
        }
        return this.knnQuery == ((NativeEngineKnnVectorQuery)((Object)obj)).knnQuery;
    }

    public int hashCode() {
        return Objects.hash(this.classHash(), this.knnQuery.hashCode());
    }

    @Generated
    public KNNQuery getKnnQuery() {
        return this.knnQuery;
    }

    @Generated
    public QueryUtils getQueryUtils() {
        return this.queryUtils;
    }

    @Generated
    public boolean isExpandNestedDocs() {
        return this.expandNestedDocs;
    }

    @Generated
    public NativeEngineKnnVectorQuery(KNNQuery knnQuery, QueryUtils queryUtils, boolean expandNestedDocs) {
        this.knnQuery = knnQuery;
        this.queryUtils = queryUtils;
        this.expandNestedDocs = expandNestedDocs;
    }
}

