/*
 * Decompiled with CFR 0.152.
 */
package eponine.model;

import eponine.model.BasisFunctionWithHistory;
import java.util.AbstractSet;
import java.util.Collection;
import java.util.Iterator;
import org.biojava.bio.BioError;
import org.biojava.bio.dist.Distribution;
import org.biojava.bio.dist.DistributionTrainerContext;
import org.biojava.bio.dist.SimpleDistribution;
import org.biojava.bio.dist.SimpleDistributionTrainerContext;
import org.biojava.bio.symbol.IllegalSymbolException;
import org.biojava.bio.symbol.SimpleAlphabet;
import org.biojava.bio.symbol.Symbol;
import org.biojava.utils.ChangeVetoException;
import stats.glm.BasisFunction;
import stats.glm.BasisSource;
import stats.glm.SLMTrainingContext;

public class MultiplexedBasisSource
extends AbstractSet
implements BasisSource {
    private SimpleAlphabet alphabet = new SimpleAlphabet();
    private int reweightFrequency = 100;
    private double reweightPseudocounts = 1.0;
    private Distribution dist = null;
    private int lastReweight = 0;

    public void setReweightFrequency(int n) {
        this.reweightFrequency = n;
    }

    public void setReweightPseudocounts(double d) {
        this.reweightPseudocounts = d;
    }

    public Iterator iterator() {
        return this.alphabet.iterator();
    }

    public int size() {
        return this.alphabet.size();
    }

    public boolean add(Object object) {
        if (this.dist != null) {
            throw new BioError("Can't add sources to a MultiplexedBasisSource after calling next()");
        }
        try {
            this.alphabet.addSymbol((Symbol)object);
        }
        catch (Exception exception) {
            throw new BioError(exception);
        }
        return true;
    }

    private Distribution getDistribution() throws IllegalSymbolException {
        if (this.dist == null) {
            this.dist = new SimpleDistribution(this.alphabet);
            double d = 1.0 / (double)this.alphabet.size();
            Iterator iterator = this.alphabet.iterator();
            while (iterator.hasNext()) {
                Symbol symbol = (Symbol)iterator.next();
                try {
                    this.dist.setWeight(symbol, d);
                }
                catch (ChangeVetoException changeVetoException) {
                    throw new BioError(changeVetoException);
                }
            }
        }
        return this.dist;
    }

    public BasisFunction next(SLMTrainingContext sLMTrainingContext) {
        try {
            int n = sLMTrainingContext.getCurrentCycle();
            if (n >= this.lastReweight + this.reweightFrequency) {
                this.doReweight(sLMTrainingContext);
                this.lastReweight = n;
            }
            BasisFunction basisFunction = null;
            Distribution distribution = this.getDistribution();
            block2: while (basisFunction == null || sLMTrainingContext.getBasisList().contains(basisFunction)) {
                double d = Math.random();
                Iterator iterator = this.alphabet.iterator();
                while (iterator.hasNext()) {
                    Symbol symbol = (Symbol)iterator.next();
                    if (!((d -= distribution.getWeight(symbol)) <= 0.0)) continue;
                    BasisSource basisSource = (BasisSource)((Object)symbol);
                    if (!basisSource.hasNext(sLMTrainingContext)) continue block2;
                    basisFunction = basisSource.next(sLMTrainingContext);
                    continue block2;
                }
            }
            return basisFunction;
        }
        catch (IllegalSymbolException illegalSymbolException) {
            throw new BioError(illegalSymbolException);
        }
    }

    private void doReweight(SLMTrainingContext sLMTrainingContext) throws IllegalSymbolException {
        SimpleDistributionTrainerContext simpleDistributionTrainerContext = new SimpleDistributionTrainerContext();
        Distribution distribution = this.getDistribution();
        simpleDistributionTrainerContext.registerDistribution(distribution);
        simpleDistributionTrainerContext.setNullModelWeight(this.reweightPseudocounts);
        this.makeCounts(simpleDistributionTrainerContext, distribution, sLMTrainingContext.getBasisList(), 1.0);
        try {
            simpleDistributionTrainerContext.train();
        }
        catch (ChangeVetoException changeVetoException) {
            throw new BioError(changeVetoException);
        }
    }

    private void makeCounts(DistributionTrainerContext distributionTrainerContext, Distribution distribution, Collection collection, double d) throws IllegalSymbolException {
        Iterator iterator = collection.iterator();
        while (iterator.hasNext()) {
            BasisFunction basisFunction = (BasisFunction)iterator.next();
            if (!(basisFunction instanceof BasisFunctionWithHistory)) continue;
            BasisFunctionWithHistory basisFunctionWithHistory = (BasisFunctionWithHistory)basisFunction;
            BasisSource basisSource = basisFunctionWithHistory.getCreatrix();
            if (basisSource instanceof Symbol && this.alphabet.contains((Symbol)((Object)basisSource))) {
                distributionTrainerContext.addCount(distribution, (Symbol)((Object)basisSource), d);
            }
            this.makeCounts(distributionTrainerContext, distribution, basisFunctionWithHistory.getParents(), d * 0.5);
        }
    }

    public boolean hasNext(SLMTrainingContext sLMTrainingContext) {
        Iterator iterator = this.alphabet.iterator();
        while (iterator.hasNext()) {
            BasisSource basisSource = (BasisSource)iterator.next();
            if (!basisSource.hasNext(sLMTrainingContext)) continue;
            return true;
        }
        return false;
    }
}

