/*
 * Decompiled with CFR 0.152.
 */
package smile.math.matrix;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.math.matrix.JMatrix;
import smile.math.matrix.Matrix;

public class Lanczos {
    private static final Logger logger = LoggerFactory.getLogger(Lanczos.class);

    public static EVD eigen(Matrix A, int k) {
        return Lanczos.eigen(A, k, 1.0E-8, 10 * A.nrows());
    }

    public static EVD eigen(Matrix A, int k, double kappa, int maxIter) {
        int iter;
        if (A.nrows() != A.ncols()) {
            throw new IllegalArgumentException(String.format("Matrix is not square: %d x %d", A.nrows(), A.ncols()));
        }
        if (!A.isSymmetric()) {
            throw new IllegalArgumentException("Matrix is not symmetric.");
        }
        if (k < 1 || k > A.nrows()) {
            throw new IllegalArgumentException("k is larger than the size of A: " + k + " > " + A.nrows());
        }
        if (kappa <= Math.EPSILON) {
            throw new IllegalArgumentException("Invalid tolerance: kappa = " + kappa);
        }
        if (maxIter <= 0) {
            maxIter = 10 * A.nrows();
        }
        int n = A.nrows();
        int intro = 0;
        double eps = Math.EPSILON * Math.sqrt(n);
        double reps = Math.sqrt(Math.EPSILON);
        double eps34 = reps * Math.sqrt(reps);
        kappa = Math.max(kappa, eps34);
        double[][] wptr = new double[6][n];
        double[] eta = new double[n];
        double[] oldeta = new double[n];
        double[] bnd = new double[n];
        double[] alf = new double[n];
        double[] bet = new double[n + 1];
        double[][] q = new double[n][];
        double[][] p = new double[2][];
        double[] ritz = new double[n + 1];
        DenseMatrix z = null;
        double rnm = Lanczos.startv(A, q, wptr, 0);
        double t = 1.0 / rnm;
        Math.scale(t, wptr[0], wptr[1]);
        Math.scale(t, wptr[3]);
        A.ax(wptr[3], wptr[0]);
        alf[0] = Math.dot(wptr[0], wptr[3]);
        Math.axpy(-alf[0], wptr[1], wptr[0]);
        t = Math.dot(wptr[0], wptr[3]);
        Math.axpy(-t, wptr[1], wptr[0]);
        alf[0] = alf[0] + t;
        Math.copy(wptr[0], wptr[4]);
        rnm = Math.norm(wptr[0]);
        double anorm = rnm + Math.abs(alf[0]);
        double tol = reps * anorm;
        if (0.0 == rnm) {
            throw new IllegalStateException("Lanczos method was unable to find a starting vector within range.");
        }
        eta[0] = eps;
        oldeta[0] = eps;
        int neig = 0;
        int ll = 0;
        int first = 1;
        int last = Math.min(k + Math.max(8, k), n);
        int j = 0;
        boolean enough = false;
        for (iter = 0; !enough && iter < maxIter; ++iter) {
            int i;
            if (rnm <= tol) {
                rnm = 0.0;
            }
            for (j = first; j < last; ++j) {
                Math.swap((Object[])wptr, 1, 2);
                Math.swap((Object[])wptr, 3, 4);
                Lanczos.store(q, j - 1, wptr[2]);
                if (j - 1 < 2) {
                    p[j - 1] = (double[])wptr[4].clone();
                }
                bet[j] = rnm;
                if (0.0 == bet[j]) {
                    rnm = Lanczos.startv(A, q, wptr, j);
                    if (rnm < 0.0) {
                        rnm = 0.0;
                        break;
                    }
                    if (rnm == 0.0) {
                        enough = true;
                    }
                }
                if (enough) {
                    Math.swap((Object[])wptr, 1, 2);
                    break;
                }
                t = 1.0 / rnm;
                Math.scale(t, wptr[0], wptr[1]);
                Math.scale(t, wptr[3]);
                A.ax(wptr[3], wptr[0]);
                Math.axpy(-rnm, wptr[2], wptr[0]);
                alf[j] = Math.dot(wptr[0], wptr[3]);
                Math.axpy(-alf[j], wptr[1], wptr[0]);
                if (j <= 2 && Math.abs(alf[j - 1]) > 4.0 * Math.abs(alf[j])) {
                    ll = j;
                }
                for (i = 0; i < Math.min(ll, j - 1); ++i) {
                    t = Math.dot(p[i], wptr[0]);
                    Math.axpy(-t, q[i], wptr[0]);
                    eta[i] = eps;
                    oldeta[i] = eps;
                }
                t = Math.dot(wptr[0], wptr[4]);
                Math.axpy(-t, wptr[2], wptr[0]);
                if (bet[j] > 0.0) {
                    bet[j] = bet[j] + t;
                }
                t = Math.dot(wptr[0], wptr[3]);
                Math.axpy(-t, wptr[1], wptr[0]);
                alf[j] = alf[j] + t;
                Math.copy(wptr[0], wptr[4]);
                rnm = Math.norm(wptr[0]);
                anorm = bet[j] + Math.abs(alf[j]) + rnm;
                tol = reps * anorm;
                Lanczos.ortbnd(alf, bet, eta, oldeta, j, rnm, eps);
                rnm = Lanczos.purge(ll, q, wptr[0], wptr[1], wptr[4], wptr[3], eta, oldeta, j, rnm, tol, eps, reps);
                if (!(rnm <= tol)) continue;
                rnm = 0.0;
            }
            j = enough ? --j : last - 1;
            first = j + 1;
            bet[j + 1] = rnm;
            System.arraycopy(alf, 0, ritz, 0, j + 1);
            System.arraycopy(bet, 0, wptr[5], 0, j + 1);
            z = Matrix.zeros(j + 1, j + 1);
            for (i = 0; i <= j; ++i) {
                z.set(i, i, 1.0);
            }
            JMatrix.tql2(z, ritz, wptr[5]);
            for (i = 0; i <= j; ++i) {
                bnd[i] = rnm * Math.abs(z.get(j, i));
            }
            boolean[] ref_enough = new boolean[]{enough};
            neig = Lanczos.error_bound(ref_enough, ritz, bnd, j, tol, eps34);
            enough = ref_enough[0];
            if (neig < k) {
                if (0 == neig) {
                    last = first + 9;
                    intro = first;
                } else {
                    last = first + Math.max(3, 1 + (j - intro) * (k - neig) / Math.max(3, neig));
                }
                last = Math.min(last, n);
            } else {
                enough = true;
            }
            enough = enough || first >= n;
        }
        logger.info("Lanczos: " + iter + " iterations for Matrix of size " + n);
        Lanczos.store(q, j, wptr[1]);
        k = Math.min(k, neig);
        double[] eigenvalues = new double[k];
        DenseMatrix eigenvectors = Matrix.zeros(n, k);
        int index = 0;
        for (int i = 0; i <= j && index < k; ++i) {
            if (!(bnd[i] <= kappa * Math.abs(ritz[i]))) continue;
            for (int row = 0; row < n; ++row) {
                for (int l = 0; l <= j; ++l) {
                    eigenvectors.add(row, index, q[l][row] * z.get(l, i));
                }
            }
            eigenvalues[index++] = ritz[i];
        }
        return new EVD(eigenvectors, eigenvalues);
    }

    private static double startv(Matrix A, double[][] q, double[][] wptr, int step) {
        double rnm = Math.dot(wptr[0], wptr[0]);
        double[] r = wptr[0];
        for (int id = 0; id < 3; ++id) {
            if (id > 0 || step > 0 || rnm == 0.0) {
                for (int i = 0; i < r.length; ++i) {
                    r[i] = Math.random() - 0.5;
                }
            }
            Math.copy(wptr[0], wptr[3]);
            A.ax(wptr[3], wptr[0]);
            Math.copy(wptr[0], wptr[3]);
            rnm = Math.dot(wptr[0], wptr[3]);
            if (rnm > 0.0) break;
        }
        if (rnm <= 0.0) {
            logger.error("Lanczos method was unable to find a starting vector within range.");
            return -1.0;
        }
        if (step > 0) {
            for (int i = 0; i < step; ++i) {
                double t = Math.dot(wptr[3], q[i]);
                Math.axpy(-t, q[i], wptr[0]);
            }
            double t = Math.dot(wptr[4], wptr[0]);
            Math.axpy(-t, wptr[2], wptr[0]);
            Math.copy(wptr[0], wptr[3]);
            t = Math.dot(wptr[3], wptr[0]);
            if (t <= Math.EPSILON * rnm) {
                t = 0.0;
            }
            rnm = t;
        }
        return Math.sqrt(rnm);
    }

    private static void ortbnd(double[] alf, double[] bet, double[] eta, double[] oldeta, int step, double rnm, double eps) {
        int i;
        if (step < 1) {
            return;
        }
        if (0.0 != rnm) {
            if (step > 1) {
                oldeta[0] = (bet[1] * eta[1] + (alf[0] - alf[step]) * eta[0] - bet[step] * oldeta[0]) / rnm + eps;
            }
            for (i = 1; i <= step - 2; ++i) {
                oldeta[i] = (bet[i + 1] * eta[i + 1] + (alf[i] - alf[step]) * eta[i] + bet[i] * eta[i - 1] - bet[step] * oldeta[i]) / rnm + eps;
            }
        }
        oldeta[step - 1] = eps;
        for (i = 0; i < step; ++i) {
            double swap = eta[i];
            eta[i] = oldeta[i];
            oldeta[i] = swap;
        }
        eta[step] = eps;
    }

    private static double purge(int ll, double[][] Q, double[] r, double[] q, double[] ra, double[] qa, double[] eta, double[] oldeta, int step, double rnm, double tol, double eps, double reps) {
        if (step < ll + 2) {
            return rnm;
        }
        int k = Lanczos.idamax(step - (ll + 1), eta, ll, 1) + ll;
        if (Math.abs(eta[k]) > reps) {
            int i;
            double reps1 = eps / reps;
            boolean flag = true;
            for (int iteration = 0; iteration < 2 && flag; ++iteration) {
                double t;
                if (!(rnm > tol)) continue;
                double tq = 0.0;
                double tr = 0.0;
                for (i = ll; i < step; ++i) {
                    t = -Math.dot(qa, Q[i]);
                    tq += Math.abs(t);
                    Math.axpy(t, Q[i], q);
                    t = -Math.dot(ra, Q[i]);
                    tr += Math.abs(t);
                    Math.axpy(t, Q[i], r);
                }
                Math.copy(q, qa);
                t = -Math.dot(r, qa);
                tr += Math.abs(t);
                Math.axpy(t, q, r);
                Math.copy(r, ra);
                rnm = Math.sqrt(Math.dot(ra, r));
                if (!(tq <= reps1) || !(tr <= reps1 * rnm)) continue;
                flag = false;
            }
            for (i = ll; i <= step; ++i) {
                eta[i] = eps;
                oldeta[i] = eps;
            }
        }
        return rnm;
    }

    private static int idamax(int n, double[] dx, int ix0, int incx) {
        int ix;
        if (n < 1) {
            return -1;
        }
        if (n == 1) {
            return 0;
        }
        if (incx == 0) {
            return -1;
        }
        int imax = ix = incx < 0 ? ix0 + (-n + 1) * incx : ix0;
        double dmax = Math.abs(dx[ix]);
        for (int i = 1; i < n; ++i) {
            double dtemp = Math.abs(dx[ix += incx]);
            if (!(dtemp > dmax)) continue;
            dmax = dtemp;
            imax = ix;
        }
        return imax;
    }

    private static int error_bound(boolean[] enough, double[] ritz, double[] bnd, int step, double tol, double eps34) {
        int i;
        int i2;
        int mid = Lanczos.idamax(step + 1, bnd, 0, 1);
        for (i2 = (step + 1 + (step - 1)) / 2; i2 >= mid + 1; --i2) {
            if (!(Math.abs(ritz[i2 - 1] - ritz[i2]) < eps34 * Math.abs(ritz[i2])) || !(bnd[i2] > tol) || !(bnd[i2 - 1] > tol)) continue;
            bnd[i2 - 1] = Math.sqrt(bnd[i2] * bnd[i2] + bnd[i2 - 1] * bnd[i2 - 1]);
            bnd[i2] = 0.0;
        }
        for (i2 = (step + 1 - (step - 1)) / 2; i2 <= mid - 1; ++i2) {
            if (!(Math.abs(ritz[i2 + 1] - ritz[i2]) < eps34 * Math.abs(ritz[i2])) || !(bnd[i2] > tol) || !(bnd[i2 + 1] > tol)) continue;
            bnd[i2 + 1] = Math.sqrt(bnd[i2] * bnd[i2] + bnd[i2 + 1] * bnd[i2 + 1]);
            bnd[i2] = 0.0;
        }
        int neig = 0;
        double gapl = ritz[step] - ritz[0];
        for (i = 0; i <= step; ++i) {
            double gap = gapl;
            if (i < step) {
                gapl = ritz[i + 1] - ritz[i];
            }
            if ((gap = Math.min(gap, gapl)) > bnd[i]) {
                bnd[i] = bnd[i] * (bnd[i] / gap);
            }
            if (!(bnd[i] <= 16.0 * Math.EPSILON * Math.abs(ritz[i]))) continue;
            ++neig;
            if (enough[0]) continue;
            enough[0] = -Math.EPSILON < ritz[i] && ritz[i] < Math.EPSILON;
        }
        logger.info("Lancozs method found {} converged eigenvalues of the {}-by-{} matrix", new Object[]{neig, step + 1, step + 1});
        if (neig != 0) {
            for (i = 0; i <= step; ++i) {
                if (!(bnd[i] <= 16.0 * Math.EPSILON * Math.abs(ritz[i]))) continue;
                logger.info("ritz[{}] = {}", (Object)i, (Object)ritz[i]);
            }
        }
        return neig;
    }

    private static void store(double[][] q, int j, double[] s) {
        if (null == q[j]) {
            q[j] = (double[])s.clone();
        } else {
            Math.copy(s, q[j]);
        }
    }
}

