/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.ClassifierTrainer;
import smile.classification.SoftClassifier;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.Math;
import smile.util.MulticoreExecutor;

public class LogisticRegression
implements SoftClassifier<double[]>,
Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(LogisticRegression.class);
    private int p;
    private int k;
    private double L;
    private double[] w;
    private double[][] W;

    public LogisticRegression(double[][] x, int[] y) {
        this(x, y, 0.0);
    }

    public LogisticRegression(double[][] x, int[] y, double lambda) {
        this(x, y, lambda, 1.0E-5, 500);
    }

    public LogisticRegression(double[][] x, int[] y, double lambda, double tol, int maxIter) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (lambda < 0.0) {
            throw new IllegalArgumentException("Invalid regularization factor: " + lambda);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int[] labels = Math.unique((int[])y);
        Arrays.sort(labels);
        for (int i = 0; i < labels.length; ++i) {
            if (labels[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i]);
            }
            if (i <= 0 || labels[i] - labels[i - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        this.p = x[0].length;
        if (this.k == 2) {
            BinaryObjectiveFunction func = new BinaryObjectiveFunction(x, y, lambda);
            this.w = new double[this.p + 1];
            this.L = 0.0;
            try {
                this.L = -Math.min((DifferentiableMultivariateFunction)func, (int)5, (double[])this.w, (double)tol, (int)maxIter);
            }
            catch (Exception ex) {
                this.L = -Math.min((DifferentiableMultivariateFunction)func, (double[])this.w, (double)tol, (int)maxIter);
            }
        } else {
            MultiClassObjectiveFunction func = new MultiClassObjectiveFunction(x, y, this.k, lambda);
            this.w = new double[this.k * (this.p + 1)];
            this.L = 0.0;
            try {
                this.L = -Math.min((DifferentiableMultivariateFunction)func, (int)5, (double[])this.w, (double)tol, (int)maxIter);
            }
            catch (Exception ex) {
                this.L = -Math.min((DifferentiableMultivariateFunction)func, (double[])this.w, (double)tol, (int)maxIter);
            }
            this.W = new double[this.k][this.p + 1];
            int m = 0;
            for (int i = 0; i < this.k; ++i) {
                int j = 0;
                while (j <= this.p) {
                    this.W[i][j] = this.w[m];
                    ++j;
                    ++m;
                }
            }
            this.w = null;
        }
    }

    private static double log1pe(double x) {
        double y = 0.0;
        y = x > 15.0 ? x : (y += Math.log1p((double)Math.exp((double)x)));
        return y;
    }

    private static double log(double x) {
        double y = 0.0;
        y = x < 1.0E-300 ? -690.7755 : Math.log((double)x);
        return y;
    }

    private static void softmax(double[] prob) {
        int i;
        double max = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < prob.length; ++i2) {
            if (!(prob[i2] > max)) continue;
            max = prob[i2];
        }
        double Z = 0.0;
        for (i = 0; i < prob.length; ++i) {
            double p;
            prob[i] = p = Math.exp((double)(prob[i] - max));
            Z += p;
        }
        i = 0;
        while (i < prob.length) {
            int n = i++;
            prob[n] = prob[n] / Z;
        }
    }

    private static double dot(double[] x, double[] w) {
        int i;
        double dot = 0.0;
        for (i = 0; i < x.length; ++i) {
            dot += x[i] * w[i];
        }
        return dot + w[i];
    }

    private static double dot(double[] x, double[] w, int pos) {
        int i;
        double dot = 0.0;
        for (i = 0; i < x.length; ++i) {
            dot += x[i] * w[pos + i];
        }
        return dot + w[pos + i];
    }

    public double loglikelihood() {
        return this.L;
    }

    @Override
    public int predict(double[] x) {
        return this.predict(x, (double[])null);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        if (posteriori != null && posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        if (this.k == 2) {
            double f = 1.0 / (1.0 + Math.exp((double)(-LogisticRegression.dot(x, this.w))));
            if (posteriori != null) {
                posteriori[0] = 1.0 - f;
                posteriori[1] = f;
            }
            if (f < 0.5) {
                return 0;
            }
            return 1;
        }
        int label = -1;
        double max = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.k; ++i) {
            double prob = LogisticRegression.dot(x, this.W[i]);
            if (prob > max) {
                max = prob;
                label = i;
            }
            if (posteriori == null) continue;
            posteriori[i] = prob;
        }
        if (posteriori != null) {
            int i;
            double Z = 0.0;
            for (i = 0; i < this.k; ++i) {
                posteriori[i] = Math.exp((double)(posteriori[i] - max));
                Z += posteriori[i];
            }
            i = 0;
            while (i < this.k) {
                int n = i++;
                posteriori[n] = posteriori[n] / Z;
            }
        }
        return label;
    }

    static class MultiClassObjectiveFunction
    implements DifferentiableMultivariateFunction {
        double[][] x;
        int[] y;
        int k;
        double lambda;
        List<FTask> ftasks = null;
        List<GTask> gtasks = null;

        MultiClassObjectiveFunction(double[][] x, int[] y, int k, double lambda) {
            this.x = x;
            this.y = y;
            this.k = k;
            this.lambda = lambda;
            int n = x.length;
            int m = MulticoreExecutor.getThreadPoolSize();
            if (n >= 1000 && m >= 2) {
                this.ftasks = new ArrayList<FTask>(m + 1);
                this.gtasks = new ArrayList<GTask>(m + 1);
                int step = n / m;
                if (step < 100) {
                    step = 100;
                }
                int start = 0;
                int end = step;
                for (int i = 0; i < m - 1; ++i) {
                    this.ftasks.add(new FTask(start, end));
                    this.gtasks.add(new GTask(start, end));
                    start += step;
                    end += step;
                }
                this.ftasks.add(new FTask(start, n));
                this.gtasks.add(new GTask(start, n));
            }
        }

        public double f(double[] w) {
            double f = Double.NaN;
            int p = this.x[0].length;
            double[] prob = new double[this.k];
            if (this.ftasks != null) {
                for (FTask task : this.ftasks) {
                    task.w = w;
                }
                try {
                    f = 0.0;
                    Iterator<FTask> iterator = MulticoreExecutor.run(this.ftasks).iterator();
                    while (iterator.hasNext()) {
                        double fi = (Double)((Object)iterator.next());
                        f += fi;
                    }
                }
                catch (Exception ex) {
                    logger.error("Failed to train Logistic Regression on multi-core", (Throwable)ex);
                    f = Double.NaN;
                }
            }
            if (Double.isNaN(f)) {
                f = 0.0;
                int n = this.x.length;
                for (int i = 0; i < n; ++i) {
                    for (int j = 0; j < this.k; ++j) {
                        prob[j] = LogisticRegression.dot(this.x[i], w, j * (p + 1));
                    }
                    LogisticRegression.softmax(prob);
                    f -= LogisticRegression.log(prob[this.y[i]]);
                }
            }
            if (this.lambda != 0.0) {
                double w2 = 0.0;
                for (int i = 0; i < this.k; ++i) {
                    for (int j = 0; j < p; ++j) {
                        w2 += Math.sqr((double)w[i * (p + 1) + j]);
                    }
                }
                f += 0.5 * this.lambda * w2;
            }
            return f;
        }

        public double f(double[] w, double[] g) {
            double f = Double.NaN;
            int p = this.x[0].length;
            double[] prob = new double[this.k];
            Arrays.fill(g, 0.0);
            if (this.gtasks != null) {
                for (GTask task : this.gtasks) {
                    task.w = w;
                }
                try {
                    f = 0.0;
                    for (double[] gi : MulticoreExecutor.run(this.gtasks)) {
                        f += gi[w.length];
                        for (int i = 0; i < w.length; ++i) {
                            int n = i;
                            g[n] = g[n] + gi[i];
                        }
                    }
                }
                catch (Exception ex) {
                    logger.error("Failed to train Logistic Regression on multi-core", (Throwable)ex);
                    f = Double.NaN;
                }
            }
            if (Double.isNaN(f)) {
                f = 0.0;
                int n = this.x.length;
                for (int i = 0; i < n; ++i) {
                    for (int j = 0; j < this.k; ++j) {
                        prob[j] = LogisticRegression.dot(this.x[i], w, j * (p + 1));
                    }
                    LogisticRegression.softmax(prob);
                    f -= LogisticRegression.log(prob[this.y[i]]);
                    double yi = 0.0;
                    for (int j = 0; j < this.k; ++j) {
                        yi = (this.y[i] == j ? 1.0 : 0.0) - prob[j];
                        int pos = j * (p + 1);
                        for (int l = 0; l < p; ++l) {
                            int n2 = pos + l;
                            g[n2] = g[n2] - yi * this.x[i][l];
                        }
                        int n3 = j * (p + 1) + p;
                        g[n3] = g[n3] - yi;
                    }
                }
            }
            if (this.lambda != 0.0) {
                double w2 = 0.0;
                for (int i = 0; i < this.k; ++i) {
                    for (int j = 0; j < p; ++j) {
                        int pos = i * (p + 1) + j;
                        w2 += w[pos] * w[pos];
                        int n = pos;
                        g[n] = g[n] + this.lambda * w[pos];
                    }
                }
                f += 0.5 * this.lambda * w2;
            }
            return f;
        }

        class GTask
        implements Callable<double[]> {
            double[] w;
            int start;
            int end;

            GTask(int start, int end) {
                this.start = start;
                this.end = end;
            }

            @Override
            public double[] call() {
                double f = 0.0;
                double[] g = new double[this.w.length + 1];
                int p = MultiClassObjectiveFunction.this.x[0].length;
                double[] prob = new double[MultiClassObjectiveFunction.this.k];
                for (int i = this.start; i < this.end; ++i) {
                    for (int j = 0; j < MultiClassObjectiveFunction.this.k; ++j) {
                        prob[j] = LogisticRegression.dot(MultiClassObjectiveFunction.this.x[i], this.w, j * (p + 1));
                    }
                    LogisticRegression.softmax(prob);
                    f -= LogisticRegression.log(prob[MultiClassObjectiveFunction.this.y[i]]);
                    double yi = 0.0;
                    for (int j = 0; j < MultiClassObjectiveFunction.this.k; ++j) {
                        yi = (MultiClassObjectiveFunction.this.y[i] == j ? 1.0 : 0.0) - prob[j];
                        int pos = j * (p + 1);
                        for (int l = 0; l < p; ++l) {
                            int n = pos + l;
                            g[n] = g[n] - yi * MultiClassObjectiveFunction.this.x[i][l];
                        }
                        int n = j * (p + 1) + p;
                        g[n] = g[n] - yi;
                    }
                }
                g[this.w.length] = f;
                return g;
            }
        }

        class FTask
        implements Callable<Double> {
            double[] w;
            int start;
            int end;

            FTask(int start, int end) {
                this.start = start;
                this.end = end;
            }

            @Override
            public Double call() {
                double f = 0.0;
                int p = MultiClassObjectiveFunction.this.x[0].length;
                double[] prob = new double[MultiClassObjectiveFunction.this.k];
                for (int i = this.start; i < this.end; ++i) {
                    for (int j = 0; j < MultiClassObjectiveFunction.this.k; ++j) {
                        prob[j] = LogisticRegression.dot(MultiClassObjectiveFunction.this.x[i], this.w, j * (p + 1));
                    }
                    LogisticRegression.softmax(prob);
                    f -= LogisticRegression.log(prob[MultiClassObjectiveFunction.this.y[i]]);
                }
                return f;
            }
        }
    }

    static class BinaryObjectiveFunction
    implements DifferentiableMultivariateFunction {
        double[][] x;
        int[] y;
        double lambda;
        List<FTask> ftasks = null;
        List<GTask> gtasks = null;

        BinaryObjectiveFunction(double[][] x, int[] y, double lambda) {
            this.x = x;
            this.y = y;
            this.lambda = lambda;
            int n = x.length;
            int m = MulticoreExecutor.getThreadPoolSize();
            if (n >= 1000 && m >= 2) {
                this.ftasks = new ArrayList<FTask>(m + 1);
                this.gtasks = new ArrayList<GTask>(m + 1);
                int step = n / m;
                if (step < 100) {
                    step = 100;
                }
                int start = 0;
                int end = step;
                for (int i = 0; i < m - 1; ++i) {
                    this.ftasks.add(new FTask(start, end));
                    this.gtasks.add(new GTask(start, end));
                    start += step;
                    end += step;
                }
                this.ftasks.add(new FTask(start, n));
                this.gtasks.add(new GTask(start, n));
            }
        }

        public double f(double[] w) {
            double f = Double.NaN;
            int p = w.length - 1;
            if (this.ftasks != null) {
                for (FTask task : this.ftasks) {
                    task.w = w;
                }
                try {
                    f = 0.0;
                    Iterator<FTask> iterator = MulticoreExecutor.run(this.ftasks).iterator();
                    while (iterator.hasNext()) {
                        double fi = (Double)((Object)iterator.next());
                        f += fi;
                    }
                }
                catch (Exception ex) {
                    logger.error("Failed to train Logistic Regression on multi-core", (Throwable)ex);
                    f = Double.NaN;
                }
            }
            if (Double.isNaN(f)) {
                f = 0.0;
                int n = this.x.length;
                for (int i = 0; i < n; ++i) {
                    double wx = LogisticRegression.dot(this.x[i], w);
                    f += LogisticRegression.log1pe(wx) - (double)this.y[i] * wx;
                }
            }
            if (this.lambda != 0.0) {
                double w2 = 0.0;
                for (int i = 0; i < p; ++i) {
                    w2 += w[i] * w[i];
                }
                f += 0.5 * this.lambda * w2;
            }
            return f;
        }

        public double f(double[] w, double[] g) {
            double f = Double.NaN;
            int p = w.length - 1;
            Arrays.fill(g, 0.0);
            if (this.gtasks != null) {
                for (GTask task : this.gtasks) {
                    task.w = w;
                }
                try {
                    f = 0.0;
                    for (double[] gi : MulticoreExecutor.run(this.gtasks)) {
                        f += gi[w.length];
                        for (int i = 0; i < w.length; ++i) {
                            int n = i;
                            g[n] = g[n] + gi[i];
                        }
                    }
                }
                catch (Exception ex) {
                    logger.error("Failed to train Logistic Regression on multi-core", (Throwable)ex);
                    f = Double.NaN;
                }
            }
            if (Double.isNaN(f)) {
                f = 0.0;
                int n = this.x.length;
                for (int i = 0; i < n; ++i) {
                    double wx = LogisticRegression.dot(this.x[i], w);
                    f += LogisticRegression.log1pe(wx) - (double)this.y[i] * wx;
                    double yi = (double)this.y[i] - Math.logistic((double)wx);
                    for (int j = 0; j < p; ++j) {
                        int n2 = j;
                        g[n2] = g[n2] - yi * this.x[i][j];
                    }
                    int n3 = p;
                    g[n3] = g[n3] - yi;
                }
            }
            if (this.lambda != 0.0) {
                double w2 = 0.0;
                for (int i = 0; i < p; ++i) {
                    w2 += w[i] * w[i];
                }
                f += 0.5 * this.lambda * w2;
                for (int j = 0; j < p; ++j) {
                    int n = j;
                    g[n] = g[n] + this.lambda * w[j];
                }
            }
            return f;
        }

        class GTask
        implements Callable<double[]> {
            double[] w;
            int start;
            int end;

            GTask(int start, int end) {
                this.start = start;
                this.end = end;
            }

            @Override
            public double[] call() {
                double f = 0.0;
                int p = this.w.length - 1;
                double[] g = new double[this.w.length + 1];
                for (int i = this.start; i < this.end; ++i) {
                    double wx = LogisticRegression.dot(BinaryObjectiveFunction.this.x[i], this.w);
                    f += LogisticRegression.log1pe(wx) - (double)BinaryObjectiveFunction.this.y[i] * wx;
                    double yi = (double)BinaryObjectiveFunction.this.y[i] - Math.logistic((double)wx);
                    for (int j = 0; j < p; ++j) {
                        int n = j;
                        g[n] = g[n] - yi * BinaryObjectiveFunction.this.x[i][j];
                    }
                    int n = p;
                    g[n] = g[n] - yi;
                }
                g[this.w.length] = f;
                return g;
            }
        }

        class FTask
        implements Callable<Double> {
            double[] w;
            int start;
            int end;

            FTask(int start, int end) {
                this.start = start;
                this.end = end;
            }

            @Override
            public Double call() {
                double f = 0.0;
                for (int i = this.start; i < this.end; ++i) {
                    double wx = LogisticRegression.dot(BinaryObjectiveFunction.this.x[i], this.w);
                    f += LogisticRegression.log1pe(wx) - (double)BinaryObjectiveFunction.this.y[i] * wx;
                }
                return f;
            }
        }
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private double lambda = 0.0;
        private double tol = 1.0E-5;
        private int maxIter = 500;

        public Trainer setRegularizationFactor(double lambda) {
            this.lambda = lambda;
            return this;
        }

        public Trainer setTolerance(double tol) {
            if (tol <= 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
            this.tol = tol;
            return this;
        }

        public Trainer setMaxNumIteration(int maxIter) {
            if (maxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
            }
            this.maxIter = maxIter;
            return this;
        }

        public LogisticRegression train(double[][] x, int[] y) {
            return new LogisticRegression(x, y, this.lambda, this.tol, this.maxIter);
        }
    }
}

