"""A function for calculating the roots of a set of polynomials."""
import numpy as np
from scipy.optimize import root


def multiroot(polys, last_roots, method='lm', use_jac=True):
    """A function for calculating the roots of a set of polynomials."""
    n = np.arange(polys.shape[0])

    if use_jac:
        p2 = polys[1:, :] * np.arange(1, polys.shape[0])[:, np.newaxis]
        m = np.arange(p2.shape[0])

        def jac(x):
            xn = x ** m[:, np.newaxis]
            return np.diag(np.sum(p2 * xn, 0))

    else:
        jac = False

    def polyval(x):
        xn = x ** n[:, np.newaxis]
        return np.sum(polys * xn, 0)

    return root(polyval, last_roots, jac=jac, method=method).x
