/*
 * Decompiled with CFR 0.152.
 */
package org.freeplane.plugin.ai.chat;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.observability.api.event.ToolExecutedEvent;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
import org.freeplane.core.util.LogUtils;
import org.freeplane.plugin.ai.chat.AssistantProfileChatMemory;
import org.freeplane.plugin.ai.chat.ChatTokenCounterMode;
import org.freeplane.plugin.ai.chat.ChatTokenUsageState;
import org.freeplane.plugin.ai.chat.ChatUsageTotals;

public class ChatTokenUsageTracker {
    private final Consumer<ChatUsageTotals> totalsConsumer;
    private final List<Long> outputByTurn = new ArrayList<Long>();
    private final List<Long> inputByTurn = new ArrayList<Long>();
    private int currentTurnCount;
    private ChatTokenCounterMode counterMode = ChatTokenCounterMode.HIDDEN;
    private String counterModeLabel;

    public ChatTokenUsageTracker(Consumer<ChatUsageTotals> totalsConsumer) {
        this.totalsConsumer = Objects.requireNonNull(totalsConsumer, "totalsConsumer");
        this.publishTotals(ChatUsageTotals.hidden());
    }

    public synchronized void setCounterMode(ChatTokenCounterMode counterMode) {
        this.setCounterMode(counterMode, null);
    }

    public synchronized void setCounterMode(ChatTokenCounterMode counterMode, String counterModeLabel) {
        this.counterMode = counterMode == null ? ChatTokenCounterMode.HIDDEN : counterMode;
        this.counterModeLabel = counterModeLabel;
    }

    public synchronized void recordProviderUsage(TokenUsage tokenUsage) {
        if (tokenUsage == null) {
            return;
        }
        Integer inputCount = tokenUsage.inputTokenCount();
        Integer outputCount = tokenUsage.outputTokenCount();
        if (inputCount == null && outputCount == null) {
            return;
        }
        this.truncateHistoryAfterCurrentTurn();
        this.outputByTurn.add(outputCount == null ? null : Long.valueOf(outputCount.longValue()));
        this.inputByTurn.add(inputCount == null ? null : Long.valueOf(inputCount.longValue()));
        this.currentTurnCount = this.outputByTurn.size();
    }

    public void logToolExecuted(ToolExecutedEvent event) {
        ToolExecutionRequest request = event.request();
        LogUtils.info((String)ChatTokenUsageTracker.buildToolCallLogMessage(request));
    }

    public synchronized void resetTotals() {
        this.outputByTurn.clear();
        this.inputByTurn.clear();
        this.currentTurnCount = 0;
        this.publishTotals(this.calculateTotals(null));
    }

    public synchronized ChatTokenUsageState snapshotState() {
        return new ChatTokenUsageState(this.outputByTurn, this.inputByTurn, this.currentTurnCount);
    }

    public synchronized void restoreState(ChatTokenUsageState state) {
        if (state == null) {
            this.resetTotals();
            return;
        }
        this.outputByTurn.clear();
        this.outputByTurn.addAll(state.getOutputByTurn());
        this.inputByTurn.clear();
        this.inputByTurn.addAll(state.getInputByTurn());
        int maxSize = Math.min(this.outputByTurn.size(), this.inputByTurn.size());
        if (this.outputByTurn.size() > maxSize) {
            this.outputByTurn.subList(maxSize, this.outputByTurn.size()).clear();
        }
        if (this.inputByTurn.size() > maxSize) {
            this.inputByTurn.subList(maxSize, this.inputByTurn.size()).clear();
        }
        this.currentTurnCount = Math.max(0, Math.min(state.getCurrentTurnCount(), maxSize));
    }

    public synchronized void refreshTotals(AssistantProfileChatMemory memory, String inputLabel, String outputLabel) {
        ChatUsageTotals totals = this.calculateTotals(memory).withLabel(this.counterModeLabel).withInputOutputLabels(inputLabel, outputLabel);
        this.publishTotals(totals);
    }

    public synchronized void undoLastResponse() {
        if (this.currentTurnCount <= 0) {
            return;
        }
        --this.currentTurnCount;
    }

    public synchronized void redoLastResponse() {
        if (this.currentTurnCount >= this.outputByTurn.size()) {
            return;
        }
        ++this.currentTurnCount;
    }

    private void publishTotals(ChatUsageTotals totals) {
        this.totalsConsumer.accept(totals);
    }

    private void truncateHistoryAfterCurrentTurn() {
        if (this.currentTurnCount >= this.outputByTurn.size()) {
            return;
        }
        this.outputByTurn.subList(this.currentTurnCount, this.outputByTurn.size()).clear();
        this.inputByTurn.subList(this.currentTurnCount, this.inputByTurn.size()).clear();
    }

    private static String buildToolCallLogMessage(ToolExecutionRequest request) {
        return "Tool call: " + request.name();
    }

    private ChatUsageTotals calculateTotals(AssistantProfileChatMemory memory) {
        if (this.counterMode == ChatTokenCounterMode.HIDDEN) {
            return ChatUsageTotals.hidden();
        }
        if (this.counterMode == ChatTokenCounterMode.MODEL_RESPONSE) {
            return this.totalsFromProviderUsage().withLabel(this.counterModeLabel);
        }
        if (memory == null) {
            return new ChatUsageTotals(0L, 0L).withLabel(this.counterModeLabel);
        }
        if (this.counterMode == ChatTokenCounterMode.CONTEXT_WINDOW) {
            return memory.estimateTokenUsageForActiveWindow().withLabel(this.counterModeLabel);
        }
        return memory.estimateTokenUsageForFullConversation().withLabel(this.counterModeLabel);
    }

    private ChatUsageTotals totalsFromProviderUsage() {
        if (this.currentTurnCount <= 0) {
            return new ChatUsageTotals(0L, 0L);
        }
        int index = Math.min(this.currentTurnCount - 1, this.outputByTurn.size() - 1);
        Long output = this.outputByTurn.get(index);
        Long input = this.inputByTurn.get(index);
        long outputCount = output == null ? 0L : output;
        long inputCount = input == null ? 0L : input;
        return new ChatUsageTotals(inputCount, outputCount);
    }
}

