package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.SamplingDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.DifferentiableStatisticalModelWrapperTrainSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.DifferentiableState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.SimpleDifferentiableState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.MaxHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.NumericalHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.DifferentiableTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import java.util.LinkedList;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/models/DifferentiableHigherOrderHMM.class */
public class DifferentiableHigherOrderHMM extends HigherOrderHMM implements SamplingDifferentiableStatisticalModel {
    protected int numberOfParameters;
    protected double ess;
    protected HigherOrderHMM.Type score;
    protected int[][] index;
    protected double[][][] gradient;
    protected IntList[] indicesState;
    protected IntList[] indicesTransition;
    protected DoubleList[] partDerState;
    protected DoubleList[] partDerTransition;

    public DifferentiableHigherOrderHMM(MaxHMMTrainingParameterSet maxHMMTrainingParameterSet, String[] strArr, int[] iArr, boolean[] zArr, DifferentiableEmission[] differentiableEmissionArr, boolean z, double d, TransitionElement... transitionElementArr) throws Exception {
        super(maxHMMTrainingParameterSet, strArr, iArr, zArr, differentiableEmissionArr, transitionElementArr);
        getOffsets();
        this.score = z ? HigherOrderHMM.Type.LIKELIHOOD : HigherOrderHMM.Type.VITERBI;
        if (d < 0.0d) {
            throw new IllegalArgumentException();
        }
        this.ess = d;
    }

    public DifferentiableHigherOrderHMM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
        getOffsets();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public void appendFurtherInformation(StringBuffer stringBuffer) {
        super.appendFurtherInformation(stringBuffer);
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.ess), "ess");
        XMLParser.appendObjectWithTags(stringBuffer, this.score, "score");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public void extractFurtherInformation(StringBuffer stringBuffer) throws NonParsableException {
        super.extractFurtherInformation(stringBuffer);
        this.ess = ((Double) XMLParser.extractObjectForTags(stringBuffer, "ess", Double.TYPE)).doubleValue();
        this.score = (HigherOrderHMM.Type) XMLParser.extractObjectForTags(stringBuffer, "score", HigherOrderHMM.Type.class);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public void createHelperVariables() {
        if (this.container == null) {
            int maximalMarkovOrder = this.transition.getMaximalMarkovOrder();
            int i = 0;
            for (int i2 = 0; i2 <= maximalMarkovOrder; i2++) {
                i = Math.max(i, this.transition.getNumberOfIndexes(i2));
            }
            if (this.gradient == null || this.gradient[0].length != i || this.gradient[0][0].length != this.numberOfParameters) {
                this.gradient = new double[2][i][this.numberOfParameters];
                this.index = new int[3][i];
            }
            if (this.indicesState == null) {
                int maximalNumberOfChildren = this.transition.getMaximalNumberOfChildren();
                try {
                    this.indicesState = (IntList[]) ArrayHandler.createArrayOf(new IntList(), this.states.length);
                    this.partDerState = (DoubleList[]) ArrayHandler.createArrayOf(new DoubleList(), this.states.length);
                    this.indicesTransition = (IntList[]) ArrayHandler.createArrayOf(new IntList(), maximalNumberOfChildren);
                    this.partDerTransition = (DoubleList[]) ArrayHandler.createArrayOf(new DoubleList(), maximalNumberOfChildren);
                } catch (CloneNotSupportedException e) {
                    throw getRunTimeException(e);
                }
            }
        }
        super.createHelperVariables();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    protected void createStates() {
        this.states = new SimpleDifferentiableState[this.emissionIdx.length];
        for (int i = 0; i < this.emissionIdx.length; i++) {
            this.states[i] = new SimpleDifferentiableState((DifferentiableEmission) this.emission[this.emissionIdx[i]], this.name[i], this.forward[i]);
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM, de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public DifferentiableHigherOrderHMM mo80clone() throws CloneNotSupportedException {
        double[][][] dArr = this.gradient;
        this.gradient = (double[][][]) null;
        IntList[] intListArr = this.indicesState;
        this.indicesState = null;
        DifferentiableHigherOrderHMM differentiableHigherOrderHMM = (DifferentiableHigherOrderHMM) super.mo80clone();
        this.gradient = dArr;
        this.indicesState = intListArr;
        return differentiableHigherOrderHMM;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        return this.ess;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) throws Exception {
        for (int i2 = 0; i2 < this.emission.length; i2++) {
            ((DifferentiableEmission) this.emission[i2]).addGradientOfLogPriorTerm(dArr, i);
        }
        ((DifferentiableTransition) this.transition).addGradientForLogPriorTerm(dArr, i);
    }

    private void getOffsets() {
        this.numberOfParameters = 0;
        for (int i = 0; i < this.emission.length; i++) {
            this.numberOfParameters = ((DifferentiableEmission) this.emission[i]).setParameterOffset(this.numberOfParameters);
            if (this.numberOfParameters == -1) {
                return;
            }
        }
        this.numberOfParameters = ((DifferentiableTransition) this.transition).setParameterOffset(this.numberOfParameters);
        if (this.numberOfParameters == -1) {
            return;
        }
        createHelperVariables();
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        return this.numberOfParameters;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfRecommendedStarts() {
        return this.trainingParameter.getNumberOfStarts();
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        int i = 0;
        int numberOfParameters = getNumberOfParameters();
        if (numberOfParameters == -1) {
            throw new IllegalArgumentException();
        }
        double[] dArr = new double[numberOfParameters];
        int i2 = 0;
        while (i2 < this.emission.length) {
            ((DifferentiableEmission) this.emission[i2]).fillCurrentParameter(dArr);
            i2++;
            i++;
        }
        ((DifferentiableTransition) this.transition).fillParameters(dArr);
        return dArr;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return true;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        for (int i2 = 0; i2 < this.emission.length; i2++) {
            ((DifferentiableEmission) this.emission[i2]).setParameter(dArr, i);
        }
        ((DifferentiableTransition) this.transition).setParameters(dArr, i);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        if (this.skipInit) {
            return;
        }
        initializeRandomly();
        getOffsets();
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        if (this.skipInit) {
            return;
        }
        if (this.trainingParameter instanceof NumericalHMMTrainingParameterSet) {
            initializeFunctionRandomly(z);
        } else {
            train(dataSetArr[i], dArr == null ? null : dArr[i]);
            getOffsets();
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel
    public void train(DataSet dataSet, double[] dArr) throws Exception {
        if (!(this.trainingParameter instanceof NumericalHMMTrainingParameterSet)) {
            super.train(dataSet, dArr);
            return;
        }
        NumericalHMMTrainingParameterSet numericalHMMTrainingParameterSet = (NumericalHMMTrainingParameterSet) this.trainingParameter;
        DifferentiableStatisticalModelWrapperTrainSM differentiableStatisticalModelWrapperTrainSM = new DifferentiableStatisticalModelWrapperTrainSM(this, numericalHMMTrainingParameterSet.getNumberOfThreads(), numericalHMMTrainingParameterSet.getAlgorithm(), numericalHMMTrainingParameterSet.getTerminationCondition(), numericalHMMTrainingParameterSet.getLineEps(), numericalHMMTrainingParameterSet.getStartDistance());
        differentiableStatisticalModelWrapperTrainSM.setOutputStream(this.sostream);
        differentiableStatisticalModelWrapperTrainSM.train(dataSet, dArr);
        DifferentiableHigherOrderHMM differentiableHigherOrderHMM = (DifferentiableHigherOrderHMM) differentiableStatisticalModelWrapperTrainSM.getFunction();
        this.emission = differentiableHigherOrderHMM.emission;
        createStates();
        this.transition = differentiableHigherOrderHMM.transition;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public boolean isNormalized() {
        return true;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogNormalizationConstant() {
        return 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogPartialNormalizationConstant(int i) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getInitialClassParam(double d) {
        return Math.log(d);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence) {
        return getLogScoreFor(sequence, 0);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i) {
        return getLogScoreFor(sequence, i, sequence.getLength() - 1);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i, int i2) {
        return logProb(i, sequence.getLength() - 1, sequence);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public double logProb(int i, int i2, Sequence sequence) {
        try {
            fillBwdOrViterbiMatrix(this.score, i, i2, 0.0d, sequence);
            return this.bwdMatrix[0][0];
        } catch (Exception e) {
            throw getRunTimeException(e);
        }
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, IntList intList, DoubleList doubleList) {
        return getLogScoreAndPartialDerivation(sequence, 0, intList, doubleList);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        return getLogScoreAndPartialDerivation(sequence, i, sequence.getLength() - 1, intList, doubleList);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, int i2, IntList intList, DoubleList doubleList) {
        try {
            boolean z = this.transition.getMaximalMarkovOrder() == 0;
            int i3 = (i2 - i) + 1;
            provideMatrix(1, (i2 - i) + 1);
            for (int i4 = 0; i4 < this.gradient[1].length; i4++) {
                Arrays.fill(this.gradient[0][i4], 0.0d);
                Arrays.fill(this.gradient[1][i4], 0.0d);
            }
            DifferentiableTransition differentiableTransition = (DifferentiableTransition) this.transition;
            for (int i5 = 0; i5 < this.states.length; i5++) {
                this.indicesState[i5].clear();
                this.partDerState[i5].clear();
            }
            for (int length = this.bwdMatrix[i3].length - 1; length >= 0; length--) {
                int numberOfChildren = this.transition.getNumberOfChildren(i3, length);
                int i6 = 0;
                double d = (z || this.finalState[this.transition.getLastContextState(i3, length)]) ? 0.0d : Double.NEGATIVE_INFINITY;
                for (int i7 = 0; i7 < numberOfChildren; i7++) {
                    this.transition.fillTransitionInformation(i3, length, i7, this.container);
                    if (this.states[this.container[0]].isSilent()) {
                        this.indicesTransition[i6].clear();
                        this.partDerTransition[i6].clear();
                        this.backwardIntermediate[i6] = this.bwdMatrix[i3][this.container[1]] + differentiableTransition.getLogScoreAndPartialDerivation(i3, length, i7, this.indicesTransition[i6], this.partDerTransition[i6], sequence, i2);
                        if (this.backwardIntermediate[i6] != Double.NEGATIVE_INFINITY) {
                            this.index[0][i6] = this.container[0];
                            this.index[1][i6] = this.container[1];
                            this.index[2][i6] = this.container[2];
                            i6++;
                        }
                    }
                }
                if (i6 == 0) {
                    this.bwdMatrix[i3][length] = d;
                    resetGradient(i3, length, 0.0d);
                } else {
                    merge(i6, i3, length, d);
                }
            }
            while (true) {
                i3--;
                if (i3 < 0) {
                    break;
                }
                for (int i8 = 0; i8 < this.states.length; i8++) {
                    this.indicesState[i8].clear();
                    this.partDerState[i8].clear();
                    this.logEmission[i8] = ((DifferentiableState) this.states[i8]).getLogScoreAndPartialDerivation(i2, i2, this.indicesState[i8], this.partDerState[i8], sequence);
                }
                for (int length2 = this.bwdMatrix[i3].length - 1; length2 >= 0; length2--) {
                    int numberOfChildren2 = this.transition.getNumberOfChildren(i3, length2);
                    int i9 = 0;
                    for (int i10 = 0; i10 < numberOfChildren2; i10++) {
                        this.indicesTransition[i9].clear();
                        this.partDerTransition[i9].clear();
                        this.transition.fillTransitionInformation(i3, length2, i10, this.container);
                        this.backwardIntermediate[i9] = this.bwdMatrix[i3 + this.container[2]][this.container[1]] + this.logEmission[this.container[0]] + differentiableTransition.getLogScoreAndPartialDerivation(i3, length2, i10, this.indicesTransition[i9], this.partDerTransition[i9], sequence, i2);
                        if (this.backwardIntermediate[i9] != Double.NEGATIVE_INFINITY) {
                            this.index[0][i9] = this.container[0];
                            this.index[1][i9] = this.container[1];
                            this.index[2][i9] = this.container[2];
                            i9++;
                        }
                    }
                    if (i9 == 0) {
                        this.bwdMatrix[i3][length2] = Double.NEGATIVE_INFINITY;
                        resetGradient(i3, length2, 0.0d);
                    } else {
                        merge(i9, i3, length2, Double.NEGATIVE_INFINITY);
                    }
                }
                i2--;
            }
            for (int i11 = 0; i11 < this.numberOfParameters; i11++) {
                if (this.gradient[0][0][i11] != 0.0d) {
                    intList.add(i11);
                    doubleList.add(this.gradient[0][0][i11]);
                }
            }
            return this.bwdMatrix[0][0];
        } catch (Exception e) {
            throw getRunTimeException(e);
        }
    }

    private void merge(int i, int i2, int i3, double d) {
        int i4 = i2 % 2;
        if (this.score == HigherOrderHMM.Type.VITERBI) {
            int maxIndex = ToolBox.getMaxIndex(0, i, this.backwardIntermediate);
            if (this.backwardIntermediate[maxIndex] <= d) {
                this.bwdMatrix[i2][i3] = d;
                return;
            }
            System.arraycopy(this.gradient[(i2 + this.index[2][maxIndex]) % 2][this.index[1][maxIndex]], 0, this.gradient[i4][i3], 0, this.numberOfParameters);
            miniMerge(maxIndex, 1.0d, i4, i3);
            this.bwdMatrix[i2][i3] = this.backwardIntermediate[maxIndex];
            return;
        }
        if (d != Double.NEGATIVE_INFINITY) {
            this.bwdMatrix[i2][i3] = Normalisation.logSumNormalisation(this.backwardIntermediate, 0, i, new double[]{d}, this.backwardIntermediate, 0);
        } else {
            this.bwdMatrix[i2][i3] = Normalisation.logSumNormalisation(this.backwardIntermediate, 0, i, this.backwardIntermediate, 0);
        }
        Arrays.fill(this.gradient[i4][i3], 0.0d);
        for (int i5 = 0; i5 < i; i5++) {
            int i6 = (i2 + this.index[2][i5]) % 2;
            for (int i7 = 0; i7 < this.numberOfParameters; i7++) {
                double[] dArr = this.gradient[i4][i3];
                int i8 = i7;
                dArr[i8] = dArr[i8] + (this.backwardIntermediate[i5] * this.gradient[i6][this.index[1][i5]][i7]);
            }
            miniMerge(i5, this.backwardIntermediate[i5], i4, i3);
        }
    }

    private void miniMerge(int i, double d, int i2, int i3) {
        for (int i4 = 0; i4 < this.indicesTransition[i].length(); i4++) {
            double[] dArr = this.gradient[i2][i3];
            int i5 = this.indicesTransition[i].get(i4);
            dArr[i5] = dArr[i5] + (d * this.partDerTransition[i].get(i4));
        }
        for (int i6 = 0; i6 < this.indicesState[this.index[0][i]].length(); i6++) {
            double[] dArr2 = this.gradient[i2][i3];
            int i7 = this.indicesState[this.index[0][i]].get(i6);
            dArr2[i7] = dArr2[i7] + (d * this.partDerState[this.index[0][i]].get(i6));
        }
    }

    private void resetGradient(int i, int i2, double d) {
        Arrays.fill(this.gradient[i % 2][i2], d);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int i) {
        int i2 = 0;
        for (int i3 = 0; i3 < this.emission.length; i3++) {
            int numberOfParameters = ((DifferentiableEmission) this.emission[i3]).getNumberOfParameters();
            if (numberOfParameters > 0 && i >= i2 && i < i2 + numberOfParameters) {
                return ((DifferentiableEmission) this.emission[i3]).getSizeOfEventSpace();
            }
            i2 += numberOfParameters;
        }
        return ((DifferentiableTransition) this.transition).getSizeOfEventSpace(i);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.SamplingDifferentiableStatisticalModel
    public int[][] getSamplingGroups(int i) {
        LinkedList<int[]> linkedList = new LinkedList<>();
        for (int i2 = 0; i2 < this.emission.length; i2++) {
            ((DifferentiableEmission) this.emission[i2]).fillSamplingGroups(i, linkedList);
        }
        ((DifferentiableTransition) this.transition).fillSamplingGroups(i, linkedList);
        return (int[][]) linkedList.toArray(new int[0][0]);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return "differentiable HMM(" + this.transition.getMaximalMarkovOrder() + ", " + this.score + ")";
    }
}
