/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.rules.ImmutableJoinLeftSingleRuleConfig;
import org.apache.calcite.rel.rules.ImmutableJoinLeftSingleValueRuleWithExprConfig;
import org.apache.calcite.rel.rules.ImmutableJoinRightSingleRuleConfig;
import org.apache.calcite.rel.rules.ImmutableJoinRightSingleValueRuleWithExprConfig;
import org.apache.calcite.rel.rules.SubstitutionRule;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

public abstract class SingleValuesOptimizationRules {
    public static final RelOptRule JOIN_LEFT_INSTANCE = JoinLeftSingleRuleConfig.DEFAULT.toRule();
    public static final RelOptRule JOIN_RIGHT_INSTANCE = JoinRightSingleRuleConfig.DEFAULT.toRule();
    public static final RelOptRule JOIN_LEFT_PROJECT_INSTANCE = JoinLeftSingleValueRuleWithExprConfig.DEFAULT.toRule();
    public static final RelOptRule JOIN_RIGHT_PROJECT_INSTANCE = JoinRightSingleValueRuleWithExprConfig.DEFAULT.toRule();

    @Value.Immutable
    static interface JoinRightSingleValueRuleWithExprConfig
    extends PruneSingleValueRule.Config {
        public static final JoinRightSingleValueRuleWithExprConfig DEFAULT = ImmutableJoinRightSingleValueRuleWithExprConfig.of().withOperandSupplier(b0 -> b0.operand(Join.class).inputs(b1 -> b1.operand(RelNode.class).anyInputs(), b2 -> b2.operand(Project.class).inputs(b21 -> b21.operand(Values.class).predicate(Values::isSingleValue).noInputs()))).withDescription("PruneJoinSingleValueRuleWithExpr(right)");

        @Override
        default public PruneSingleValueRule toRule() {
            return new PruneSingleValueRule(this){

                @Override
                public void onMatch(RelOptRuleCall call) {
                    Join join = (Join)call.rel(0);
                    Object other = call.rel(1);
                    Project project = (Project)call.rel(2);
                    Values values2 = (Values)call.rel(3);
                    this.onMatch(call, values2, project, join, (RelNode)other, false);
                }
            };
        }
    }

    @Value.Immutable
    static interface JoinLeftSingleValueRuleWithExprConfig
    extends PruneSingleValueRule.Config {
        public static final JoinLeftSingleValueRuleWithExprConfig DEFAULT = ImmutableJoinLeftSingleValueRuleWithExprConfig.of().withOperandSupplier(b0 -> b0.operand(Join.class).inputs(b1 -> b1.operand(Project.class).inputs(b11 -> b11.operand(Values.class).predicate(Values::isSingleValue).noInputs()), b2 -> b2.operand(RelNode.class).anyInputs())).withDescription("PruneJoinSingleValueRuleWithExpr(left)");

        @Override
        default public PruneSingleValueRule toRule() {
            return new PruneSingleValueRule(this){

                @Override
                public void onMatch(RelOptRuleCall call) {
                    Join join = (Join)call.rel(0);
                    Project project = (Project)call.rel(1);
                    Values values2 = (Values)call.rel(2);
                    Object other = call.rel(3);
                    this.onMatch(call, values2, project, join, (RelNode)other, true);
                }
            };
        }
    }

    @Value.Immutable
    static interface JoinLeftSingleRuleConfig
    extends PruneSingleValueRule.Config {
        public static final JoinLeftSingleRuleConfig DEFAULT = ImmutableJoinLeftSingleRuleConfig.of().withOperandSupplier(b0 -> b0.operand(Join.class).inputs(b1 -> b1.operand(Values.class).predicate(Values::isSingleValue).noInputs(), b2 -> b2.operand(RelNode.class).anyInputs())).withDescription("PruneJoinSingleValueRule(left)");

        @Override
        default public PruneSingleValueRule toRule() {
            return new PruneSingleValueRule(this){

                @Override
                public void onMatch(RelOptRuleCall call) {
                    Join join = (Join)call.rel(0);
                    Values values2 = (Values)call.rel(1);
                    Object other = call.rel(2);
                    this.onMatch(call, values2, null, join, (RelNode)other, true);
                }
            };
        }
    }

    @Value.Immutable
    static interface JoinRightSingleRuleConfig
    extends PruneSingleValueRule.Config {
        public static final JoinRightSingleRuleConfig DEFAULT = ImmutableJoinRightSingleRuleConfig.of().withOperandSupplier(b0 -> b0.operand(Join.class).inputs(b1 -> b1.operand(RelNode.class).anyInputs(), b2 -> b2.operand(Values.class).predicate(Values::isSingleValue).noInputs())).withDescription("PruneJoinSingleValue(right)");

        @Override
        default public PruneSingleValueRule toRule() {
            return new PruneSingleValueRule(this){

                @Override
                public void onMatch(RelOptRuleCall call) {
                    Join join = (Join)call.rel(0);
                    Values values2 = (Values)call.rel(2);
                    Object other = call.rel(1);
                    this.onMatch(call, values2, null, join, (RelNode)other, false);
                }
            };
        }
    }

    protected static abstract class PruneSingleValueRule
    extends RelRule<Config>
    implements SubstitutionRule {
        protected PruneSingleValueRule(Config config) {
            super(config);
        }

        protected BiFunction<RexNode, List<RexNode>, List<RexNode>> getRexTransformer(RexBuilder rexBuilder, JoinRelType joinRelType) {
            switch (joinRelType) {
                case LEFT: 
                case RIGHT: {
                    return (condition, rexLiterals) -> rexLiterals.stream().map(lit -> rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, (RexNode)condition, (RexNode)lit, (RexNode)rexBuilder.makeNullLiteral(lit.getType()))).collect(Collectors.toList());
                }
            }
            return (condition, rexLiterals) -> rexLiterals;
        }

        protected void onMatch(RelOptRuleCall call, Values values2, @Nullable Project project, Join join, RelNode other, boolean isOnLefSide) {
            List<RexNode> rexNodes;
            Predicate<Join> transformableCheck = PruneSingleValueRule.isJoinTransformable(isOnLefSide);
            if (project != null) {
                ImmutableBitSet bitSet = ImmutableBitSet.range(0, values2.getRowType().getFieldCount());
                RexNodeReplacer shuttle = new RexNodeReplacer(bitSet, new ArrayList<RexNode>((Collection)values2.tuples.get(0)), 0);
                rexNodes = project.getProjects().stream().map(shuttle::apply).collect(Collectors.toList());
            } else {
                rexNodes = new ArrayList<RexNode>((Collection)values2.tuples.get(0));
            }
            RelBuilder relBuilder = call.builder();
            BiFunction<RexNode, List<RexNode>, List<RexNode>> transformer = this.getRexTransformer(relBuilder.getRexBuilder(), join.getJoinType());
            SingleValuesRelTransformer relTransformer = new SingleValuesRelTransformer(join, rexNodes, other, transformableCheck, isOnLefSide, transformer);
            RelNode transformedRelNode = relTransformer.transform(relBuilder);
            if (transformedRelNode != null) {
                call.transformTo(transformedRelNode);
            }
        }

        static Predicate<Join> isJoinTransformable(boolean isLeft) {
            Predicate<Join> isFullOrAntiJoin = jn -> jn.getJoinType() == JoinRelType.ANTI || jn.getJoinType() == JoinRelType.FULL;
            if (isLeft) {
                return jn -> jn.getJoinType() != JoinRelType.LEFT && !isFullOrAntiJoin.test((Join)jn);
            }
            return jn -> jn.getJoinType() != JoinRelType.RIGHT && !isFullOrAntiJoin.test((Join)jn);
        }

        @Override
        public boolean autoPruneOld() {
            return true;
        }

        protected static interface Config
        extends RelRule.Config {
            @Override
            public PruneSingleValueRule toRule();
        }
    }

    private static class RexNodeReplacer
    extends RexShuttle {
        private final ImmutableBitSet bitSet;
        private final List<RexNode> fieldValues;
        private final int offset;

        RexNodeReplacer(ImmutableBitSet bitSet, List<RexNode> values2, int offset) {
            this.bitSet = bitSet;
            this.fieldValues = values2;
            this.offset = offset;
        }

        @Override
        public RexNode visitInputRef(RexInputRef inputRef) {
            if (this.bitSet.get(inputRef.getIndex())) {
                return this.fieldValues.get(inputRef.getIndex() + this.offset);
            }
            return super.visitInputRef(inputRef);
        }

        public RexNode go(RexNode expression) {
            return expression.accept(this);
        }
    }

    private static class SingleValuesRelTransformer {
        private final Join join;
        private final RelNode relNode;
        private final Predicate<Join> transformable;
        private final BiFunction<RexNode, List<RexNode>, List<RexNode>> litTransformer;
        private final boolean valuesAsLeftChild;
        private final List<RexNode> literals;

        protected SingleValuesRelTransformer(Join join, List<RexNode> rexNodes, RelNode otherNode, Predicate<Join> transformable, boolean isValuesLeftChild, BiFunction<RexNode, List<RexNode>, List<RexNode>> litTransformer) {
            this.relNode = otherNode;
            this.join = join;
            this.transformable = transformable;
            this.litTransformer = litTransformer;
            this.valuesAsLeftChild = isValuesLeftChild;
            this.literals = rexNodes;
        }

        public @Nullable RelNode transform(RelBuilder relBuilder) {
            if (!this.transformable.test(this.join)) {
                return null;
            }
            int end = this.valuesAsLeftChild ? this.join.getLeft().getRowType().getFieldCount() : this.join.getRowType().getFieldCount();
            int start = this.valuesAsLeftChild ? 0 : this.join.getLeft().getRowType().getFieldCount();
            ImmutableBitSet bitSet = ImmutableBitSet.range(start, end);
            RexLiteral trueNode = relBuilder.getRexBuilder().makeLiteral(true);
            RexNode filterCondition = new RexNodeReplacer(bitSet, this.literals, (this.valuesAsLeftChild ? 0 : -1) * this.join.getLeft().getRowType().getFieldCount()).go(this.join.getCondition());
            RexNode fixedCondition = this.valuesAsLeftChild ? RexUtil.shift(filterCondition, -1 * this.join.getLeft().getRowType().getFieldCount()) : filterCondition;
            List<RexNode> rexLiterals = this.litTransformer.apply(fixedCondition, this.literals);
            relBuilder.push(this.relNode).filter(this.join.getJoinType().isOuterJoin() ? trueNode : fixedCondition);
            List rexNodes = this.relNode.getRowType().getFieldList().stream().map(fld -> relBuilder.field(fld.getIndex())).collect(Collectors.toList());
            ArrayList<RexNode> projects = new ArrayList<RexNode>();
            projects.addAll(this.valuesAsLeftChild ? rexLiterals : rexNodes);
            projects.addAll(this.valuesAsLeftChild ? rexNodes : rexLiterals);
            return relBuilder.project(projects).build();
        }
    }
}

