package de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif;

import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.text.NumberFormat;
import java.util.HashSet;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/mixture/motif/MixtureDurationDiffSM.class */
public class MixtureDurationDiffSM extends DurationDiffSM {
    private DurationDiffSM[] function;
    private double[] hiddenParams;
    private double[] scores;
    private double logNorm;
    private int[] paramRef;
    private int[] partDerOffset;
    private int starts;
    private IntList help;
    private static String XML_TAG = "MixtureDurationDiffSM";

    private static double getESS(DurationDiffSM... durationDiffSMArr) {
        double ess = durationDiffSMArr[0].getESS();
        boolean z = ess == 0.0d;
        for (int i = 1; i < durationDiffSMArr.length; i++) {
            double ess2 = durationDiffSMArr[i].getESS();
            if (!z) {
                ess += ess2;
            } else if (ess2 > 0.0d) {
                throw new IllegalArgumentException("The ESS of duration " + i + " has to be zero.");
            }
        }
        return ess;
    }

    public MixtureDurationDiffSM(int i, DurationDiffSM... durationDiffSMArr) throws WrongAlphabetException, CloneNotSupportedException, IllegalArgumentException {
        super(durationDiffSMArr[0].getMin(), durationDiffSMArr[0].getMax(), getESS(durationDiffSMArr));
        if (i <= 0) {
            throw new IllegalArgumentException("The number of recommended starts should be positive.");
        }
        this.starts = i;
        this.function = new DurationDiffSM[durationDiffSMArr.length];
        this.scores = new double[durationDiffSMArr.length];
        this.hiddenParams = new double[durationDiffSMArr.length];
        this.paramRef = null;
        this.partDerOffset = new int[durationDiffSMArr.length];
        for (int i2 = 0; i2 < durationDiffSMArr.length; i2++) {
            if (!this.alphabets.checkConsistency(durationDiffSMArr[i2].getAlphabetContainer())) {
                throw new WrongAlphabetException("All durations have to have the same alphabet: Violated at position " + i2);
            }
            this.function[i2] = (DurationDiffSM) durationDiffSMArr[i2].mo71clone();
        }
        this.logNorm = Normalisation.getLogSum(this.hiddenParams);
        setParamRef(false);
        this.help = new IntList();
    }

    private void setParamRef(boolean z) {
        if (this.paramRef == null || this.paramRef.length != this.function.length + 2) {
            this.paramRef = new int[this.function.length + 2];
        }
        int i = 0;
        boolean z2 = false;
        while (i < this.function.length) {
            int numberOfParameters = this.function[i].getNumberOfParameters();
            z2 |= numberOfParameters < 0;
            this.paramRef[i + 1] = this.paramRef[i] + numberOfParameters;
            i++;
        }
        if (z2) {
            this.paramRef[i + 1] = -1;
        } else {
            this.paramRef[i + 1] = (this.paramRef[i] + this.scores.length) - (z ? 1 : 0);
        }
    }

    public MixtureDurationDiffSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.PositionDiffSM, de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public MixtureDurationDiffSM mo71clone() throws CloneNotSupportedException {
        MixtureDurationDiffSM mixtureDurationDiffSM = (MixtureDurationDiffSM) super.mo71clone();
        mixtureDurationDiffSM.function = (DurationDiffSM[]) ArrayHandler.clone(this.function);
        mixtureDurationDiffSM.scores = (double[]) this.scores.clone();
        mixtureDurationDiffSM.paramRef = (int[]) this.paramRef.clone();
        mixtureDurationDiffSM.partDerOffset = (int[]) this.partDerOffset.clone();
        mixtureDurationDiffSM.hiddenParams = (double[]) this.hiddenParams.clone();
        mixtureDurationDiffSM.help = new IntList();
        return mixtureDurationDiffSM;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.DurationDiffSM
    public void adjust(int[] iArr, double[] dArr) {
        double[] dArr2 = new double[this.function.length];
        double[] dArr3 = new double[this.hiddenParams.length];
        double d = 0.0d;
        int i = 0;
        HashSet hashSet = new HashSet();
        while (i < this.function.length) {
            Class<?> cls = this.function[i].getClass();
            if (hashSet.contains(cls)) {
                break;
            }
            hashSet.add(cls);
            i++;
        }
        boolean z = i < this.function.length;
        for (int i2 = 0; i2 < this.function.length; i2++) {
            if (z) {
                try {
                    this.function[i2].initializeFunctionRandomly(false);
                } catch (Exception e) {
                    throw new RuntimeException();
                }
            } else {
                this.function[i2].adjust(iArr, dArr);
            }
            this.hiddenParams[i2] = 0.0d;
            dArr2[i2] = new double[dArr.length];
            dArr3[i2] = this.function[i2].getESS();
            d += dArr3[i2];
        }
        this.logNorm = Normalisation.getLogSum(this.hiddenParams);
        int[] iArr2 = new int[1];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            iArr2[0] = iArr[i3];
            for (int i4 = 0; i4 < this.function.length; i4++) {
                this.scores[i4] = this.hiddenParams[i4] + this.function[i4].getLogScore(iArr2);
            }
            Normalisation.logSumNormalisation(this.scores);
            for (int i5 = 0; i5 < this.function.length; i5++) {
                dArr2[i5][i3] = dArr[i3] * this.scores[i5];
                int i6 = i5;
                dArr3[i6] = dArr3[i6] + (dArr[i3] * this.scores[i5]);
            }
            d += dArr[i3];
        }
        for (int i7 = 0; i7 < this.function.length; i7++) {
            this.function[i7].adjust(iArr, dArr2[i7]);
            this.hiddenParams[i7] = Math.log(dArr3[i7] / d);
        }
        this.logNorm = 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.PositionDiffSM
    public double getLogScore(int... iArr) {
        for (int i = 0; i < this.function.length; i++) {
            this.scores[i] = this.hiddenParams[i] + this.function[i].getLogScore(iArr);
        }
        return Normalisation.getLogSum(this.scores) - this.logNorm;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.PositionDiffSM
    public double getLogScoreAndPartialDerivation(IntList intList, DoubleList doubleList, int... iArr) {
        int length = doubleList.length();
        for (int i = 0; i < this.function.length; i++) {
            this.help.clear();
            this.scores[i] = this.hiddenParams[i] + this.function[i].getLogScoreAndPartialDerivation(this.help, doubleList, iArr);
            this.partDerOffset[i] = doubleList.length();
            for (int i2 = 0; i2 < this.help.length(); i2++) {
                intList.add(this.paramRef[i] + this.help.get(i2));
            }
        }
        double logSumNormalisation = Normalisation.logSumNormalisation(this.scores);
        for (int i3 = 0; i3 < this.function.length; i3++) {
            doubleList.multiply(length, this.partDerOffset[i3], this.scores[i3]);
            length = this.partDerOffset[i3];
        }
        int i4 = 0;
        int i5 = this.paramRef[this.function.length];
        while (i5 < this.paramRef[this.function.length + 1]) {
            intList.add(i5);
            doubleList.add(this.scores[i4] - Math.exp(this.hiddenParams[i4] - this.logNorm));
            i5++;
            i4++;
        }
        return logSumNormalisation - this.logNorm;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) throws Exception {
        if (this.ess > 0.0d) {
            for (int i2 = 0; i2 < this.function.length; i2++) {
                this.function[i2].addGradientOfLogPriorTerm(dArr, this.paramRef[i2] + i);
            }
            int i3 = 0;
            int i4 = this.paramRef[this.function.length];
            while (i4 < this.paramRef[this.function.length + 1]) {
                dArr[i4 + i] = this.function[i3].getESS() - (this.ess * Math.exp(this.hiddenParams[i3] - this.logNorm));
                i4++;
                i3++;
            }
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        double d = 0.0d;
        if (this.ess > 0.0d) {
            for (int i = 0; i < this.function.length; i++) {
                d += this.function[i].getLogPriorTerm() + (this.function[i].getESS() * this.hiddenParams[i]);
            }
            d -= this.ess * this.logNorm;
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        int numberOfParameters = getNumberOfParameters();
        if (numberOfParameters <= 0) {
            throw new RuntimeException();
        }
        double[] dArr = new double[numberOfParameters];
        for (int i = 0; i < this.function.length; i++) {
            double[] currentParameterValues = this.function[i].getCurrentParameterValues();
            System.arraycopy(currentParameterValues, 0, dArr, this.paramRef[i], currentParameterValues.length);
        }
        int i2 = 0;
        int i3 = this.paramRef[this.function.length];
        while (i3 < this.paramRef[this.function.length + 1]) {
            dArr[i3] = this.hiddenParams[i2];
            i3++;
            i2++;
        }
        return dArr;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        String str = "mixture duration(" + this.function[0].getInstanceName();
        for (int i = 1; i < this.function.length; i++) {
            str = str + ", " + this.function[i].getInstanceName();
        }
        return str + ")";
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        return this.paramRef[this.function.length + 1];
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v40, types: [double[]] */
    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        double d = 1.0d;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < this.function.length; i2++) {
            this.hiddenParams[i2] = this.function[i2].getESS();
            d2 += this.hiddenParams[i2];
        }
        DirichletMRGParams dirichletMRGParams = new DirichletMRGParams(this.hiddenParams);
        double[] dArr2 = new double[this.function.length];
        if (dArr == null) {
            dArr = new double[dataSetArr.length];
        }
        double[] dArr3 = dArr[i];
        double[][] dArr4 = new double[this.function.length][dataSetArr[i].getNumberOfElements()];
        for (int i3 = 0; i3 < dArr4[0].length; i3++) {
            DirichletMRG.DEFAULT_INSTANCE.generate(dArr2, 0, this.function.length, dirichletMRGParams);
            if (dArr3 != null) {
                d = dArr3[i3];
            }
            d2 += d;
            for (int i4 = 0; i4 < this.function.length; i4++) {
                dArr4[i4][i3] = d * dArr2[i4];
                double[] dArr5 = this.hiddenParams;
                int i5 = i4;
                dArr5[i5] = dArr5[i5] + dArr4[i4][i3];
            }
        }
        for (int i6 = 0; i6 < this.function.length; i6++) {
            dArr[i] = dArr4[i6];
            this.hiddenParams[i6] = Math.log(this.hiddenParams[i6] / d2);
        }
        this.logNorm = 0.0d;
        dArr[i] = dArr3;
        setParamRef(z);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.DurationDiffSM
    public void initializeUniformly() {
        for (int i = 0; i < this.function.length; i++) {
            this.function[i].initializeUniformly();
            this.hiddenParams[i] = 0.0d;
        }
        this.logNorm = Normalisation.getLogSum(this.hiddenParams);
        setParamRef(false);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        boolean z2 = this.ess == 0.0d;
        double[] dArr = (double[]) this.scores.clone();
        for (int i = 0; i < this.function.length; i++) {
            this.function[i].initializeFunctionRandomly(z);
            dArr[i] = z2 ? 1.0d : this.function[i].getESS();
        }
        DirichletMRG.DEFAULT_INSTANCE.generate(this.hiddenParams, 0, this.function.length, new DirichletMRGParams(dArr));
        for (int i2 = 0; i2 < this.function.length; i2++) {
            this.hiddenParams[i2] = Math.log(this.hiddenParams[i2]);
        }
        this.logNorm = 0.0d;
        setParamRef(z);
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        int i = 0;
        while (i < this.function.length && this.function[i].isInitialized()) {
            i++;
        }
        return i == this.function.length;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public boolean isNormalized() {
        return true;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        for (int i2 = 0; i2 < this.function.length; i2++) {
            this.function[i2].setParameters(dArr, i + this.paramRef[i2]);
        }
        int i3 = 0;
        int i4 = this.paramRef[this.function.length];
        while (i4 < this.paramRef[this.function.length + 1]) {
            this.hiddenParams[i3] = dArr[i + i4];
            i4++;
            i3++;
        }
        this.logNorm = Normalisation.getLogSum(this.hiddenParams);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.DurationDiffSM
    public String getRNotation(String str, NumberFormat numberFormat) {
        String str2 = "";
        String str3 = null;
        for (int i = 0; i < this.function.length; i++) {
            str2 = str2 + this.function[i].getRNotation(str + i, numberFormat) + "\n";
            str3 = (str3 == null ? str + " = " : str3 + " + ") + numberFormat.format(Math.exp(this.hiddenParams[i] - this.logNorm)) + " * " + str + i;
        }
        return str2 + str3 + ";";
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.DurationDiffSM
    public void modify(int i) {
        if (i != 0) {
            super.modify(i);
            for (int i2 = 0; i2 < this.function.length; i2++) {
                this.function[i2].modify(i);
            }
        }
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfRecommendedStarts() {
        return this.starts;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.DurationDiffSM, de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.PositionDiffSM, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    public void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, XML_TAG);
        super.fromXML(extractForTag);
        this.function = (DurationDiffSM[]) XMLParser.extractObjectForTags(extractForTag, "components", DurationDiffSM[].class);
        this.hiddenParams = (double[]) XMLParser.extractObjectForTags(extractForTag, "hiddenParams", double[].class);
        this.starts = ((Integer) XMLParser.extractObjectForTags(extractForTag, "starts", Integer.TYPE)).intValue();
        this.scores = new double[this.function.length];
        this.paramRef = null;
        this.partDerOffset = new int[this.function.length];
        this.logNorm = Normalisation.getLogSum(this.hiddenParams);
        setParamRef(false);
        this.help = new IntList();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.DurationDiffSM, de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.PositionDiffSM, de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer xml = super.toXML();
        XMLParser.appendObjectWithTags(xml, this.function, "components");
        XMLParser.appendObjectWithTags(xml, this.hiddenParams, "hiddenParams");
        XMLParser.appendObjectWithTags(xml, Integer.valueOf(this.starts), "starts");
        XMLParser.addTags(xml, XML_TAG);
        return xml;
    }
}
