package de.jstacs.sequenceScores.statisticalModels.trainable;

import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.optimization.LimitedMedianStartDistance;
import de.jstacs.algorithms.optimization.NegativeDifferentiableFunction;
import de.jstacs.algorithms.optimization.Optimizer;
import de.jstacs.algorithms.optimization.termination.AbstractTerminationCondition;
import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LogGenDisMixFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.CompositeLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.WrongLengthException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.IndependentProductDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.UniformDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.UniformHomogeneousDiffSM;
import de.jstacs.utils.SafeOutputStream;
import java.io.OutputStream;
import java.text.NumberFormat;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/DifferentiableStatisticalModelWrapperTrainSM.class */
public class DifferentiableStatisticalModelWrapperTrainSM extends AbstractTrainableStatisticalModel {
    private SafeOutputStream out;
    protected DifferentiableStatisticalModel nsf;
    private double logNorm;
    private double lineps;
    private double startD;
    private AbstractTerminationCondition tc;
    private byte algo;
    private int threads;
    private LogPrior prior;
    private static final String XML_TAG = "DifferentiableStatisticalModelWrapperTrainSM";

    public DifferentiableStatisticalModelWrapperTrainSM(DifferentiableStatisticalModel differentiableStatisticalModel, int i, byte b, AbstractTerminationCondition abstractTerminationCondition, double d, double d2) throws CloneNotSupportedException {
        this(differentiableStatisticalModel, i, b, abstractTerminationCondition, d, d2, new CompositeLogPrior());
    }

    public DifferentiableStatisticalModelWrapperTrainSM(DifferentiableStatisticalModel differentiableStatisticalModel, int i, byte b, AbstractTerminationCondition abstractTerminationCondition, double d, double d2, LogPrior logPrior) throws CloneNotSupportedException {
        super(differentiableStatisticalModel.getAlphabetContainer(), differentiableStatisticalModel.getLength());
        if (i < 1) {
            throw new IllegalArgumentException("The number of threads has to be positive.");
        }
        this.threads = i;
        this.tc = abstractTerminationCondition.m6clone();
        if (d < 0.0d) {
            throw new IllegalArgumentException("The value of lineps has to be non-negative.");
        }
        this.lineps = d;
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("The value of startD has to be positive.");
        }
        this.startD = d2;
        this.algo = b;
        this.nsf = (DifferentiableStatisticalModel) differentiableStatisticalModel.mo98clone();
        if (isInitialized()) {
            this.logNorm = differentiableStatisticalModel.getLogNormalizationConstant();
        } else {
            this.logNorm = Double.NEGATIVE_INFINITY;
        }
        setOutputStream(SafeOutputStream.DEFAULT_STREAM);
        this.prior = logPrior.getNewInstance();
    }

    public DifferentiableStatisticalModelWrapperTrainSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel
    /* renamed from: clone */
    public DifferentiableStatisticalModelWrapperTrainSM mo98clone() throws CloneNotSupportedException {
        DifferentiableStatisticalModelWrapperTrainSM differentiableStatisticalModelWrapperTrainSM = (DifferentiableStatisticalModelWrapperTrainSM) super.mo98clone();
        differentiableStatisticalModelWrapperTrainSM.nsf = (DifferentiableStatisticalModel) this.nsf.mo98clone();
        differentiableStatisticalModelWrapperTrainSM.tc = this.tc.m6clone();
        differentiableStatisticalModelWrapperTrainSM.setOutputStream(this.out.doesNothing() ? null : SafeOutputStream.DEFAULT_STREAM);
        return differentiableStatisticalModelWrapperTrainSM;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v20, types: [double[], double[][]] */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel
    public void train(DataSet dataSet, double[] dArr) throws Exception {
        if (!dataSet.getAlphabetContainer().checkConsistency(this.alphabets)) {
            throw new WrongAlphabetException("The AlphabetConatainer of the data set and the model do not match.");
        }
        if (this.length != 0 && this.length != dataSet.getElementLength()) {
            throw new WrongLengthException("The length of the elements of the data set is not suitable for the model.");
        }
        if (!(this.nsf instanceof IndependentProductDiffSM)) {
            this.nsf = train(dataSet, dArr, this.nsf);
            return;
        }
        IndependentProductDiffSM independentProductDiffSM = (IndependentProductDiffSM) this.nsf;
        DifferentiableStatisticalModel[] differentiableStatisticalModelArr = (DifferentiableStatisticalModel[]) ArrayHandler.cast(DifferentiableStatisticalModel.class, independentProductDiffSM.getFunctions());
        DataSet[] dataSetArr = new DataSet[1];
        DataSet[] dataSetArr2 = {dataSet};
        ?? r0 = {dArr};
        for (int i = 0; i < differentiableStatisticalModelArr.length; i++) {
            differentiableStatisticalModelArr[i] = train(dataSetArr[0], independentProductDiffSM.extractWeights(independentProductDiffSM.extractSequenceParts(i, dataSetArr2, dataSetArr), r0)[0], differentiableStatisticalModelArr[i]);
        }
        this.nsf = new IndependentProductDiffSM(independentProductDiffSM.getESS(), true, differentiableStatisticalModelArr, independentProductDiffSM.getIndices(), independentProductDiffSM.getPartialLengths(), independentProductDiffSM.getReverseSwitches());
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v21, types: [de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel[], de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore[]] */
    /* JADX WARN: Type inference failed for: r0v27, types: [de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LogGenDisMixFunction, de.jstacs.algorithms.optimization.DifferentiableFunction] */
    /* JADX WARN: Type inference failed for: r0v40, types: [de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel] */
    /* JADX WARN: Type inference failed for: r0v54 */
    /* JADX WARN: Type inference failed for: r4v6, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r5v2, types: [double[], double[][]] */
    private DifferentiableStatisticalModel train(DataSet dataSet, double[] dArr, DifferentiableStatisticalModel differentiableStatisticalModel) throws Exception {
        if (!(differentiableStatisticalModel instanceof UniformDiffSM) && !(differentiableStatisticalModel instanceof UniformHomogeneousDiffSM)) {
            DataSet.WeightedDataSetFactory weightedDataSetFactory = new DataSet.WeightedDataSetFactory(DataSet.WeightedDataSetFactory.SortOperation.NO_SORT, dataSet, dArr);
            DataSet dataSet2 = weightedDataSetFactory.getDataSet();
            double[] weights = weightedDataSetFactory.getWeights();
            DifferentiableStatisticalModel differentiableStatisticalModel2 = null;
            double d = Double.NEGATIVE_INFINITY;
            double numberOfElements = dataSet.getNumberOfElements();
            double ess = differentiableStatisticalModel.getESS();
            double d2 = (numberOfElements / (ess + numberOfElements)) * (ess == 0.0d ? 1.0d : 2.0d);
            ?? r0 = {(DifferentiableStatisticalModel) differentiableStatisticalModel.mo98clone()};
            ?? logGenDisMixFunction = new LogGenDisMixFunction(this.threads, r0, new DataSet[]{dataSet2}, new double[]{weights}, this.prior, LearningPrinciple.getBeta(ess == 0.0d ? LearningPrinciple.ML : LearningPrinciple.MAP), true, false);
            NegativeDifferentiableFunction negativeDifferentiableFunction = new NegativeDifferentiableFunction(logGenDisMixFunction);
            LimitedMedianStartDistance limitedMedianStartDistance = new LimitedMedianStartDistance(5, this.startD * d2);
            for (int i = 0; i < differentiableStatisticalModel.getNumberOfRecommendedStarts(); i++) {
                this.out.writeln("start: " + i);
                r0[0].initializeFunction(0, false, new DataSet[]{dataSet2}, new double[]{weights});
                logGenDisMixFunction.reset(r0);
                double[] parameters = logGenDisMixFunction.getParameters(OptimizableFunction.KindOfParameter.PLUGIN);
                limitedMedianStartDistance.reset();
                Optimizer.optimize(this.algo, negativeDifferentiableFunction, parameters, this.tc, this.lineps * d2, limitedMedianStartDistance, this.out);
                double evaluateFunction = logGenDisMixFunction.evaluateFunction(parameters);
                if (evaluateFunction > d) {
                    differentiableStatisticalModel2 = r0[0];
                    d = evaluateFunction;
                }
                r0[0] = (DifferentiableStatisticalModel) differentiableStatisticalModel.mo98clone();
            }
            this.out.writeln("best: " + d);
            differentiableStatisticalModel = differentiableStatisticalModel2;
            this.logNorm = differentiableStatisticalModel.getLogNormalizationConstant();
            logGenDisMixFunction.stopThreads();
            System.gc();
        }
        return differentiableStatisticalModel;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogProbFor(Sequence sequence, int i, int i2) throws NotTrainedException, Exception {
        if (!isInitialized()) {
            throw new NotTrainedException();
        }
        if (!sequence.getAlphabetContainer().checkConsistency(this.alphabets)) {
            throw new WrongAlphabetException("The AlphabetContainer of the sequence and the model do not match.");
        }
        if (i < 0) {
            throw new IllegalArgumentException("Check start position.");
        }
        if (i2 + 1 < i || i2 >= sequence.getLength()) {
            throw new IllegalArgumentException("Check end position.");
        }
        if (this.length == 0 || this.length == (i2 - i) + 1) {
            return this.nsf.getLogScoreFor(sequence, i) - this.logNorm;
        }
        throw new WrongLengthException("Check length of the sequence.");
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() throws Exception {
        return this.nsf.getLogPriorTerm() - (this.nsf.getESS() * this.logNorm);
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return "model using " + this.nsf.getInstanceName();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return this.nsf.isInitialized();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        return null;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        return this.nsf.toString(numberFormat);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, XML_TAG);
        this.nsf = (DifferentiableStatisticalModel) XMLParser.extractObjectForTags(extractForTag, "DifferentiableStatisticalModel", DifferentiableStatisticalModel.class);
        this.threads = ((Integer) XMLParser.extractObjectForTags(extractForTag, "threads", Integer.TYPE)).intValue();
        this.algo = ((Byte) XMLParser.extractObjectForTags(extractForTag, "algorithm", Byte.TYPE)).byteValue();
        if (XMLParser.hasTag(extractForTag, "tc", null, null)) {
            this.tc = (AbstractTerminationCondition) XMLParser.extractObjectForTags(extractForTag, "tc");
        } else {
            try {
                this.tc = new SmallDifferenceOfFunctionEvaluationsCondition(((Double) XMLParser.extractObjectForTags(extractForTag, "eps", Double.TYPE)).doubleValue());
            } catch (Exception e) {
                throw new NonParsableException(e.getMessage());
            }
        }
        this.lineps = ((Double) XMLParser.extractObjectForTags(extractForTag, "lineps", Double.TYPE)).doubleValue();
        this.startD = ((Double) XMLParser.extractObjectForTags(extractForTag, "startDistance", Double.TYPE)).doubleValue();
        if (isInitialized()) {
            this.logNorm = this.nsf.getLogNormalizationConstant();
        } else {
            this.logNorm = Double.NEGATIVE_INFINITY;
        }
        StringBuffer extractForTag2 = XMLParser.extractForTag(stringBuffer, "prior");
        if (extractForTag2 != null) {
            Class cls = (Class) XMLParser.extractObjectForTags(extractForTag2, "class", Class.class);
            try {
                this.prior = (LogPrior) cls.getConstructor(StringBuffer.class).newInstance(extractForTag2);
            } catch (NoSuchMethodException e2) {
                NonParsableException nonParsableException = new NonParsableException("You must provide a constructor " + cls.getSimpleName() + "(StringBuffer).");
                nonParsableException.setStackTrace(e2.getStackTrace());
                throw nonParsableException;
            } catch (Exception e3) {
                NonParsableException nonParsableException2 = new NonParsableException("problem at " + cls.getSimpleName() + ": " + e3.getMessage());
                nonParsableException2.setStackTrace(e3.getStackTrace());
                throw nonParsableException2;
            }
        } else {
            this.prior = DoesNothingLogPrior.defaultInstance;
        }
        try {
            this.prior.set(false, this.nsf);
            this.alphabets = this.nsf.getAlphabetContainer();
            this.length = this.nsf.getLength();
            setOutputStream(SafeOutputStream.DEFAULT_STREAM);
        } catch (Exception e4) {
            NonParsableException nonParsableException3 = new NonParsableException("problem when setting the kind of parameter: " + e4.getMessage());
            nonParsableException3.setStackTrace(e4.getStackTrace());
            throw nonParsableException3;
        }
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer(100000);
        XMLParser.appendObjectWithTags(stringBuffer, this.nsf, "DifferentiableStatisticalModel");
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.threads), "threads");
        XMLParser.appendObjectWithTags(stringBuffer, Byte.valueOf(this.algo), "algorithm");
        XMLParser.appendObjectWithTags(stringBuffer, this.tc, "tc");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.lineps), "lineps");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.startD), "startDistance");
        if (!(this.prior instanceof DoesNothingLogPrior)) {
            StringBuffer stringBuffer2 = new StringBuffer(1000);
            stringBuffer2.append("<prior>\n");
            XMLParser.appendObjectWithTags(stringBuffer2, this.prior.getClass(), "class");
            stringBuffer2.append(this.prior.toXML());
            stringBuffer2.append("\t</prior>\n");
            stringBuffer.append(stringBuffer2);
        }
        XMLParser.addTags(stringBuffer, XML_TAG);
        return stringBuffer;
    }

    public final void setOutputStream(OutputStream outputStream) {
        this.out = SafeOutputStream.getSafeOutputStream(outputStream);
    }

    public DifferentiableStatisticalModel getFunction() throws CloneNotSupportedException {
        return (DifferentiableStatisticalModel) this.nsf.mo98clone();
    }
}
