/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.calcite.rel.metadata.RelMdMeasure;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.ImmutableAggregateMeasure2RuleConfig;
import org.apache.calcite.rel.rules.ImmutableAggregateMeasureRuleConfig;
import org.apache.calcite.rel.rules.ImmutableFilterSortMeasureRuleConfig;
import org.apache.calcite.rel.rules.ImmutableProjectMeasureRuleConfig;
import org.apache.calcite.rel.rules.ImmutableProjectSortMeasureRuleConfig;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlInternalOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.MonotonicSupplier;
import org.apache.calcite.util.Util;
import org.immutables.value.Value;
import shaded.com.google.common.base.Suppliers;
import shaded.com.google.common.collect.ImmutableList;
import shaded.com.google.common.collect.ImmutableSet;
import shaded.com.google.common.collect.Iterables;

public abstract class MeasureRules {
    public static final RelOptRule AGGREGATE = AggregateMeasureRuleConfig.DEFAULT.toRule();
    public static final RelOptRule PROJECT = ProjectMeasureRuleConfig.DEFAULT.toRule();
    public static final RelOptRule AGGREGATE2 = AggregateMeasure2RuleConfig.DEFAULT.toRule();
    public static final RelOptRule FILTER_SORT = FilterSortMeasureRuleConfig.DEFAULT.as(FilterSortMeasureRuleConfig.class).toRule();
    public static final RelOptRule PROJECT_SORT = ProjectSortMeasureRuleConfig.DEFAULT.as(ProjectSortMeasureRuleConfig.class).toRule();

    private MeasureRules() {
    }

    public static Iterable<? extends RelOptRule> rules() {
        return ImmutableList.of(AGGREGATE2, PROJECT, PROJECT_SORT);
    }

    @Value.Immutable
    public static interface ProjectSortMeasureRuleConfig
    extends RelRule.Config {
        public static final ProjectSortMeasureRuleConfig DEFAULT = ImmutableProjectSortMeasureRuleConfig.of().withOperandSupplier(b -> b.operand(Project.class).predicate(RexUtil.M2V_FINDER::inProject).oneInput(b2 -> b2.operand(Sort.class).anyInputs()));

        @Override
        default public ProjectSortMeasureRule toRule() {
            return new ProjectSortMeasureRule(this);
        }
    }

    public static class ProjectSortMeasureRule
    extends RelRule<ProjectSortMeasureRuleConfig>
    implements TransformationRule {
        protected ProjectSortMeasureRule(ProjectSortMeasureRuleConfig config) {
            super(config);
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Project project = (Project)call.rel(0);
            Sort sort = (Sort)call.rel(1);
            final RelBuilder relBuilder = call.builder();
            final List<RexNode> projects = project.getAliasedProjects(relBuilder);
            final LinkedHashMap map = new LinkedHashMap();
            List<RexNode> newProjects = new RexShuttle(){

                @Override
                public RexNode visitCall(RexCall call) {
                    if (call.getKind() == SqlKind.M2V) {
                        return map.computeIfAbsent(call, c -> relBuilder.getRexBuilder().makeInputRef(call.getType(), projects.size() + map.size()));
                    }
                    return super.visitCall(call);
                }
            }.apply(projects);
            relBuilder.push(sort.getInput()).projectPlus(map.keySet()).sortLimit(sort.offset == null ? 0 : RexLiteral.intValue(sort.offset), sort.fetch == null ? -1 : RexLiteral.intValue(sort.fetch), sort.getSortExps()).project(newProjects);
            call.transformTo(relBuilder.build());
        }
    }

    public static class FilterSortMeasureRule
    extends RelRule<FilterSortMeasureRuleConfig>
    implements TransformationRule {
        protected FilterSortMeasureRule(FilterSortMeasureRuleConfig config) {
            super(config);
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Filter filter = (Filter)call.rel(0);
            RexNode condition = filter.getCondition();
            if (condition.equals(filter.getCondition())) {
                return;
            }
            RelBuilder relBuilder = this.relBuilderFactory.create(filter.getCluster(), null);
            relBuilder.push(filter.getInput()).filter(condition);
            call.transformTo(relBuilder.build());
        }
    }

    @Value.Immutable
    public static interface FilterSortMeasureRuleConfig
    extends RelRule.Config {
        public static final FilterSortMeasureRuleConfig DEFAULT = ImmutableFilterSortMeasureRuleConfig.of().withOperandSupplier(b -> b.operand(Filter.class).oneInput(b2 -> b2.operand(Sort.class).anyInputs()));

        @Override
        default public FilterSortMeasureRule toRule() {
            return new FilterSortMeasureRule(this);
        }
    }

    public static class ProjectMeasureRule
    extends RelRule<ProjectMeasureRuleConfig>
    implements TransformationRule {
        protected ProjectMeasureRule(ProjectMeasureRuleConfig config) {
            super(config);
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Aggregate aggregate = (Aggregate)call.rel(0);
            Project project = (Project)call.rel(1);
            RelBuilder b = call.builder();
            b.push(project).aggregateRex(b.groupKey(aggregate.getGroupSet(), (Iterable<? extends ImmutableBitSet>)aggregate.getGroupSets()), true, Util.transform(aggregate.getAggCallList(), aggregateCall -> ProjectMeasureRule.toRex(aggregateCall, project)));
            call.transformTo(b.build());
        }

        private static RexNode toRex(AggregateCall aggregateCall, Project project) {
            switch (aggregateCall.getAggregation().kind) {
                case SINGLE_VALUE: {
                    int arg = Iterables.getOnlyElement(aggregateCall.getArgList());
                    RexNode e = project.getProjects().get(arg);
                    switch (e.getKind()) {
                        case M2X: {
                            RexCall callM2x = (RexCall)e;
                            switch (((RexNode)callM2x.operands.get(0)).getKind()) {
                                case V2M: {
                                    RexCall callV2m = (RexCall)callM2x.operands.get(0);
                                    return (RexNode)callV2m.operands.get(0);
                                }
                            }
                            throw new UnsupportedOperationException();
                        }
                    }
                    throw new UnsupportedOperationException();
                }
            }
            throw new UnsupportedOperationException();
        }
    }

    public static class AggregateMeasure2Rule
    extends RelRule<AggregateMeasure2RuleConfig>
    implements TransformationRule {
        protected AggregateMeasure2Rule(AggregateMeasure2RuleConfig config) {
            super(config);
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            RelMetadataQuery mq = call.getMetadataQuery();
            Aggregate aggregate = (Aggregate)call.rel(0);
            RelBuilder b = call.builder();
            b.push(aggregate.getInput());
            MonotonicSupplier<RexCorrelVariable> holder = MonotonicSupplier.empty();
            ArrayList aggCallList = new ArrayList();
            ArrayList projects = new ArrayList();
            b.variable(holder).let(b2 -> {
                aggregate.getGroupSet().forEachInt(i -> projects.add(b4 -> b4.field(i)));
                Supplier<RelBuilder> builderSupplier = Suppliers.memoize(call::builder)::get;
                BuiltInMetadata.Measure.Context context = RelMdMeasure.Contexts.forAggregate(aggregate, builderSupplier, (RexCorrelVariable)holder.get());
                aggregate.getAggCallList().forEach(c -> {
                    if (c.getAggregation().kind == SqlKind.AGG_M2V) {
                        int arg = Iterables.getOnlyElement(c.getArgList());
                        aggCallList.add(b3 -> b3.aggregateCall(SqlInternalOperators.AGG_M2M, b3.fields(c.getArgList())).filter(c.filterArg < 0 ? null : b3.field(c.filterArg)));
                        RelMdMeasure.DelegatingContext context2 = new RelMdMeasure.DelegatingContext(context, (AggregateCall)c){
                            final /* synthetic */ AggregateCall val$c;
                            {
                                this.val$c = aggregateCall;
                                super(context);
                            }

                            @Override
                            public List<RexNode> getFilters(RelBuilder b) {
                                ImmutableList.Builder builder = ImmutableList.builder();
                                builder.addAll(super.getFilters(b));
                                if (this.val$c.filterArg >= 0) {
                                    builder.add(b.field(this.val$c.filterArg));
                                }
                                return builder.build();
                            }
                        };
                        projects.add(b4 -> mq.expand(b4.peek(), arg, context2));
                    } else {
                        int i = aggregate.getGroupSet().cardinality() + aggCallList.size();
                        aggCallList.add(b3 -> b3.aggregateCall((AggregateCall)c).filter(c.filterArg < 0 ? null : b3.field(c.filterArg)));
                        projects.add(b4 -> b4.field(i));
                    }
                });
                return b2;
            });
            b.aggregate(b.groupKey(aggregate.getGroupSet(), (Iterable<? extends ImmutableBitSet>)aggregate.groupSets), AggregateMeasure2Rule.bind(aggCallList).apply(b));
            b.project(AggregateMeasure2Rule.bind(projects).apply(b), aggregate.getRowType().getFieldNames(), false, ImmutableSet.of(holder.get().id));
            call.transformTo(b.build());
        }

        private static <T, E> Function<T, List<E>> bind(List<Function<T, E>> list) {
            return t2 -> {
                ImmutableList.Builder builder = ImmutableList.builder();
                list.forEach(f -> builder.add(f.apply(t2)));
                return builder.build();
            };
        }
    }

    @Value.Immutable
    public static interface AggregateMeasure2RuleConfig
    extends RelRule.Config {
        public static final AggregateMeasure2RuleConfig DEFAULT = ImmutableAggregateMeasure2RuleConfig.of().withOperandSupplier(b -> b.operand(Aggregate.class).predicate(b2 -> b2.getAggCallList().stream().anyMatch(c -> c.getAggregation() == SqlInternalOperators.AGG_M2V)).anyInputs());

        @Override
        default public AggregateMeasure2Rule toRule() {
            return new AggregateMeasure2Rule(this);
        }
    }

    @Value.Immutable
    public static interface ProjectMeasureRuleConfig
    extends RelRule.Config {
        public static final ProjectMeasureRuleConfig DEFAULT = ImmutableProjectMeasureRuleConfig.of().withOperandSupplier(b -> b.operand(Aggregate.class).predicate(aggregate -> aggregate.getAggCallList().stream().allMatch(c -> c.getAggregation() == SqlStdOperatorTable.SINGLE_VALUE)).oneInput(b2 -> b2.operand(Project.class).predicate(RexUtil.find(SqlKind.V2M)::inProject).anyInputs()));

        @Override
        default public ProjectMeasureRule toRule() {
            return new ProjectMeasureRule(this);
        }
    }

    public static class AggregateMeasureRule
    extends RelRule<AggregateMeasureRuleConfig>
    implements TransformationRule {
        protected AggregateMeasureRule(AggregateMeasureRuleConfig config) {
            super(config);
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Aggregate aggregate = (Aggregate)call.rel(0);
            RelBuilder b = call.builder();
            b.push(aggregate.getInput());
            ArrayList aggCallList = new ArrayList();
            ArrayList extraProjects = new ArrayList();
            aggregate.getAggCallList().forEach(c -> {
                if (c.getAggregation().kind == SqlKind.AGG_M2V) {
                    int arg = Iterables.getOnlyElement(c.getArgList());
                    int i = b.fields().size() + extraProjects.size();
                    extraProjects.add(b.call(SqlInternalOperators.M2X, b.field(arg), b.call(SqlInternalOperators.SAME_PARTITION, b.fields(aggregate.getGroupSet()))));
                    aggCallList.add(b2 -> b2.aggregateCall(SqlStdOperatorTable.SINGLE_VALUE, b2.field(i)));
                } else {
                    aggCallList.add(b2 -> b2.aggregateCall((AggregateCall)c));
                }
            });
            b.projectPlus(extraProjects);
            b.aggregate(b.groupKey(aggregate.getGroupSet(), (Iterable<? extends ImmutableBitSet>)aggregate.groupSets), AggregateMeasureRule.bind(aggCallList).apply(b));
            call.transformTo(b.build());
        }

        private static <T, E> Function<T, List<E>> bind(List<Function<T, E>> list) {
            return t2 -> {
                ImmutableList.Builder builder = ImmutableList.builder();
                list.forEach(f -> builder.add(f.apply(t2)));
                return builder.build();
            };
        }
    }

    @Value.Immutable
    public static interface AggregateMeasureRuleConfig
    extends RelRule.Config {
        public static final AggregateMeasureRuleConfig DEFAULT = ImmutableAggregateMeasureRuleConfig.of().withOperandSupplier(b -> b.operand(Aggregate.class).predicate(b2 -> b2.getAggCallList().stream().anyMatch(c -> c.getAggregation() == SqlInternalOperators.AGG_M2V)).anyInputs());

        @Override
        default public AggregateMeasureRule toRule() {
            return new AggregateMeasureRule(this);
        }
    }
}

