/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.zoo.nlp.textgeneration;

import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;
import java.util.Collection;
import java.util.stream.Collectors;

public class PtGptTranslator
implements NoBatchifyTranslator<NDList, CausalLMOutput> {
    private long kvDim;
    private int numAttentionHeads;
    private int numLayers;
    private String tupleName;

    public PtGptTranslator(long kvDim, int numAttentionHeads, int numLayers) {
        this.kvDim = kvDim;
        this.numAttentionHeads = numAttentionHeads;
        this.numLayers = numLayers;
        this.tupleName = "past_key_values(" + numLayers + ',' + 2 + ')';
    }

    public NDList processInput(TranslatorContext ctx, NDList input) throws Exception {
        NDManager manager = ctx.getNDManager();
        if (input.size() == 3) {
            ctx.setAttachment("useDummyPastKeyValues", (Object)Boolean.TRUE);
            this.initialDummyPastKeyValues((NDArray)input.get(0), manager, input);
            long batchSize = ((NDArray)input.get(0)).getShape().get(0);
            NDArray attentionMask = manager.zeros(new Shape(new long[]{batchSize, 1L}), DataType.INT64).concat((NDArray)input.get(2), -1);
            input.set(2, (Object)attentionMask);
        }
        for (int i = 3; i < this.numLayers * 2 + 3; ++i) {
            ((NDArray)input.get(i)).setName(this.tupleName);
        }
        return input;
    }

    public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws Exception {
        NDArray logitsOutput = (NDArray)output.get(0);
        NDManager manager = output.getManager();
        NDList pastKeyValuesOutput = output.subNDList(1, this.numLayers * 2 + 1);
        NDArray hiddenStatesOutput = output.size() > this.numLayers * 2 + 1 ? (NDArray)output.get(this.numLayers * 2 + 1) : manager.zeros(new Shape(new long[]{1L}));
        if (ctx.getAttachment("useDummyPastKeyValues") != null) {
            NDIndex index2 = new NDIndex(":, :, 1:, ...", new Object[0]);
            pastKeyValuesOutput = new NDList((Collection)pastKeyValuesOutput.stream().map(object -> object.get(index2)).collect(Collectors.toList()));
        }
        for (NDArray array : pastKeyValuesOutput) {
            array.setName(this.tupleName);
        }
        return new CausalLMOutput(logitsOutput, hiddenStatesOutput, pastKeyValuesOutput);
    }

    private void initialDummyPastKeyValues(NDArray inputIds, NDManager manager, NDList list) {
        long numBatch = inputIds.getShape().get(0);
        for (int i = 0; i < this.numLayers * 2; ++i) {
            NDArray array = manager.zeros(new Shape(new long[]{numBatch, this.numAttentionHeads, 1L, this.kvDim}));
            list.add((Object)array);
        }
    }
}

