package de.jstacs.sequenceScores.differentiable;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.MultiDimensionalSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import java.text.NumberFormat;
import java.util.Arrays;

/* loaded from: input_file:de/jstacs/sequenceScores/differentiable/MultiDimensionalSequenceWrapperDiffSS.class */
public class MultiDimensionalSequenceWrapperDiffSS extends AbstractDifferentiableSequenceScore {
    private DifferentiableSequenceScore function;
    private IntList iList;
    private DoubleList dList;
    private double[] gradient;
    private static final String XML_TAG = MultiDimensionalSequenceWrapperDiffSS.class.getSimpleName();

    public MultiDimensionalSequenceWrapperDiffSS(DifferentiableSequenceScore differentiableSequenceScore) throws IllegalArgumentException, CloneNotSupportedException {
        super(differentiableSequenceScore.getAlphabetContainer(), differentiableSequenceScore.getLength());
        this.function = differentiableSequenceScore.mo71clone();
    }

    public MultiDimensionalSequenceWrapperDiffSS(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        this.function = (DifferentiableSequenceScore) XMLParser.extractObjectForTags(XMLParser.extractForTag(stringBuffer, XML_TAG), "function", DifferentiableSequenceScore.class);
        this.alphabets = this.function.getAlphabetContainer();
        this.length = this.function.getLength();
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer();
        XMLParser.appendObjectWithTags(stringBuffer, this.function, "function");
        XMLParser.addTags(stringBuffer, XML_TAG);
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    /* renamed from: clone */
    public MultiDimensionalSequenceWrapperDiffSS mo71clone() throws CloneNotSupportedException {
        MultiDimensionalSequenceWrapperDiffSS multiDimensionalSequenceWrapperDiffSS = (MultiDimensionalSequenceWrapperDiffSS) super.mo71clone();
        multiDimensionalSequenceWrapperDiffSS.function = this.function.mo71clone();
        if (this.gradient != null) {
            multiDimensionalSequenceWrapperDiffSS.gradient = (double[]) this.gradient.clone();
            multiDimensionalSequenceWrapperDiffSS.iList = this.iList.m125clone();
            multiDimensionalSequenceWrapperDiffSS.dList = this.dList.m124clone();
        }
        return multiDimensionalSequenceWrapperDiffSS;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        return this.function.getCurrentParameterValues();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return "multidimensional wrapper of " + this.function.getInstanceName();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i) {
        double logScoreFor;
        if (sequence instanceof MultiDimensionalSequence) {
            MultiDimensionalSequence multiDimensionalSequence = (MultiDimensionalSequence) sequence;
            int numberOfSequences = multiDimensionalSequence.getNumberOfSequences();
            double d = 0.0d;
            for (int i2 = 0; i2 < numberOfSequences; i2++) {
                d += this.function.getLogScoreFor(multiDimensionalSequence.getSequence(i2), i);
            }
            logScoreFor = d / numberOfSequences;
        } else {
            logScoreFor = this.function.getLogScoreFor(sequence, i);
        }
        return logScoreFor;
    }

    private void init() {
        this.gradient = new double[this.function.getNumberOfParameters()];
        this.iList = new IntList();
        this.dList = new DoubleList();
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        double logScoreAndPartialDerivation;
        if (sequence instanceof MultiDimensionalSequence) {
            MultiDimensionalSequence multiDimensionalSequence = (MultiDimensionalSequence) sequence;
            int numberOfSequences = multiDimensionalSequence.getNumberOfSequences();
            double d = 0.0d;
            if (this.gradient == null) {
                init();
            }
            Arrays.fill(this.gradient, 0.0d);
            for (int i2 = 0; i2 < numberOfSequences; i2++) {
                this.iList.clear();
                this.dList.clear();
                d += this.function.getLogScoreAndPartialDerivation(multiDimensionalSequence.getSequence(i2), i, this.iList, this.dList);
                for (int i3 = 0; i3 < this.iList.length(); i3++) {
                    double[] dArr = this.gradient;
                    int i4 = this.iList.get(i3);
                    dArr[i4] = dArr[i4] + this.dList.get(i3);
                }
            }
            logScoreAndPartialDerivation = d / numberOfSequences;
            for (int i5 = 0; i5 < this.gradient.length; i5++) {
                if (this.gradient[i5] != 0.0d) {
                    intList.add(i5);
                    doubleList.add(this.gradient[i5] / numberOfSequences);
                }
            }
        } else {
            logScoreAndPartialDerivation = this.function.getLogScoreAndPartialDerivation(sequence, i, intList, doubleList);
        }
        return logScoreAndPartialDerivation;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        return this.function.getNumberOfParameters();
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        this.function.initializeFunction(i, z, dataSetArr, dArr);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        this.function.initializeFunctionRandomly(z);
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return this.function.isInitialized();
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        this.function.setParameters(dArr, i);
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        return "wrapper of " + this.function.getInstanceName() + ":\n" + this.function.toString(numberFormat);
    }
}
