package de.jstacs.utils;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.IntSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.WrongSequenceTypeException;
import de.jstacs.sequenceScores.SequenceScore;
import de.jstacs.sequenceScores.statisticalModels.StatisticalModel;
import java.io.IOException;

/* loaded from: input_file:de/jstacs/utils/StatisticalModelTester.class */
public class StatisticalModelTester {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/jstacs/utils/StatisticalModelTester$SeqIterator.class */
    public static class SeqIterator {
        private int[] seq;
        private boolean simple;
        private int[] a;
        private int l;
        private int last;
        private AlphabetContainer abc;

        private SeqIterator(AlphabetContainer alphabetContainer, int i) throws IllegalArgumentException {
            if (!alphabetContainer.isDiscrete()) {
                throw new IllegalArgumentException("The model is not discrete.");
            }
            this.abc = alphabetContainer;
            this.simple = alphabetContainer.isSimple();
            this.a = new int[(this.simple ? 1 : i) + 1];
            int i2 = 0;
            while (i2 < this.a.length - 1) {
                this.a[i2] = ((int) alphabetContainer.getAlphabetLengthAt(i2)) - 1;
                i2++;
            }
            this.a[i2] = 1;
            this.l = i;
            this.last = this.l - 1;
            this.seq = new int[i + 1];
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean next() {
            int i = 0;
            while (true) {
                if (this.seq[i] != this.a[this.simple ? 0 : i]) {
                    break;
                }
                int i2 = i;
                i++;
                this.seq[i2] = 0;
            }
            int[] iArr = this.seq;
            int i3 = i;
            iArr[i3] = iArr[i3] + 1;
            return this.seq[this.l] == 0;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean isSatisfied(int[] iArr) {
            int i = 0;
            while (i < iArr.length && (iArr[i] == -1 || iArr[i] == this.seq[i])) {
                i++;
            }
            return i == iArr.length;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public Sequence getSequence() throws WrongAlphabetException, WrongSequenceTypeException {
            return new IntSequence(this.abc, this.seq, 0, this.l);
        }
    }

    public static double getKLDivergence(StatisticalModel statisticalModel, StatisticalModel statisticalModel2, int i) throws Exception {
        SeqIterator seqIterator = new SeqIterator(statisticalModel.getAlphabetContainer(), i);
        double d = 0.0d;
        do {
            Sequence sequence = seqIterator.getSequence();
            double logProbFor = statisticalModel.getLogProbFor(sequence);
            d += Math.exp(logProbFor) * (logProbFor - statisticalModel2.getLogProbFor(sequence));
        } while (seqIterator.next());
        return d;
    }

    public static double getSymKLDivergence(StatisticalModel statisticalModel, StatisticalModel statisticalModel2, int i) throws Exception {
        SeqIterator seqIterator = new SeqIterator(statisticalModel.getAlphabetContainer(), i);
        double d = 0.0d;
        do {
            Sequence sequence = seqIterator.getSequence();
            double logProbFor = statisticalModel.getLogProbFor(sequence);
            double logProbFor2 = statisticalModel2.getLogProbFor(sequence);
            d += (Math.exp(logProbFor) - Math.exp(logProbFor2)) * (logProbFor - logProbFor2);
        } while (seqIterator.next());
        return d;
    }

    public static double getLogLikelihood(StatisticalModel statisticalModel, DataSet dataSet) throws Exception {
        return getLogLikelihood(statisticalModel, dataSet, null);
    }

    public static double getLogLikelihood(StatisticalModel statisticalModel, DataSet dataSet, double[] dArr) throws Exception {
        int numberOfElements = dataSet.getNumberOfElements();
        double d = 0.0d;
        DataSet.ElementEnumerator elementEnumerator = new DataSet.ElementEnumerator(dataSet);
        if (dArr == null) {
            for (int i = 0; i < numberOfElements; i++) {
                d += statisticalModel.getLogProbFor(elementEnumerator.nextElement());
            }
        } else {
            if (numberOfElements != dArr.length) {
                throw new IllegalArgumentException("The weights and the data set does not match.");
            }
            for (int i2 = 0; i2 < numberOfElements; i2++) {
                d += dArr[i2] * statisticalModel.getLogProbFor(elementEnumerator.nextElement());
            }
        }
        return d;
    }

    public static double getMarginalDistribution(StatisticalModel statisticalModel, int[] iArr) throws Exception {
        if (statisticalModel.getLength() != 0 && statisticalModel.getLength() != iArr.length) {
            throw new IOException("This model can only classify sequences of length " + statisticalModel.getLength() + ".");
        }
        double d = Double.NEGATIVE_INFINITY;
        SeqIterator seqIterator = new SeqIterator(statisticalModel.getAlphabetContainer(), iArr.length);
        do {
            if (seqIterator.isSatisfied(iArr)) {
                d = Normalisation.getLogSum(d, statisticalModel.getLogProbFor(seqIterator.getSequence()));
            }
        } while (seqIterator.next());
        return Math.exp(d);
    }

    public static double getMaxOfDeviation(StatisticalModel statisticalModel, StatisticalModel statisticalModel2, int i) throws Exception {
        if (statisticalModel.getLength() != 0 && statisticalModel.getLength() != i) {
            throw new IOException("The model m1 can only classify sequences of length " + statisticalModel.getLength() + ".");
        }
        if (statisticalModel2.getLength() != 0 && statisticalModel2.getLength() != i) {
            throw new IOException("This model m2 can only classify sequences of length " + statisticalModel2.getLength() + ".");
        }
        if (!statisticalModel.getAlphabetContainer().checkConsistency(statisticalModel2.getAlphabetContainer())) {
            throw new IOException("The models are training on different alphabets.");
        }
        double d = 0.0d;
        SeqIterator seqIterator = new SeqIterator(statisticalModel.getAlphabetContainer(), i);
        do {
            Sequence sequence = seqIterator.getSequence();
            double abs = Math.abs(Math.exp(statisticalModel.getLogProbFor(sequence, 0, seqIterator.last)) - Math.exp(statisticalModel2.getLogProbFor(sequence, 0, seqIterator.last)));
            if (abs > d) {
                d = abs;
            }
        } while (seqIterator.next());
        return d;
    }

    public static Sequence getMostProbableSequence(SequenceScore sequenceScore, int i) throws Exception {
        SeqIterator seqIterator = new SeqIterator(sequenceScore.getAlphabetContainer(), i);
        Sequence sequence = seqIterator.getSequence();
        double logScoreFor = sequenceScore.getLogScoreFor(sequence);
        while (seqIterator.next()) {
            Sequence sequence2 = seqIterator.getSequence();
            double logScoreFor2 = sequenceScore.getLogScoreFor(sequence2);
            if (logScoreFor2 > logScoreFor) {
                logScoreFor = logScoreFor2;
                sequence = sequence2;
            }
        }
        return sequence;
    }

    public static double getShannonEntropy(StatisticalModel statisticalModel, int i) throws Exception {
        if (statisticalModel.getLength() != 0 && statisticalModel.getLength() != i) {
            throw new IOException("This model can only classify sequences of length " + statisticalModel.getLength() + ".");
        }
        double d = 0.0d;
        SeqIterator seqIterator = new SeqIterator(statisticalModel.getAlphabetContainer(), i);
        do {
            double logProbFor = statisticalModel.getLogProbFor(seqIterator.getSequence());
            if (!Double.isInfinite(logProbFor)) {
                d -= Math.exp(logProbFor) * logProbFor;
            }
            if (logProbFor > 0.0d) {
                throw new IOException("The probability of sequence " + seqIterator.getSequence() + " is not correct (" + Math.exp(logProbFor) + ").");
            }
        } while (seqIterator.next());
        return d;
    }

    public static double getShannonEntropyInBits(StatisticalModel statisticalModel, int i) throws Exception {
        return getShannonEntropy(statisticalModel, i) / Math.log(2.0d);
    }

    public static double getSumOfDeviation(StatisticalModel statisticalModel, StatisticalModel statisticalModel2, int i) throws Exception {
        if (statisticalModel.getLength() != 0 && statisticalModel.getLength() != i) {
            throw new IOException("The model m1 can only classify sequences of length " + statisticalModel.getLength() + ".");
        }
        if (statisticalModel2.getLength() != 0 && statisticalModel2.getLength() != i) {
            throw new IOException("This model m2 can only classify sequences of length " + statisticalModel2.getLength() + ".");
        }
        if (!statisticalModel.getAlphabetContainer().checkConsistency(statisticalModel2.getAlphabetContainer())) {
            throw new IOException("The models are training on different alphabets.");
        }
        double d = 0.0d;
        SeqIterator seqIterator = new SeqIterator(statisticalModel.getAlphabetContainer(), i);
        do {
            Sequence sequence = seqIterator.getSequence();
            d += Math.abs(Math.exp(statisticalModel.getLogProbFor(sequence, 0, seqIterator.last)) - Math.exp(statisticalModel2.getLogProbFor(sequence, 0, seqIterator.last)));
        } while (seqIterator.next());
        return d;
    }

    public static double getSumOfDistribution(StatisticalModel statisticalModel, int i) throws Exception {
        if (statisticalModel.getLength() != 0 && statisticalModel.getLength() != i) {
            throw new IOException("This model can only classify sequences of length " + statisticalModel.getLength() + ".");
        }
        double d = Double.NEGATIVE_INFINITY;
        SeqIterator seqIterator = new SeqIterator(statisticalModel.getAlphabetContainer(), i);
        do {
            double logProbFor = statisticalModel.getLogProbFor(seqIterator.getSequence());
            if (logProbFor > 0.0d) {
                throw new IOException("The probability (" + Math.exp(logProbFor) + ") for sequence \"" + seqIterator.getSequence() + "\" is not in [0,1].");
            }
            d = Normalisation.getLogSum(d, logProbFor);
        } while (seqIterator.next());
        return Math.exp(d);
    }

    public static double getValueOfAIC(StatisticalModel statisticalModel, DataSet dataSet, int i) throws Exception {
        return (2.0d * getLogLikelihood(statisticalModel, dataSet)) - (2 * i);
    }

    public static double getValueOfBIC(StatisticalModel statisticalModel, DataSet dataSet, int i) throws Exception {
        return (2.0d * getLogLikelihood(statisticalModel, dataSet)) - (i * StrictMath.log(dataSet.getNumberOfElements()));
    }
}
