/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.search.aggregations.bucket;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.OptionalLong;
import java.util.function.BiConsumer;
import java.util.function.Function;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.NumericUtils;
import org.opensearch.common.Rounding;
import org.opensearch.common.lucene.search.function.FunctionScoreQuery;
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.DateRangeIncludingNowQuery;
import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceConfig;
import org.opensearch.search.aggregations.bucket.composite.RoundingValuesSource;
import org.opensearch.search.aggregations.bucket.histogram.LongBounds;
import org.opensearch.search.internal.SearchContext;

public final class FastFilterRewriteHelper {
    private static final Logger logger = LogManager.getLogger(FastFilterRewriteHelper.class);
    private static final int MAX_NUM_FILTER_BUCKETS = 1024;
    private static final Map<Class<?>, Function<Query, Query>> queryWrappers = new HashMap();

    private FastFilterRewriteHelper() {
    }

    private static Query unwrapIntoConcreteQuery(Query query) {
        while (queryWrappers.containsKey(query.getClass())) {
            query = queryWrappers.get(query.getClass()).apply(query);
        }
        return query;
    }

    private static long[] getShardBounds(SearchContext context, String fieldName) throws IOException {
        List leaves = context.searcher().getIndexReader().leaves();
        long min = Long.MAX_VALUE;
        long max = Long.MIN_VALUE;
        for (LeafReaderContext leaf : leaves) {
            PointValues values = leaf.reader().getPointValues(fieldName);
            if (values == null) continue;
            min = Math.min(min, NumericUtils.sortableBytesToLong((byte[])values.getMinPackedValue(), (int)0));
            max = Math.max(max, NumericUtils.sortableBytesToLong((byte[])values.getMaxPackedValue(), (int)0));
        }
        if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) {
            return null;
        }
        return new long[]{min, max};
    }

    private static long[] getSegmentBounds(LeafReaderContext context, String fieldName) throws IOException {
        long min = Long.MAX_VALUE;
        long max = Long.MIN_VALUE;
        PointValues values = context.reader().getPointValues(fieldName);
        if (values != null) {
            min = Math.min(min, NumericUtils.sortableBytesToLong((byte[])values.getMinPackedValue(), (int)0));
            max = Math.max(max, NumericUtils.sortableBytesToLong((byte[])values.getMaxPackedValue(), (int)0));
        }
        if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) {
            return null;
        }
        return new long[]{min, max};
    }

    public static long[] getDateHistoAggBounds(SearchContext context, String fieldName) throws IOException {
        Query cq = FastFilterRewriteHelper.unwrapIntoConcreteQuery(context.query());
        if (cq instanceof PointRangeQuery) {
            PointRangeQuery prq = (PointRangeQuery)cq;
            long[] indexBounds = FastFilterRewriteHelper.getShardBounds(context, fieldName);
            if (indexBounds == null) {
                return null;
            }
            return FastFilterRewriteHelper.getBoundsWithRangeQuery(prq, fieldName, indexBounds);
        }
        if (cq instanceof MatchAllDocsQuery) {
            return FastFilterRewriteHelper.getShardBounds(context, fieldName);
        }
        if (cq instanceof FieldExistsQuery && ((FieldExistsQuery)cq).getField().equals(fieldName)) {
            return FastFilterRewriteHelper.getShardBounds(context, fieldName);
        }
        return null;
    }

    private static long[] getBoundsWithRangeQuery(PointRangeQuery prq, String fieldName, long[] indexBounds) {
        if (prq.getField().equals(fieldName)) {
            long upper;
            long lower = Math.max(NumericUtils.sortableBytesToLong((byte[])prq.getLowerPoint(), (int)0), indexBounds[0]);
            if (lower > (upper = Math.min(NumericUtils.sortableBytesToLong((byte[])prq.getUpperPoint(), (int)0), indexBounds[1]))) {
                return null;
            }
            return new long[]{lower, upper};
        }
        return null;
    }

    private static Weight[] createFilterForAggregations(SearchContext context, DateFieldMapper.DateFieldType fieldType, long interval, Rounding.Prepared preparedRounding, long low, long high) throws IOException {
        long roundedLow;
        long prevRounded = roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low));
        int bucketCount = 0;
        while (roundedLow <= fieldType.convertNanosToMillis(high)) {
            if (++bucketCount > 1024) {
                logger.debug("Max number of filters reached [{}], skip the fast filter optimization", (Object)1024);
                return null;
            }
            if (prevRounded == (roundedLow = preparedRounding.round(roundedLow + interval))) break;
            prevRounded = roundedLow;
        }
        Weight[] filters = null;
        if (bucketCount > 0) {
            filters = new Weight[bucketCount];
            roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low));
            int i = 0;
            while (i < bucketCount) {
                byte[] lower = new byte[8];
                NumericUtils.longToSortableBytes((long)(i == 0 ? low : fieldType.convertRoundedMillisToNanos(roundedLow)), (byte[])lower, (int)0);
                roundedLow = preparedRounding.round(roundedLow + interval);
                byte[] upper = new byte[8];
                NumericUtils.longToSortableBytes((long)(i + 1 == bucketCount ? high : fieldType.convertRoundedMillisToNanos(roundedLow) - 1L), (byte[])upper, (int)0);
                filters[i++] = context.searcher().createWeight((Query)new PointRangeQuery(fieldType.name(), lower, upper, 1){

                    protected String toString(int dimension, byte[] value) {
                        return Long.toString(LongPoint.decodeDimension((byte[])value, (int)0));
                    }
                }, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
            }
        }
        return filters;
    }

    public static boolean isCompositeAggRewriteable(CompositeValuesSourceConfig[] sourceConfigs) {
        return sourceConfigs.length == 1 && sourceConfigs[0].valuesSource() instanceof RoundingValuesSource;
    }

    public static long getBucketOrd(long bucketOrd) {
        if (bucketOrd < 0L) {
            bucketOrd = -1L - bucketOrd;
        }
        return bucketOrd;
    }

    public static boolean tryFastFilterAggregation(LeafReaderContext ctx, FastFilterContext fastFilterContext, BiConsumer<Long, Integer> incrementDocCount) throws IOException {
        int i;
        if (fastFilterContext == null) {
            return false;
        }
        if (!fastFilterContext.rewriteable) {
            return false;
        }
        NumericDocValues docCountValues = DocValues.getNumeric((LeafReader)ctx.reader(), (String)"_doc_count");
        if (docCountValues.nextDoc() != Integer.MAX_VALUE) {
            logger.debug("Shard {} segment {} has at least one document with _doc_count field, skip fast filter optimization", (Object)fastFilterContext.context.indexShard().shardId(), (Object)ctx.ord);
            return false;
        }
        if (!fastFilterContext.filtersBuiltAtShardLevel && !FastFilterRewriteHelper.segmentMatchAll(fastFilterContext.context, ctx)) {
            return false;
        }
        Weight[] filters = fastFilterContext.filters;
        if (filters == null) {
            logger.debug("Shard {} segment {} functionally match all documents. Build the fast filter", (Object)fastFilterContext.context.indexShard().shardId(), (Object)ctx.ord);
            filters = fastFilterContext.buildFastFilter(ctx);
            if (filters == null) {
                return false;
            }
        }
        int[] counts = new int[filters.length];
        for (i = 0; i < filters.length; ++i) {
            counts[i] = filters[i].count(ctx);
            if (counts[i] != -1) continue;
            return false;
        }
        int s = 0;
        int size = fastFilterContext.aggregationType.getSize();
        for (i = 0; i < filters.length; ++i) {
            if (counts[i] <= 0) continue;
            long bucketKey = i;
            if (fastFilterContext.aggregationType instanceof AbstractDateHistogramAggregationType) {
                DateFieldMapper.DateFieldType fieldType = ((AbstractDateHistogramAggregationType)fastFilterContext.aggregationType).getFieldType();
                bucketKey = fieldType.convertNanosToMillis(NumericUtils.sortableBytesToLong((byte[])((PointRangeQuery)filters[i].getQuery()).getLowerPoint(), (int)0));
            }
            incrementDocCount.accept(bucketKey, counts[i]);
            if (++s > size) break;
        }
        logger.debug("Fast filter optimization applied to shard {} segment {}", (Object)fastFilterContext.context.indexShard().shardId(), (Object)ctx.ord);
        return true;
    }

    private static boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException {
        Weight weight = ctx.searcher().createWeight(ctx.query(), ScoreMode.COMPLETE_NO_SCORES, 1.0f);
        return weight != null && weight.count(leafCtx) == leafCtx.reader().numDocs();
    }

    static {
        queryWrappers.put(ConstantScoreQuery.class, q -> ((ConstantScoreQuery)q).getQuery());
        queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery)((Object)q)).getSubQuery());
        queryWrappers.put(DateRangeIncludingNowQuery.class, q -> ((DateRangeIncludingNowQuery)((Object)q)).getQuery());
        queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery)q).getIndexQuery());
    }

    public static abstract class AbstractDateHistogramAggregationType
    implements AggregationType {
        private final MappedFieldType fieldType;
        private final boolean missing;
        private final boolean hasScript;
        private LongBounds hardBounds;

        public AbstractDateHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript) {
            this.fieldType = fieldType;
            this.missing = missing;
            this.hasScript = hasScript;
        }

        public AbstractDateHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript, LongBounds hardBounds) {
            this(fieldType, missing, hasScript);
            this.hardBounds = hardBounds;
        }

        @Override
        public boolean isRewriteable(Object parent, int subAggLength) {
            if (parent == null && subAggLength == 0 && !this.missing && !this.hasScript && this.fieldType != null && this.fieldType instanceof DateFieldMapper.DateFieldType) {
                return this.fieldType.isSearchable();
            }
            return false;
        }

        @Override
        public Weight[] buildFastFilter(SearchContext context) throws IOException {
            long[] bounds = FastFilterRewriteHelper.getDateHistoAggBounds(context, this.fieldType.name());
            logger.debug("Bounds are {} for shard {}", (Object)bounds, (Object)context.indexShard().shardId());
            return this.buildFastFilter(context, bounds);
        }

        @Override
        public Weight[] buildFastFilter(LeafReaderContext leaf, SearchContext context) throws IOException {
            long[] bounds = FastFilterRewriteHelper.getSegmentBounds(leaf, this.fieldType.name());
            logger.debug("Bounds are {} for shard {} segment {}", (Object)bounds, (Object)context.indexShard().shardId(), (Object)leaf.ord);
            return this.buildFastFilter(context, bounds);
        }

        private Weight[] buildFastFilter(SearchContext context, long[] bounds) throws IOException {
            if ((bounds = this.processHardBounds(bounds)) == null) {
                return null;
            }
            assert (bounds[0] <= bounds[1]) : "Low bound should be less than high bound";
            Rounding rounding = this.getRounding(bounds[0], bounds[1]);
            OptionalLong intervalOpt = Rounding.getInterval(rounding);
            if (intervalOpt.isEmpty()) {
                return null;
            }
            long interval = intervalOpt.getAsLong();
            this.processAfterKey(bounds, interval);
            return FastFilterRewriteHelper.createFilterForAggregations(context, (DateFieldMapper.DateFieldType)this.fieldType, interval, this.getRoundingPrepared(), bounds[0], bounds[1]);
        }

        protected abstract Rounding getRounding(long var1, long var3);

        protected abstract Rounding.Prepared getRoundingPrepared();

        protected void processAfterKey(long[] bound, long interval) {
        }

        protected long[] processHardBounds(long[] bounds) {
            if (bounds != null && this.hardBounds != null) {
                if (this.hardBounds.getMin() > bounds[0]) {
                    bounds[0] = this.hardBounds.getMin();
                }
                if (this.hardBounds.getMax() - 1L < bounds[1]) {
                    bounds[1] = this.hardBounds.getMax() - 1L;
                }
                if (bounds[0] > bounds[1]) {
                    return null;
                }
            }
            return bounds;
        }

        public DateFieldMapper.DateFieldType getFieldType() {
            assert (this.fieldType instanceof DateFieldMapper.DateFieldType);
            return (DateFieldMapper.DateFieldType)this.fieldType;
        }
    }

    static interface AggregationType {
        public boolean isRewriteable(Object var1, int var2);

        public Weight[] buildFastFilter(SearchContext var1) throws IOException;

        public Weight[] buildFastFilter(LeafReaderContext var1, SearchContext var2) throws IOException;

        default public int getSize() {
            return Integer.MAX_VALUE;
        }
    }

    public static class FastFilterContext {
        private boolean rewriteable = false;
        private Weight[] filters = null;
        private boolean filtersBuiltAtShardLevel = false;
        private AggregationType aggregationType;
        private final SearchContext context;

        public FastFilterContext(SearchContext context) {
            this.context = context;
        }

        public AggregationType getAggregationType() {
            return this.aggregationType;
        }

        public void setAggregationType(AggregationType aggregationType) {
            this.aggregationType = aggregationType;
        }

        public boolean isRewriteable(Object parent, int subAggLength) {
            boolean rewriteable = this.aggregationType.isRewriteable(parent, subAggLength);
            logger.debug("Fast filter rewriteable: {} for shard {}", (Object)rewriteable, (Object)this.context.indexShard().shardId());
            this.rewriteable = rewriteable;
            return rewriteable;
        }

        public void buildFastFilter() throws IOException {
            assert (this.filters == null) : "Filters should only be built once, but they are already built";
            this.filters = this.aggregationType.buildFastFilter(this.context);
            if (this.filters != null) {
                logger.debug("Fast filter built for shard {}", (Object)this.context.indexShard().shardId());
                this.filtersBuiltAtShardLevel = true;
            }
        }

        public Weight[] buildFastFilter(LeafReaderContext leaf) throws IOException {
            Weight[] filters = this.aggregationType.buildFastFilter(leaf, this.context);
            if (filters != null) {
                logger.debug("Fast filter built for shard {} segment {}", (Object)this.context.indexShard().shardId(), (Object)leaf.ord);
            }
            return filters;
        }
    }
}

