package de.jstacs.sequenceScores.statisticalModels.trainable.mixture;

import de.jstacs.algorithms.optimization.termination.TerminationCondition;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.sampling.BurnInTest;
import de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.mixture.AbstractMixtureTrainSM;
import de.jstacs.utils.random.MRGParams;
import de.jstacs.utils.random.MultivariateRandomGenerator;
import java.text.NumberFormat;
import java.util.Arrays;
import javax.naming.OperationNotSupportedException;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/mixture/MixtureTrainSM.class */
public class MixtureTrainSM extends AbstractMixtureTrainSM {
    /* JADX INFO: Access modifiers changed from: protected */
    public MixtureTrainSM(int i, TrainableStatisticalModel[] trainableStatisticalModelArr, int i2, boolean z, double[] dArr, double[] dArr2, AbstractMixtureTrainSM.Algorithm algorithm, double d, TerminationCondition terminationCondition, AbstractMixtureTrainSM.Parameterization parameterization, int i3, int i4, BurnInTest burnInTest) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        super(i, trainableStatisticalModelArr, null, trainableStatisticalModelArr.length, i2, z, dArr, dArr2, algorithm, d, terminationCondition, parameterization, i3, i4, burnInTest);
    }

    public MixtureTrainSM(int i, TrainableStatisticalModel[] trainableStatisticalModelArr, int i2, double[] dArr, double d, TerminationCondition terminationCondition, AbstractMixtureTrainSM.Parameterization parameterization) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        this(i, trainableStatisticalModelArr, i2, true, dArr, null, AbstractMixtureTrainSM.Algorithm.EM, d, terminationCondition, parameterization, 0, 0, null);
    }

    public MixtureTrainSM(int i, TrainableStatisticalModel[] trainableStatisticalModelArr, double[] dArr, int i2, double d, TerminationCondition terminationCondition, AbstractMixtureTrainSM.Parameterization parameterization) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        this(i, trainableStatisticalModelArr, i2, false, null, dArr, AbstractMixtureTrainSM.Algorithm.EM, d, terminationCondition, parameterization, 0, 0, null);
    }

    public MixtureTrainSM(int i, TrainableStatisticalModel[] trainableStatisticalModelArr, int i2, double[] dArr, int i3, int i4, BurnInTest burnInTest) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        this(i, trainableStatisticalModelArr, i2, true, dArr, null, AbstractMixtureTrainSM.Algorithm.GIBBS_SAMPLING, 0.0d, null, AbstractMixtureTrainSM.Parameterization.LAMBDA, i3, i4, burnInTest);
    }

    public MixtureTrainSM(int i, TrainableStatisticalModel[] trainableStatisticalModelArr, double[] dArr, int i2, int i3, int i4, BurnInTest burnInTest) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        this(i, trainableStatisticalModelArr, i2, false, null, dArr, AbstractMixtureTrainSM.Algorithm.GIBBS_SAMPLING, 0.0d, null, AbstractMixtureTrainSM.Parameterization.LAMBDA, i3, i4, burnInTest);
    }

    public MixtureTrainSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.mixture.AbstractMixtureTrainSM
    protected Sequence[] emitDataSetUsingCurrentParameterSet(int i, int... iArr) throws Exception {
        DataSet emitDataSet;
        int[] iArr2 = new int[this.dimension];
        Arrays.fill(iArr2, 0);
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < i; i4++) {
            int draw = AbstractMixtureTrainSM.draw(this.weights, 0);
            iArr2[draw] = iArr2[draw] + 1;
        }
        Sequence[] sequenceArr = new Sequence[i];
        if (this.length == 0) {
            while (i2 < this.dimension) {
                if (iArr2[i2] > 0) {
                    if (iArr.length == 1) {
                        emitDataSet = this.model[i2].emitDataSet(i, iArr);
                    } else {
                        int[] iArr3 = new int[iArr2[i2]];
                        System.arraycopy(iArr, i3, iArr3, 0, iArr2[i2]);
                        emitDataSet = this.model[i2].emitDataSet(i, iArr3);
                    }
                    i3 = 0;
                    while (i3 < emitDataSet.getNumberOfElements()) {
                        sequenceArr[0] = emitDataSet.getElementAt(i3);
                        i3++;
                    }
                }
                i2++;
            }
        } else {
            if (iArr != null && iArr.length != 0) {
                throw new Exception("This is an inhomogeneous model. Please check parameter lengths.");
            }
            while (i2 < this.dimension) {
                if (iArr2[i2] > 0) {
                    DataSet emitDataSet2 = this.model[i2].getLength() == 0 ? this.model[i2].emitDataSet(iArr2[i2], this.length) : this.model[i2].emitDataSet(iArr2[i2], iArr);
                    int i5 = 0;
                    while (i5 < iArr2[i2]) {
                        sequenceArr[i3] = emitDataSet2.getElementAt(i5);
                        i5++;
                        i3++;
                    }
                }
                i2++;
            }
        }
        return sequenceArr;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.mixture.AbstractMixtureTrainSM
    protected double[][] doFirstIteration(double[] dArr, MultivariateRandomGenerator multivariateRandomGenerator, MRGParams[] mRGParamsArr) throws Exception {
        int numberOfElements = this.sample[0].getNumberOfElements();
        double[][] createSeqWeightsArray = createSeqWeightsArray();
        double[] dArr2 = new double[this.dimension];
        initWithPrior(dArr2);
        double[] dArr3 = new double[this.dimension];
        if (dArr == null) {
            for (int i = 0; i < numberOfElements; i++) {
                double[] generate = multivariateRandomGenerator.generate(this.dimension, mRGParamsArr[i]);
                for (int i2 = 0; i2 < this.dimension; i2++) {
                    createSeqWeightsArray[i2][i] = generate[i2];
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + generate[i2];
                }
            }
        } else {
            for (int i4 = 0; i4 < numberOfElements; i4++) {
                double[] generate2 = multivariateRandomGenerator.generate(this.dimension, mRGParamsArr[i4]);
                for (int i5 = 0; i5 < this.dimension; i5++) {
                    createSeqWeightsArray[i5][i4] = dArr[i4] * generate2[i5];
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] + createSeqWeightsArray[i5][i4];
                }
            }
        }
        getNewParameters(0, createSeqWeightsArray, dArr2);
        return createSeqWeightsArray;
    }

    public double[][] doFirstIteration(DataSet dataSet, double[] dArr, double[][] dArr2) throws Exception {
        setTrainData(dataSet);
        if (this.dimension <= 1) {
            throw new OperationNotSupportedException();
        }
        int numberOfElements = dataSet.getNumberOfElements();
        double[][] createSeqWeightsArray = createSeqWeightsArray();
        double[] dArr3 = new double[this.dimension];
        initWithPrior(dArr3);
        for (int i = 0; i < numberOfElements; i++) {
            if (dArr2[i].length != this.dimension) {
                throw new IllegalArgumentException("The partitioning for sequence " + i + " was wrong. (number of parts)");
            }
            double d = 0.0d;
            for (int i2 = 0; i2 < this.dimension; i2++) {
                if (dArr2[i][i2] < 0.0d || dArr2[i][i2] > 1.0d) {
                    throw new IllegalArgumentException("The partitioning for sequence " + i + " was wrong. (part " + i2 + "was incorrect)");
                }
                createSeqWeightsArray[i2][i] = (dArr == null ? 1.0d : dArr[i]) * dArr2[i][i2];
                d += dArr2[i][i2];
                int i3 = i2;
                dArr3[i3] = dArr3[i3] + createSeqWeightsArray[i2][i];
            }
            if (d != 1.0d) {
                throw new IllegalArgumentException("The partitioning for sequence " + i + " was wrong. (sum of parts not 1)");
            }
        }
        getNewParameters(0, createSeqWeightsArray, dArr3);
        return createSeqWeightsArray;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.mixture.AbstractMixtureTrainSM
    protected double getLogProbUsingCurrentParameterSetFor(int i, Sequence sequence, int i2, int i3) throws Exception {
        return this.logWeights[i] + this.model[i].getLogProbFor(sequence, i2, i3);
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        StringBuffer stringBuffer = new StringBuffer(this.model.length * 100000);
        stringBuffer.append("Mixture model with parameter estimation by " + getNameOfAlgorithm() + ": \n");
        stringBuffer.append("number of starts:\t" + this.starts + "\n");
        switch (this.algorithm) {
            case EM:
                for (int i = 0; i < this.dimension; i++) {
                    stringBuffer.append(numberFormat.format(this.weights[i]) + "\t" + this.model[i].getInstanceName() + "\n" + this.model[i].toString(numberFormat) + "\n");
                }
                break;
            case GIBBS_SAMPLING:
                stringBuffer.append("burn in test              :\t" + this.burnInTest.getInstanceName() + "\n");
                stringBuffer.append("length of stationary phase:\t" + this.stationaryIteration + "\n");
                stringBuffer.append("Mixture model components:\n");
                for (int i2 = 0; i2 < this.dimension; i2++) {
                    stringBuffer.append((i2 + 1) + ". component: " + this.model[i2].getInstanceName() + "\n");
                }
                break;
            default:
                throw new IllegalArgumentException("The type of algorithm is unknown.");
        }
        return stringBuffer.toString();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.mixture.AbstractMixtureTrainSM
    protected double getNewWeights(double[] dArr, double[] dArr2, double[][] dArr3) throws Exception {
        double d = 0.0d;
        double d2 = 1.0d;
        initWithPrior(dArr2);
        double[] dArr4 = new double[this.dimension];
        for (int i = 0; i < dArr3[0].length; i++) {
            Sequence elementAt = this.sample[0].getElementAt(i);
            if (dArr != null) {
                d2 = dArr[i];
            }
            for (int i2 = 0; i2 < this.dimension; i2++) {
                dArr4[i2] = this.model[i2].getLogProbFor(elementAt) + this.logWeights[i2];
            }
            d += modifyWeights(dArr4) * d2;
            for (int i3 = 0; i3 < this.dimension; i3++) {
                dArr3[i3][i] = dArr4[i3] * d2;
                int i4 = i3;
                dArr2[i4] = dArr2[i4] + dArr3[i3][i];
            }
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.mixture.AbstractMixtureTrainSM
    protected void setTrainData(DataSet dataSet) {
        this.sample = new DataSet[]{dataSet};
    }
}
