/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.planner.physical;

import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataframe.ColumnMeta;
import org.opensearch.ml.common.dataframe.ColumnValue;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataframe.Row;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.sql.data.model.ExprBooleanValue;
import org.opensearch.sql.data.model.ExprDoubleValue;
import org.opensearch.sql.data.model.ExprFloatValue;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprLongValue;
import org.opensearch.sql.data.model.ExprShortValue;
import org.opensearch.sql.data.model.ExprStringValue;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.opensearch.client.MLClient;
import org.opensearch.sql.planner.physical.PhysicalPlan;
import org.opensearch.transport.client.node.NodeClient;

public abstract class MLCommonsOperatorActions
extends PhysicalPlan {
    protected DataFrame generateInputDataset(PhysicalPlan input) {
        MLInputRows inputData = new MLInputRows();
        while (input.hasNext()) {
            inputData.addTupleValue(((ExprValue)input.next()).tupleValue());
        }
        return inputData.toDataFrame();
    }

    protected List<Pair<DataFrame, DataFrame>> generateCategorizedInputDataset(PhysicalPlan input, String categoryField) {
        HashMap<ExprValue, MLInputRows> inputMap = new HashMap<ExprValue, MLInputRows>();
        while (input.hasNext()) {
            Map tupleValue = ((ExprValue)input.next()).tupleValue();
            ExprValue categoryValue = categoryField == null ? null : (ExprValue)tupleValue.get(categoryField);
            MLInputRows inputData2 = inputMap.computeIfAbsent(categoryValue, k -> new MLInputRows());
            inputData2.addTupleValue(tupleValue);
        }
        return inputMap.values().stream().filter(inputData -> inputData.size() > 0).map(inputData -> new ImmutablePair((Object)inputData.toDataFrame(), (Object)inputData.toFilteredDataFrame(e -> !((String)e.getKey()).equals(categoryField)))).collect(Collectors.toList());
    }

    protected Map<String, ExprValue> convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) {
        ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder();
        for (int i = 0; i < columnMetas.length; ++i) {
            ColumnValue columnValue = row.getValue(i);
            String resultKeyName = columnMetas[i].getName();
            this.populateResultBuilder(columnValue, resultKeyName, (ImmutableMap.Builder<String, ExprValue>)resultBuilder);
        }
        return resultBuilder.build();
    }

    protected void populateResultBuilder(ColumnValue columnValue, String resultKeyName, ImmutableMap.Builder<String, ExprValue> resultBuilder) {
        switch (columnValue.columnType()) {
            case INTEGER: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprIntegerValue((Number)columnValue.intValue()));
                break;
            }
            case DOUBLE: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprDoubleValue((Number)columnValue.doubleValue()));
                break;
            }
            case STRING: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprStringValue(columnValue.stringValue()));
                break;
            }
            case SHORT: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprShortValue((Number)columnValue.shortValue()));
                break;
            }
            case LONG: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprLongValue((Number)columnValue.longValue()));
                break;
            }
            case FLOAT: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprFloatValue((Number)Float.valueOf(columnValue.floatValue())));
                break;
            }
            case BOOLEAN: {
                resultBuilder.put((Object)resultKeyName, (Object)ExprBooleanValue.of((Boolean)columnValue.booleanValue()));
                break;
            }
        }
    }

    protected Map<String, ExprValue> convertResultRowIntoExprValue(ColumnMeta[] columnMetas, Row row, Map<String, ExprValue> schema) {
        ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder();
        for (int i = 0; i < columnMetas.length; ++i) {
            ColumnValue columnValue = row.getValue(i);
            Object resultKeyName = columnMetas[i].getName();
            if (schema.containsKey(resultKeyName)) {
                resultKeyName = (String)resultKeyName + "1";
            }
            this.populateResultBuilder(columnValue, (String)resultKeyName, (ImmutableMap.Builder<String, ExprValue>)resultBuilder);
        }
        return resultBuilder.build();
    }

    protected ExprTupleValue buildResult(Iterator<Row> inputRowIter, DataFrame inputDataFrame, MLPredictionOutput predictionResult, Iterator<Row> resultRowIter) {
        ImmutableMap.Builder resultSchemaBuilder = new ImmutableMap.Builder();
        resultSchemaBuilder.putAll(this.convertRowIntoExprValue(inputDataFrame.columnMetas(), inputRowIter.next()));
        ImmutableMap resultSchema = resultSchemaBuilder.build();
        ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder();
        resultBuilder.putAll(this.convertResultRowIntoExprValue(predictionResult.getPredictionResult().columnMetas(), resultRowIter.next(), (Map<String, ExprValue>)resultSchema));
        resultBuilder.putAll((Map)resultSchema);
        return ExprTupleValue.fromExprValueMap((Map)resultBuilder.build());
    }

    protected MLPredictionOutput getMLPredictionResult(FunctionName functionName, MLAlgoParams mlAlgoParams, DataFrame inputDataFrame, NodeClient nodeClient) {
        MLInput mlinput = MLInput.builder().algorithm(functionName).parameters(mlAlgoParams).inputDataset((MLInputDataset)new DataFrameInputDataset(inputDataFrame)).build();
        MachineLearningNodeClient machineLearningClient = MLClient.getMLClient(nodeClient);
        return (MLPredictionOutput)machineLearningClient.trainAndPredict(mlinput).actionGet(30L, TimeUnit.SECONDS);
    }

    protected MLOutput getMLOutput(DataFrame inputDataFrame, Map<String, Object> arguments, NodeClient nodeClient) {
        MLInput mlinput = MLInput.builder().inputDataset((MLInputDataset)new DataFrameInputDataset(inputDataFrame)).algorithm(FunctionName.SAMPLE_ALGO).parameters((MLAlgoParams)new SampleAlgoParams(Integer.valueOf(0))).build();
        MachineLearningNodeClient machineLearningClient = MLClient.getMLClient(nodeClient);
        return (MLOutput)machineLearningClient.run(mlinput, arguments).actionGet(30L, TimeUnit.SECONDS);
    }

    protected ExprTupleValue buildPPLResult(boolean isPredict, Iterator<Row> inputRowIter, DataFrame inputDataFrame, MLOutput mlResult, Iterator<Row> resultRowIter) {
        if (isPredict) {
            return this.buildResult(inputRowIter, inputDataFrame, (MLPredictionOutput)mlResult, resultRowIter);
        }
        return this.buildTrainResult((MLTrainingOutput)mlResult);
    }

    protected ExprTupleValue buildTrainResult(MLTrainingOutput trainResult) {
        ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder();
        resultBuilder.put((Object)"model_id", (Object)new ExprStringValue(trainResult.getModelId()));
        resultBuilder.put((Object)"task_id", (Object)new ExprStringValue(trainResult.getTaskId()));
        resultBuilder.put((Object)"status", (Object)new ExprStringValue(trainResult.getStatus()));
        return ExprTupleValue.fromExprValueMap((Map)resultBuilder.build());
    }

    private static class MLInputRows
    extends LinkedList<Map<String, Object>> {
        private MLInputRows() {
        }

        public void addTupleValue(Map<String, ExprValue> tupleValue) {
            if (tupleValue.values().stream().anyMatch(e -> e.isNull() || e.isMissing())) {
                return;
            }
            this.add(tupleValue.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((ExprValue)e.getValue()).value())));
        }

        public DataFrame toDataFrame() {
            return DataFrameBuilder.load((List)this);
        }

        public DataFrame toFilteredDataFrame(Predicate<Map.Entry<String, Object>> filter) {
            return DataFrameBuilder.load(this.stream().map(row -> row.entrySet().stream().filter(filter).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))).collect(Collectors.toList()));
        }
    }
}

