/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCZeros;
import org.apache.sysds.runtime.compress.colgroup.FORUtil;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

public class ColGroupSDCFOR
extends AMorphingMMColGroup {
    private static final long serialVersionUID = 3883228464052204203L;
    protected AOffset _indexes;
    protected AMapToData _data;
    protected double[] _reference;

    protected ColGroupSDCFOR(int numRows) {
        super(numRows);
    }

    private ColGroupSDCFOR(int[] colIndices, int numRows, ADictionary dict, AOffset indexes, AMapToData data, int[] cachedCounts, double[] reference) {
        super(colIndices, numRows, dict, cachedCounts);
        if (data.getUnique() != dict.getNumberOfValues(colIndices.length)) {
            throw new DMLCompressionException("Invalid construction of SDCZero group");
        }
        this._data = data;
        this._indexes = indexes;
        this._zeros = false;
        this._reference = reference;
    }

    protected static AColGroup create(int[] colIndexes, int numRows, ADictionary dict, AOffset offsets, AMapToData data, int[] cachedCounts, double[] reference) {
        boolean allZero = FORUtil.allZero(reference);
        if (allZero && dict == null) {
            return new ColGroupEmpty(colIndexes);
        }
        if (dict == null) {
            return ColGroupConst.create(colIndexes, reference);
        }
        if (allZero) {
            return ColGroupSDCZeros.create(colIndexes, numRows, dict, offsets, data, cachedCounts);
        }
        return new ColGroupSDCFOR(colIndexes, numRows, dict, offsets, data, cachedCounts, reference);
    }

    @Override
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.SDCFOR;
    }

    @Override
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.SDCFOR;
    }

    @Override
    public int[] getCounts(int[] counts) {
        return this._data.getCounts(counts);
    }

    @Override
    protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) {
        ColGroupSDC.computeRowSums(c, rl, ru, preAgg, this._data, this._indexes, this._numRows);
    }

    @Override
    protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg) {
        ColGroupSDC.computeRowMxx(c, builtin, rl, ru, preAgg, this._data, this._indexes, this._numRows, preAgg[preAgg.length - 1]);
    }

    @Override
    public double getIdx(int r, int colIdx) {
        AIterator it = this._indexes.getIterator(r);
        int nCol = this._colIndexes.length;
        if (it == null || it.value() != r) {
            return this._reference[colIdx];
        }
        int rowOff = this._data.getIndex(it.getDataIndex()) * nCol;
        return this._dict.getValue(rowOff + colIdx) + this._reference[colIdx];
    }

    @Override
    public AColGroup scalarOperation(ScalarOperator op) {
        double[] newRef = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; ++i) {
            newRef[i] = op.executeScalar(this._reference[i]);
        }
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            return ColGroupSDCFOR.create(this._colIndexes, this._numRows, this._dict, this._indexes, this._data, this.getCachedCounts(), newRef);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            ADictionary newDict = this._dict.applyScalarOp(op);
            return ColGroupSDCFOR.create(this._colIndexes, this._numRows, newDict, this._indexes, this._data, this.getCachedCounts(), newRef);
        }
        ADictionary newDict = this._dict.applyScalarOpWithReference(op, this._reference, newRef);
        return ColGroupSDCFOR.create(this._colIndexes, this._numRows, newDict, this._indexes, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public AColGroup unaryOperation(UnaryOperator op) {
        double[] newRef = FORUtil.unaryOperator(op, this._reference);
        ADictionary newDict = this._dict.applyUnaryOpWithReference(op, this._reference, newRef);
        return ColGroupSDCFOR.create(this._colIndexes, this._numRows, newDict, this._indexes, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) {
        double[] newRef = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; ++i) {
            newRef[i] = op.fn.execute(v[this._colIndexes[i]], this._reference[i]);
        }
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            return ColGroupSDCFOR.create(this._colIndexes, this._numRows, this._dict, this._indexes, this._data, this.getCachedCounts(), newRef);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            ADictionary newDict = this._dict.binOpLeft(op, v, this._colIndexes);
            return ColGroupSDCFOR.create(this._colIndexes, this._numRows, newDict, this._indexes, this._data, this.getCachedCounts(), newRef);
        }
        ADictionary newDict = this._dict.binOpLeftWithReference(op, v, this._colIndexes, this._reference, newRef);
        return ColGroupSDCFOR.create(this._colIndexes, this._numRows, newDict, this._indexes, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) {
        double[] newRef = new double[this._reference.length];
        for (int i = 0; i < this._reference.length; ++i) {
            newRef[i] = op.fn.execute(this._reference[i], v[this._colIndexes[i]]);
        }
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            return ColGroupSDCFOR.create(this._colIndexes, this._numRows, this._dict, this._indexes, this._data, this.getCachedCounts(), newRef);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            ADictionary newDict = this._dict.binOpRight(op, v, this._colIndexes);
            return ColGroupSDCFOR.create(this._colIndexes, this._numRows, newDict, this._indexes, this._data, this.getCachedCounts(), newRef);
        }
        ADictionary newDict = this._dict.binOpRightWithReference(op, v, this._colIndexes, this._reference, newRef);
        return ColGroupSDCFOR.create(this._colIndexes, this._numRows, newDict, this._indexes, this._data, this.getCachedCounts(), newRef);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);
        this._indexes.write(out);
        this._data.write(out);
        for (double d : this._reference) {
            out.writeDouble(d);
        }
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        super.readFields(in);
        this._indexes = OffsetFactory.readIn(in);
        this._data = MapToFactory.readIn(in);
        this._reference = new double[this._colIndexes.length];
        for (int i = 0; i < this._colIndexes.length; ++i) {
            this._reference[i] = in.readDouble();
        }
    }

    @Override
    public long getExactSizeOnDisk() {
        long ret = super.getExactSizeOnDisk();
        ret += this._data.getExactSizeOnDisk();
        ret += this._indexes.getExactSizeOnDisk();
        return ret += (long)(8 * this._colIndexes.length);
    }

    @Override
    public long estimateInMemorySize() {
        long size = super.estimateInMemorySize();
        size += this._indexes.getInMemorySize();
        size += this._data.getInMemorySize();
        return size += (long)(8 * this._colIndexes.length);
    }

    @Override
    public AColGroup replace(double pattern, double replace) {
        boolean patternInReference = false;
        for (double d : this._reference) {
            if (pattern != d) continue;
            patternInReference = true;
            break;
        }
        if (patternInReference) {
            throw new NotImplementedException("Not Implemented replace where a value in reference should be replaced");
        }
        ADictionary newDict = this._dict.replaceWithReference(pattern, replace, this._reference);
        return ColGroupSDCFOR.create(this._colIndexes, this._numRows, newDict, this._indexes, this._data, this.getCachedCounts(), this._reference);
    }

    @Override
    protected double computeMxx(double c, Builtin builtin) {
        return this._dict.aggregateWithReference(c, builtin, this._reference, true);
    }

    @Override
    protected void computeColMxx(double[] c, Builtin builtin) {
        this._dict.aggregateColsWithReference(c, builtin, this._colIndexes, this._reference, true);
    }

    @Override
    protected void computeSum(double[] c, int nRows) {
        super.computeSum(c, nRows);
        double refSum = FORUtil.refSum(this._reference);
        c[0] = c[0] + refSum * (double)nRows;
    }

    @Override
    public void computeColSums(double[] c, int nRows) {
        super.computeColSums(c, nRows);
        for (int i = 0; i < this._colIndexes.length; ++i) {
            int n = this._colIndexes[i];
            c[n] = c[n] + this._reference[i] * (double)nRows;
        }
    }

    @Override
    protected void computeSumSq(double[] c, int nRows) {
        c[0] = c[0] + this._dict.sumSqWithReference(this.getCounts(), this._reference);
        double refSum = FORUtil.refSumSq(this._reference);
        c[0] = c[0] + refSum * (double)(this._numRows - this._data.size());
    }

    @Override
    protected void computeColSumsSq(double[] c, int nRows) {
        this._dict = this._dict.getMBDict(this._colIndexes.length);
        this._dict.colSumSqWithReference(c, this.getCounts(), this._colIndexes, this._reference);
        for (int i = 0; i < this._colIndexes.length; ++i) {
            int n = this._colIndexes[i];
            c[n] = c[n] + this._reference[i] * this._reference[i] * (double)(this._numRows - this._data.size());
        }
    }

    @Override
    protected double[] preAggSumRows() {
        return this._dict.sumAllRowsToDoubleWithReference(this._reference);
    }

    @Override
    protected double[] preAggSumSqRows() {
        return this._dict.sumAllRowsToDoubleSqWithReference(this._reference);
    }

    @Override
    protected double[] preAggProductRows() {
        throw new NotImplementedException();
    }

    @Override
    protected double[] preAggBuiltinRows(Builtin builtin) {
        return this._dict.aggregateRowsWithReference(builtin, this._reference);
    }

    @Override
    protected void computeProduct(double[] c, int nRows) {
        int count = this._numRows - this._data.size();
        this._dict.productWithReference(c, this.getCounts(), this._reference, count);
    }

    @Override
    protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) {
        throw new NotImplementedException("Not Implemented PFOR");
    }

    @Override
    protected void computeColProduct(double[] c, int nRows) {
        throw new NotImplementedException("Not Implemented PFOR");
    }

    @Override
    protected AColGroup sliceSingleColumn(int idx) {
        ColGroupSDCFOR ret = (ColGroupSDCFOR)super.sliceSingleColumn(idx);
        ret._reference = new double[1];
        ret._reference[0] = this._reference[idx];
        return ret;
    }

    @Override
    protected AColGroup sliceMultiColumns(int idStart, int idEnd, int[] outputCols) {
        ColGroupSDCFOR ret = (ColGroupSDCFOR)super.sliceMultiColumns(idStart, idEnd, outputCols);
        int len = idEnd - idStart;
        ret._reference = new double[len];
        int i = 0;
        int ii = idStart;
        while (i < len) {
            ret._reference[i] = this._reference[ii];
            ++i;
            ++ii;
        }
        return ret;
    }

    @Override
    public boolean containsValue(double pattern) {
        if (pattern == 0.0 && this._zeros) {
            return true;
        }
        if (Double.isNaN(pattern) || Double.isInfinite(pattern)) {
            return FORUtil.containsInfOrNan(pattern, this._reference) || this._dict.containsValue(pattern);
        }
        return this._dict.containsValueWithReference(pattern, this._reference);
    }

    @Override
    public long getNumberNonZeros(int nRows) {
        int[] counts = this.getCounts();
        int count = this._numRows - this._data.size();
        long c = this._dict.getNumberNonZerosWithReference(counts, this._reference, nRows);
        for (int x = 0; x < this._colIndexes.length; ++x) {
            c += this._reference[x] != 0.0 ? (long)count : 0L;
        }
        return c;
    }

    @Override
    public AColGroup extractCommon(double[] constV) {
        for (int i = 0; i < this._colIndexes.length; ++i) {
            int n = this._colIndexes[i];
            constV[n] = constV[n] + this._reference[i];
        }
        return ColGroupSDCZeros.create(this._colIndexes, this._numRows, this._dict, this._indexes, this._data, this.getCounts());
    }

    @Override
    public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
        ADictionary d = this._dict.rexpandColsWithReference(max, ignore, cast, this._reference[0]);
        return ColGroupSDC.rexpandCols(max, ignore, cast, nRows, d, this._indexes, this._data, this.getCachedCounts(), this._reference[0]);
    }

    @Override
    public CM_COV_Object centralMoment(CMOperator op, int nRows) {
        CM_COV_Object ret = this._dict.centralMomentWithReference(op.fn, this.getCounts(), this._reference[0], nRows);
        int count = this._numRows - this._data.size();
        op.fn.execute(ret, this._reference[0], count);
        return ret;
    }

    @Override
    public double getCost(ComputationCostEstimator e, int nRows) {
        int nVals = this.getNumValues();
        int nCols = this.getNumCols();
        int nRowsScanned = this._data.size();
        return e.getCost(nRows, nRowsScanned, nCols, nVals, this._dict.getSparsity());
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s", "Indexes: "));
        sb.append(this._indexes.toString());
        sb.append(String.format("\n%15s", "Data: "));
        sb.append(this._data);
        sb.append(String.format("\n%15s", "Reference:"));
        sb.append(Arrays.toString(this._reference));
        return sb.toString();
    }
}

