/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.analysis;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Generated;
import org.opensearch.sql.analysis.AnalysisContext;
import org.opensearch.sql.analysis.QualifierAnalyzer;
import org.opensearch.sql.analysis.TypeEnvironment;
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.AllFields;
import org.opensearch.sql.ast.expression.And;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.HighlightFunction;
import org.opensearch.sql.ast.expression.In;
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.RelevanceFieldList;
import org.opensearch.sql.ast.expression.ScoreFunction;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedArgument;
import org.opensearch.sql.ast.expression.UnresolvedAttribute;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.expression.subquery.ExistsSubquery;
import org.opensearch.sql.ast.expression.subquery.InSubquery;
import org.opensearch.sql.ast.expression.subquery.ScalarSubquery;
import org.opensearch.sql.calcite.utils.CalciteUtils;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.HighlightExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.conditional.cases.CaseClause;
import org.opensearch.sql.expression.conditional.cases.WhenClause;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.span.SpanExpression;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
import shaded.com.google.common.collect.ImmutableList;
import shaded.com.google.common.collect.ImmutableMap;
import shaded.com.google.common.collect.ImmutableSet;

public class ExpressionAnalyzer
extends AbstractNodeVisitor<Expression, AnalysisContext> {
    private final BuiltinFunctionRepository repository;

    @Override
    public Expression visitCast(Cast node, AnalysisContext context) {
        Expression expression = node.getExpression().accept(this, context);
        return (Expression)((Object)this.repository.compile(context.getFunctionProperties(), node.convertFunctionName(), Collections.singletonList(expression)));
    }

    public ExpressionAnalyzer(BuiltinFunctionRepository repository) {
        this.repository = repository;
    }

    public Expression analyze(UnresolvedExpression unresolved, AnalysisContext context) {
        return unresolved.accept(this, context);
    }

    @Override
    public Expression visitUnresolvedAttribute(UnresolvedAttribute node, AnalysisContext context) {
        return this.visitIdentifier(node.getAttr(), context);
    }

    @Override
    public Expression visitEqualTo(EqualTo node, AnalysisContext context) {
        Expression left = node.getLeft().accept(this, context);
        Expression right = node.getRight().accept(this, context);
        return DSL.equal(left, right);
    }

    @Override
    public Expression visitLiteral(Literal node, AnalysisContext context) {
        return DSL.literal(ExprValueUtils.fromObjectValue(node.getValue(), node.getType().getCoreType()));
    }

    @Override
    public Expression visitInterval(Interval node, AnalysisContext context) {
        Expression value = node.getValue().accept(this, context);
        LiteralExpression unit = DSL.literal(node.getUnit().name());
        return DSL.interval(value, unit);
    }

    @Override
    public Expression visitAnd(And node, AnalysisContext context) {
        Expression left = node.getLeft().accept(this, context);
        Expression right = node.getRight().accept(this, context);
        return DSL.and(left, right);
    }

    @Override
    public Expression visitOr(Or node, AnalysisContext context) {
        Expression left = node.getLeft().accept(this, context);
        Expression right = node.getRight().accept(this, context);
        return DSL.or(left, right);
    }

    @Override
    public Expression visitXor(Xor node, AnalysisContext context) {
        Expression left = node.getLeft().accept(this, context);
        Expression right = node.getRight().accept(this, context);
        return DSL.xor(left, right);
    }

    @Override
    public Expression visitNot(Not node, AnalysisContext context) {
        return DSL.not(node.getExpression().accept(this, context));
    }

    @Override
    public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) {
        Optional<BuiltinFunctionName> builtinFunctionName = BuiltinFunctionName.ofAggregation(node.getFuncName());
        if (builtinFunctionName.isPresent()) {
            ImmutableList.Builder builder = ImmutableList.builder();
            builder.add(node.getField().accept(this, context));
            for (UnresolvedExpression arg : node.getArgList()) {
                builder.add(arg.accept(this, context));
            }
            Aggregator aggregator = (Aggregator)this.repository.compile(context.getFunctionProperties(), builtinFunctionName.get().getName(), (List<Expression>)((Object)builder.build()));
            aggregator.distinct(node.getDistinct());
            if (node.condition() != null) {
                aggregator.condition(this.analyze(node.condition(), context));
            }
            return aggregator;
        }
        throw new SemanticCheckException("Unsupported aggregation function " + node.getFuncName());
    }

    @Override
    public Expression visitRelevanceFieldList(RelevanceFieldList node, AnalysisContext context) {
        return new LiteralExpression(ExprValueUtils.tupleValue(ImmutableMap.copyOf(node.getFieldList())));
    }

    @Override
    public Expression visitFunction(Function node, AnalysisContext context) {
        FunctionName functionName = FunctionName.of(node.getFuncName());
        this.validateCalciteOnlyFunction(functionName);
        List<Expression> arguments2 = node.getFuncArgs().stream().map(unresolvedExpression -> {
            Expression ret = this.analyze((UnresolvedExpression)unresolvedExpression, context);
            if (ret == null) {
                throw new UnsupportedOperationException(String.format("Invalid use of expression %s", unresolvedExpression));
            }
            return ret;
        }).collect(Collectors.toList());
        return (Expression)((Object)this.repository.compile(context.getFunctionProperties(), functionName, arguments2));
    }

    private void validateCalciteOnlyFunction(FunctionName functionName) {
        if (this.isCalciteOnlyFunction(functionName)) {
            throw CalciteUtils.getOnlyForCalciteException(functionName.getFunctionName().toUpperCase());
        }
    }

    private boolean isCalciteOnlyFunction(FunctionName functionName) {
        ImmutableSet<String> calciteOnlyFunctions = ImmutableSet.of(BuiltinFunctionName.REGEX_MATCH.getName().getFunctionName(), BuiltinFunctionName.STRFTIME.getName().getFunctionName());
        return calciteOnlyFunctions.stream().anyMatch(f -> f.equalsIgnoreCase(functionName.getFunctionName()));
    }

    @Override
    public Expression visitWindowFunction(WindowFunction node, AnalysisContext context) {
        Expression expr = node.getFunction().accept(this, context);
        if (expr instanceof Aggregator) {
            return new AggregateWindowFunction((Aggregator)expr);
        }
        return expr;
    }

    @Override
    public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext context) {
        Expression expr = node.getHighlightField().accept(this, context);
        return new HighlightExpression(expr);
    }

    @Override
    public Expression visitScoreFunction(ScoreFunction node, AnalysisContext context) {
        Literal boostArg = node.getRelevanceFieldWeight();
        if (!boostArg.getType().equals((Object)DataType.DOUBLE)) {
            throw new SemanticCheckException(String.format("Expected boost type '%s' but got '%s'", DataType.DOUBLE.name(), boostArg.getType().name()));
        }
        Double thisBoostValue = (Double)boostArg.getValue();
        Function relevanceQueryUnresolvedExpr = (Function)node.getRelevanceQuery();
        List<UnresolvedExpression> relevanceFuncArgs = relevanceQueryUnresolvedExpr.getFuncArgs();
        boolean doesFunctionContainBoostArgument = false;
        ArrayList<UnresolvedExpression> updatedFuncArgs = new ArrayList<UnresolvedExpression>();
        for (UnresolvedExpression expr : relevanceFuncArgs) {
            String argumentName = ((UnresolvedArgument)expr).getArgName();
            if (argumentName.equalsIgnoreCase("boost")) {
                doesFunctionContainBoostArgument = true;
                Literal boostArgLiteral = (Literal)((UnresolvedArgument)expr).getValue();
                Double boostValue = Double.parseDouble((String)boostArgLiteral.getValue()) * thisBoostValue;
                UnresolvedArgument newBoostArg = new UnresolvedArgument(argumentName, new Literal(boostValue.toString(), DataType.STRING));
                updatedFuncArgs.add(newBoostArg);
                continue;
            }
            updatedFuncArgs.add(expr);
        }
        if (!doesFunctionContainBoostArgument) {
            UnresolvedArgument newBoostArg = new UnresolvedArgument("boost", new Literal(Double.toString(thisBoostValue), DataType.STRING));
            updatedFuncArgs.add(newBoostArg);
        }
        Function updatedRelevanceQueryUnresolvedExpr = new Function(relevanceQueryUnresolvedExpr.getFuncName(), updatedFuncArgs);
        OpenSearchFunctions.OpenSearchFunction relevanceQueryExpr = (OpenSearchFunctions.OpenSearchFunction)updatedRelevanceQueryUnresolvedExpr.accept(this, context);
        relevanceQueryExpr.setScoreTracked(true);
        return relevanceQueryExpr;
    }

    @Override
    public Expression visitIn(In node, AnalysisContext context) {
        return this.visitIn(node.getField(), node.getValueList(), context);
    }

    private Expression visitIn(UnresolvedExpression field, List<UnresolvedExpression> valueList, AnalysisContext context) {
        if (valueList.isEmpty()) {
            throw new SemanticCheckException("Values in In clause should not be empty");
        }
        Expression[] expressions = new Expression[valueList.size()];
        for (int i = 0; i < expressions.length; ++i) {
            expressions[i] = this.visitCompare(new Compare("=", field, valueList.get(i)), context);
        }
        return this.buildOrTree(expressions, 0, expressions.length);
    }

    private Expression buildOrTree(Expression[] children, int start, int end) {
        if (end - start <= 1) {
            return children[start];
        }
        if (end - start == 2) {
            return DSL.or(children[start], children[end - 1]);
        }
        int split = start + (end - start) / 2;
        return DSL.or(this.buildOrTree(children, start, split), this.buildOrTree(children, split, end));
    }

    @Override
    public Expression visitCompare(Compare node, AnalysisContext context) {
        FunctionName functionName = FunctionName.of(node.getOperator());
        Expression left = this.analyze(node.getLeft(), context);
        Expression right = this.analyze(node.getRight(), context);
        return (Expression)((Object)this.repository.compile(context.getFunctionProperties(), functionName, Arrays.asList(left, right)));
    }

    @Override
    public Expression visitBetween(Between node, AnalysisContext context) {
        return AstDSL.and(AstDSL.compare(">=", node.getValue(), node.getLowerBound()), AstDSL.compare("<=", node.getValue(), node.getUpperBound())).accept(this, context);
    }

    @Override
    public Expression visitCase(Case node, AnalysisContext context) {
        ArrayList<WhenClause> whens = new ArrayList<WhenClause>();
        for (When when : node.getWhenClauses()) {
            if (node.getCaseValue() == null) {
                whens.add((WhenClause)this.analyze(when, context));
                continue;
            }
            whens.add((WhenClause)this.analyze(new When(new Function("=", Arrays.asList(node.getCaseValue(), when.getCondition())), when.getResult()), context));
        }
        Expression defaultResult = node.getElseClause().map(elseClause -> this.analyze((UnresolvedExpression)elseClause, context)).orElse(null);
        CaseClause caseClause = new CaseClause(whens, defaultResult);
        List<ExprType> resultTypes = caseClause.allResultTypes();
        if (ImmutableSet.copyOf(resultTypes).size() > 1) {
            throw new SemanticCheckException("All result types of CASE clause must be the same, but found " + String.valueOf(resultTypes));
        }
        return caseClause;
    }

    @Override
    public Expression visitWhen(When node, AnalysisContext context) {
        return new WhenClause(this.analyze(node.getCondition(), context), this.analyze(node.getResult(), context));
    }

    @Override
    public Expression visitField(Field node, AnalysisContext context) {
        return this.visitQualifiedName((QualifiedName)node.getField(), context);
    }

    @Override
    public Expression visitAllFields(AllFields node, AnalysisContext context) {
        return DSL.literal("*");
    }

    @Override
    public Expression visitQualifiedName(QualifiedName node, AnalysisContext context) {
        QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context);
        for (String part : node.getParts()) {
            for (TypeEnvironment typeEnv = context.peek(); typeEnv != null; typeEnv = typeEnv.getParent()) {
                Optional<ExprType> exprType = Optional.ofNullable(typeEnv.lookupAllFields(Namespace.HIDDEN_FIELD_NAME).get(part));
                if (!exprType.isPresent()) continue;
                return this.visitMetadata(qualifierAnalyzer.unqualified(node), (ExprCoreType)exprType.get(), context);
            }
        }
        return this.visitIdentifier(qualifierAnalyzer.unqualified(node), context);
    }

    @Override
    public Expression visitSpan(Span node, AnalysisContext context) {
        return new SpanExpression(node.getField().accept(this, context), node.getValue().accept(this, context), node.getUnit());
    }

    @Override
    public Expression visitUnresolvedArgument(UnresolvedArgument node, AnalysisContext context) {
        return new NamedArgumentExpression(node.getArgName(), node.getValue().accept(this, context));
    }

    @Override
    public Expression visitArgument(Argument node, AnalysisContext context) {
        return new NamedArgumentExpression(node.getArgName(), node.getValue().accept(this, context));
    }

    @Override
    public Expression visitScalarSubquery(ScalarSubquery node, AnalysisContext context) {
        throw CalciteUtils.getOnlyForCalciteException("Subsearch");
    }

    @Override
    public Expression visitExistsSubquery(ExistsSubquery node, AnalysisContext context) {
        throw CalciteUtils.getOnlyForCalciteException("Subsearch");
    }

    @Override
    public Expression visitInSubquery(InSubquery node, AnalysisContext context) {
        throw CalciteUtils.getOnlyForCalciteException("Subsearch");
    }

    @Override
    public Expression visitLambdaFunction(LambdaFunction node, AnalysisContext context) {
        throw CalciteUtils.getOnlyForCalciteException("Lambda function");
    }

    private Expression visitMetadata(String ident, ExprCoreType exprCoreType, AnalysisContext context) {
        return DSL.ref(ident, exprCoreType);
    }

    private Expression visitIdentifier(String ident, AnalysisContext context) {
        for (NamedExpression expr : context.getNamedParseExpressions()) {
            if (!expr.getNameOrAlias().equals(ident) || !(expr.getDelegated() instanceof ParseExpression)) continue;
            return expr.getDelegated();
        }
        TypeEnvironment typeEnv = context.peek();
        ReferenceExpression ref = DSL.ref(ident, typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, ident)));
        return ref;
    }

    @Generated
    public BuiltinFunctionRepository getRepository() {
        return this.repository;
    }
}

