/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.get.GetRequest;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService;
import org.opensearch.ml.action.contextmanagement.ContextManagerFactory;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.contextmanager.ContextManagementTemplate;
import org.opensearch.ml.common.contextmanager.ContextManager;
import org.opensearch.ml.common.contextmanager.ContextManagerConfig;
import org.opensearch.ml.common.contextmanager.ContextManagerHookProvider;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.hooks.HookRegistry;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.indices.MLInputDatasetHandler;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.StreamTransportResponseHandler;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportException;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.client.Client;
import org.opensearch.transport.stream.StreamTransportResponse;

public class MLExecuteTaskRunner
extends MLTaskRunner<MLExecuteTaskRequest, MLExecuteTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(MLExecuteTaskRunner.class);
    private final ThreadPool threadPool;
    private final ClusterService clusterService;
    private final Client client;
    private final MLInputDatasetHandler mlInputDatasetHandler;
    protected final DiscoveryNodeHelper nodeHelper;
    private final MLEngine mlEngine;
    private final ContextManagementTemplateService contextManagementTemplateService;
    private final ContextManagerFactory contextManagerFactory;
    private volatile Boolean isPythonModelEnabled;

    public MLExecuteTaskRunner(ThreadPool threadPool, ClusterService clusterService, Client client, MLTaskManager mlTaskManager, MLStats mlStats, MLInputDatasetHandler mlInputDatasetHandler, MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService, DiscoveryNodeHelper nodeHelper, MLEngine mlEngine, ContextManagementTemplateService contextManagementTemplateService, ContextManagerFactory contextManagerFactory) {
        super(mlTaskManager, mlStats, nodeHelper, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
        this.threadPool = threadPool;
        this.clusterService = clusterService;
        this.client = client;
        this.mlInputDatasetHandler = mlInputDatasetHandler;
        this.nodeHelper = nodeHelper;
        this.mlEngine = mlEngine;
        this.contextManagementTemplateService = contextManagementTemplateService;
        this.contextManagerFactory = contextManagerFactory;
        this.isPythonModelEnabled = (Boolean)MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL.get(this.clusterService.getSettings());
        this.clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL, it -> {
            this.isPythonModelEnabled = it;
        });
    }

    @Override
    protected String getTransportActionName() {
        return "cluster:admin/opensearch/ml/execute";
    }

    @Override
    protected String getTransportStreamActionName() {
        return "cluster:admin/opensearch/ml/execute/stream";
    }

    @Override
    protected boolean isStreamingRequest(MLExecuteTaskRequest request) {
        return request.getStreamingChannel() != null;
    }

    @Override
    protected TransportResponseHandler<MLExecuteTaskResponse> getResponseHandler(ActionListener<MLExecuteTaskResponse> listener) {
        return new ActionListenerResponseHandler(listener, MLExecuteTaskResponse::new);
    }

    @Override
    protected TransportResponseHandler<MLExecuteTaskResponse> getResponseStreamHandler(MLExecuteTaskRequest request) {
        final TransportChannel channel = request.getStreamingChannel();
        return new StreamTransportResponseHandler<MLExecuteTaskResponse>(){

            public void handleStreamResponse(StreamTransportResponse<MLExecuteTaskResponse> streamResponse) {
                try {
                    MLExecuteTaskResponse response;
                    while ((response = (MLExecuteTaskResponse)streamResponse.nextResponse()) != null) {
                        channel.sendResponseBatch((TransportResponse)response);
                    }
                    channel.completeStream();
                    streamResponse.close();
                }
                catch (Exception e) {
                    streamResponse.cancel("Stream error", (Throwable)e);
                }
            }

            public void handleException(TransportException exp) {
                try {
                    channel.sendResponse((Exception)exp);
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }

            public String executor() {
                return "same";
            }

            public MLExecuteTaskResponse read(StreamInput in) throws IOException {
                return new MLExecuteTaskResponse(in);
            }
        };
    }

    @Override
    protected void executeTask(MLExecuteTaskRequest request, ActionListener<MLExecuteTaskResponse> listener) {
        TransportChannel channel = request.getStreamingChannel();
        String threadPoolName = channel != null ? "opensearch_ml_execute_stream" : "opensearch_ml_execute";
        this.threadPool.executor(threadPoolName).execute(() -> {
            try {
                this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
                this.mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
                this.mlStats.createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
                Input input = request.getInput();
                FunctionName functionName = request.getFunctionName();
                if (FunctionName.AGENT.equals((Object)functionName) && input instanceof AgentMLInput) {
                    AgentMLInput agentInput = (AgentMLInput)input;
                    this.getEffectiveContextManagementNameAsync(agentInput, (ActionListener<String>)ActionListener.wrap(contextManagementName -> {
                        if (contextManagementName != null && !contextManagementName.trim().isEmpty()) {
                            this.executeAgentWithContextManagement(request, (String)contextManagementName, channel, listener);
                        } else {
                            this.continueNormalExecution(request, channel, listener);
                        }
                    }, e -> {
                        log.debug("Failed to get context management name, continuing with normal execution: {}", (Object)e.getMessage());
                        this.continueNormalExecution(request, channel, listener);
                    }));
                    return;
                }
                if (FunctionName.METRICS_CORRELATION.equals((Object)functionName) && !this.isPythonModelEnabled.booleanValue()) {
                    IllegalArgumentException exception = new IllegalArgumentException("This algorithm is not enabled from settings");
                    listener.onFailure((Exception)exception);
                    return;
                }
                try {
                    this.mlEngine.execute(input, ActionListener.wrap(output -> {
                        MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output);
                        listener.onResponse((Object)response);
                    }, e -> listener.onFailure(e)), channel);
                }
                catch (Exception e2) {
                    log.error("Failed to execute ML function", (Throwable)e2);
                    listener.onFailure(e2);
                }
            }
            catch (Exception e3) {
                this.mlStats.createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
                listener.onFailure(e3);
            }
            finally {
                this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
            }
        });
    }

    private void executeAgentWithContextManagement(MLExecuteTaskRequest request, String contextManagementName, TransportChannel channel, ActionListener<MLExecuteTaskResponse> listener) {
        log.debug("Executing agent with context management: {}", (Object)contextManagementName);
        this.contextManagementTemplateService.getTemplate(contextManagementName, (ActionListener<ContextManagementTemplate>)ActionListener.wrap(template -> {
            if (template == null) {
                listener.onFailure((Exception)new IllegalArgumentException("Context management template not found: " + contextManagementName));
                return;
            }
            try {
                List<ContextManager> contextManagers = this.createContextManagers((ContextManagementTemplate)template);
                HookRegistry hookRegistry = this.createHookRegistry(contextManagers, (ContextManagementTemplate)template);
                AgentMLInput agentInput = (AgentMLInput)request.getInput();
                agentInput.setHookRegistry(hookRegistry);
                log.debug("Executing agent with context management template: {} using {} context managers", (Object)contextManagementName, (Object)contextManagers.size());
                try {
                    this.mlEngine.execute(request.getInput(), ActionListener.wrap(output -> {
                        log.debug("Agent execution completed successfully with context management");
                        MLExecuteTaskResponse response = new MLExecuteTaskResponse(request.getFunctionName(), output);
                        listener.onResponse((Object)response);
                    }, error -> {
                        log.error("Agent execution failed with context management", (Throwable)error);
                        listener.onFailure(error);
                    }), channel);
                }
                catch (Exception e) {
                    log.error("Failed to execute agent with context management", (Throwable)e);
                    listener.onFailure(e);
                }
            }
            catch (Exception e) {
                log.error("Failed to create context managers from template: {}", (Object)contextManagementName, (Object)e);
                listener.onFailure(e);
            }
        }, error -> {
            log.error("Failed to retrieve context management template: {}", (Object)contextManagementName, error);
            listener.onFailure(error);
        }));
    }

    private void getEffectiveContextManagementNameAsync(AgentMLInput agentInput, ActionListener<String> listener) {
        String runtimeContextManagementName = agentInput.getContextManagementName();
        if (runtimeContextManagementName != null && !runtimeContextManagementName.trim().isEmpty()) {
            log.debug("Using runtime context management name: {}", (Object)runtimeContextManagementName);
            listener.onResponse((Object)runtimeContextManagementName);
            return;
        }
        String agentId = agentInput.getAgentId();
        if (agentId != null) {
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                ActionListener internalListener = ActionListener.wrap(response -> {
                    if (response.isExists()) {
                        try {
                            String templateName;
                            XContentParser parser = JsonXContent.jsonXContent.createParser(null, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, response.getSourceAsString());
                            XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                            MLAgent mlAgent = MLAgent.parse((XContentParser)parser);
                            if (mlAgent.hasContextManagementTemplate() && (templateName = mlAgent.getContextManagementTemplateName()) != null && !templateName.trim().isEmpty()) {
                                listener.onResponse((Object)templateName);
                                return;
                            }
                        }
                        catch (Exception e) {
                            log.debug("Failed to parse agent, using fallback: {}", (Object)e.getMessage());
                        }
                    }
                    listener.onResponse((Object)this.getFallbackContextManagementName(agentInput));
                }, e -> {
                    log.debug("Failed to retrieve agent, using fallback: {}", (Object)e.getMessage());
                    listener.onResponse((Object)this.getFallbackContextManagementName(agentInput));
                });
                this.client.get(new GetRequest(".plugins-ml-agent", agentId), ActionListener.runBefore((ActionListener)internalListener, () -> ((ThreadContext.StoredContext)context).restore()));
            }
            return;
        }
        listener.onResponse((Object)this.getFallbackContextManagementName(agentInput));
    }

    private String getFallbackContextManagementName(AgentMLInput agentInput) {
        if (agentInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
            RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet)agentInput.getInputDataset();
            String contextManagementProcessed = (String)dataset.getParameters().get("context_management_processed");
            if ("true".equals(contextManagementProcessed)) {
                log.debug("Context management already processed by MLAgentExecutor, skipping template lookup");
                return null;
            }
            String runtimeContextManagementName = (String)dataset.getParameters().get("context_management_name");
            if (runtimeContextManagementName != null && !runtimeContextManagementName.trim().isEmpty()) {
                log.debug("Using runtime context management name from parameters: {}", (Object)runtimeContextManagementName);
                return runtimeContextManagementName;
            }
        }
        return null;
    }

    private void continueNormalExecution(MLExecuteTaskRequest request, TransportChannel channel, ActionListener<MLExecuteTaskResponse> listener) {
        Input input = request.getInput();
        FunctionName functionName = request.getFunctionName();
        if (FunctionName.METRICS_CORRELATION.equals((Object)functionName) && !this.isPythonModelEnabled.booleanValue()) {
            IllegalArgumentException exception = new IllegalArgumentException("This algorithm is not enabled from settings");
            listener.onFailure((Exception)exception);
            return;
        }
        try {
            this.mlEngine.execute(input, ActionListener.wrap(output -> {
                MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output);
                listener.onResponse((Object)response);
            }, e -> listener.onFailure(e)), channel);
        }
        catch (Exception e2) {
            log.error("Failed to execute ML function", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private List<ContextManager> createContextManagers(ContextManagementTemplate template) {
        ArrayList<ContextManager> contextManagers = new ArrayList<ContextManager>();
        for (Map.Entry entry : template.getHooks().entrySet()) {
            String hookName = (String)entry.getKey();
            List configs = (List)entry.getValue();
            for (ContextManagerConfig config : configs) {
                try {
                    ContextManager manager = this.contextManagerFactory.createContextManager(config);
                    if (manager != null) {
                        contextManagers.add(manager);
                        log.debug("Created context manager: {} for hook: {}", (Object)config.getType(), (Object)hookName);
                        continue;
                    }
                    log.warn("Failed to create context manager of type: {}", (Object)config.getType());
                }
                catch (Exception e) {
                    log.error("Error creating context manager of type: {}", (Object)config.getType(), (Object)e);
                }
            }
        }
        log.info("Created {} context managers from template: {}", (Object)contextManagers.size(), (Object)template.getName());
        return contextManagers;
    }

    private HookRegistry createHookRegistry(List<ContextManager> contextManagers, ContextManagementTemplate template) {
        HookRegistry hookRegistry = new HookRegistry();
        if (!contextManagers.isEmpty()) {
            ContextManagerHookProvider hookProvider = new ContextManagerHookProvider(contextManagers, template.getHooks());
            hookProvider.registerHooks(hookRegistry);
            log.debug("Registered context manager hooks for {} managers", (Object)contextManagers.size());
        }
        return hookRegistry;
    }
}

