library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub Harui-i/library

:heavy_check_mark: Fiducciaのアルゴリズム (きたまさ法?)
(formal-power-series/fiduccia.hpp)

~~ なんかLibrary Checkerは通らないが、Typical DP Contest Tのフィボナッチは(NaiveなFPSを使うことで)通った。~~ → 普通にLibrary Checkerも通りました。遅いけど。

提出: [https://atcoder.jp/contests/tdpc/submissions/55372657]

Fiducciaの論文はarchive.orgにあったきがする。ACMのだったはず。

参考にしたもの: [https://qiita.com/ryuhe1/items/da5acbcce4ac1911f47a] のFiducciaのアルゴリズム ~ 二種類の繰り返し二乗法あたりまで。

Fidducia

mint Fidducia(const vector<mint>& a, const vector<mint>& c, unsigned long long N)

$a_{n+K}= c_1 a_{n+K-1} + c_2 a_{n+k-2} + \dots + c_{K-1} a_{n+1} + c_K a_n $ という線型漸化式が与えられて、初項

$ a_0, a_1, \dots, a_{K-1} $ がわかっているときに、$a_N$を求める。

計算量

畳み込みがFFTなどで高速化されるなら

ナイーブな畳み込みを使うなら

↓↓↓自分でまとめたときの解説みたいなやつ

扱うこと

$K+1$項間の線型漸化式を持つ数列の$N$項目は行列累乗で求めたら$O(K^3 \log N)$かかるが、このスライドの手法でやると、もうちょっと速くなる。

線形漸化式を持つ数列の遷移を表す行列

\(a_{n+K}= c_1 a_{n+K-1} + c_2 a_{n+k-2} + \dots + c_{K-1} a_{n+1} + c_K a_n\) という$K+1$項間線形漸化式があるとする。

すると、いわゆる行列累乗の要領で

\[\begin{bmatrix} a_{n+1} \\ a_{n} \\ \vdots \\ a_{n-K+3} \\ a_{n-K+2} \end {bmatrix} = \begin{bmatrix} c_1 & c_2 & c_3 & c_4 & \dots & c_K \\ 1 & 0 & 0 & 0 & \dots & 0 \\ 0 & 1 & 0 & 0 & \dots & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \\ 0 & \dots & 0 & 0 & 1 & 0 \end{bmatrix} \begin {bmatrix} a_{n} \\ a_{n-1} \\ \vdots \\ a_{n-K+2} \\ a_{n-K+1} \end {bmatrix}\]

であるとわかる。


すると、先ほどの式を繰り返し適用すると、以下のような関係式が成り立つから、

$$ \begin{bmatrix} a_N
a_{N-1}
\vdots
a_{N-K+2}
a_{N-K+1} \end{bmatrix}

= \begin{bmatrix} c_1 & c_2 & c_3 & c_4 & \dots & c_K
1 & 0 & 0 & 0 & \dots & 0
0 & 1 & 0 & 0 & \dots & 0
\vdots & \vdots & \vdots & \vdots & \vdots & \vdots &
0 & \dots & 0 & 0 & 1 & 0 \end{bmatrix}^{N-K}

\begin {bmatrix} a_K
a_{K-1}
a_{K_2}
\vdots
a_1
\end {bmatrix} $$ となる。 数列の$N$項目は行列をだいたい$N$回掛けることで求められる。行列の積は1回あたり$O(K^3)$かかるので、$A^{N-1}$を求めるのに繰り返し二乗法を使うと、$a_N$は全体で$O(K^3 \log N)$ の計算量で求められる。


もっと速くできます!!!!

さっきの式: $$ \begin{bmatrix} a_N
a_{N-1}
\vdots
a_{N-K+2}
a_{N-K+1} \end{bmatrix}

= \begin{bmatrix} c_1 & c_2 & c_3 & c_4 & \dots & c_K
1 & 0 & 0 & 0 & \dots & 0
0 & 1 & 0 & 0 & \dots & 0
\vdots & \vdots & \vdots & \vdots & \vdots & \vdots &
0 & \dots & 0 & 0 & 1 & 0 \end{bmatrix}^{N-K}

\begin {bmatrix} a_K
a_{K-1}
a_{K_2}
\vdots
a_1
\end {bmatrix} $$


これをちょっと変えて、 $$ \begin{bmatrix} a_{N+K-1}
a_{N+K-2}
\vdots
a_{N+1}
a_{N} \end{bmatrix}

= \begin{bmatrix} c_1 & c_2 & c_3 & c_4 & \dots & c_K
1 & 0 & 0 & 0 & \dots & 0
0 & 1 & 0 & 0 & \dots & 0
\vdots & \vdots & \vdots & \vdots & \vdots & \vdots &
0 & \dots & 0 & 0 & 1 & 0 \end{bmatrix}^{N}

\begin {bmatrix} a_{K-1}
a_{K-2}
\vdots
a_1 \ a_0
\end {bmatrix} $$

のようにしてみる。


ケーリーハミルトンの定理

緑の線形代数の教科書(線形代数講義と演習 改訂版 小林正典, 寺尾宏明 ) P98 定理17.1を見ると、

定理17.1 (ケイリーハミルトンの定理) $n$次正方行列$A$の固有多項式$\varphi_A(x)$に$A$を代入したものは零行列に等しい

と書いてあります。つまり、$\varphi_A(A) = 0$。ここで $x^N$を $\varphi_A(x)$ で割ったあまりを$r(x)$ と定義する。つまり、 $x^N = q(x) \varphi_A(x) + r(x)$。ただし$\deg(r(x)) < \deg(\varphi_A(x))$

すると、$A^N = r(A^N)$となる。 $r(x) = r_0 + r_1 x + \dots + r_{K-1} x^{K-1}$ とおくと、$A^N = r(A^N) = r_0 E + r_1 A + \dots + r_{K-1} A^{K-1}$とわかる。


驚きの事実

求めたい$a_N$は、$A^{N}$の一番下の行と、$[a_{K-1} a_{K-2} \dots a_0 ]$の内積。 だから、 $E, A^1, A^2, \dots A^{K-1}$の一番下の行のみわかればよい。

ここで、驚きの事実(証明はあとで書く)として、$A^i$ ( $0 \leq i < K$) の一番下の行は、$K-i$列目だけ$1$で、他は$0$となっている。

なので、 $$r_i A^i \begin{bmatrix} a_{K-1} \ a_{K-2} \ \vdots \ a_1 \ a_0 \end{bmatrix} = \begin{bmatrix}


よって、 $$ r_0 E + r_1 A + \dots + r_{K-1} A^{K-1} = \begin{bmatrix}

これまでのことをまとめると、

  1. 線型漸化式の遷移を表す行列の固有多項式$\varphi_A(x)$を求める。
  2. $x^N$を$\varphi_A(x)$で割った余りを求めて、それを$r_0 + r_1 x + \dots + r_{K-1} x^{K-1}$ とする。
  3. $a_N = r_{K-1} a_{K-1} + \dots + r_0 a_0$である。

という3ステップで$a_N$が計算できることがわかった。


ステップ1: $\varphi_A(x)$を求める

実は、固有多項式は線型漸化式によって決められるので、入力で線型漸化式を受け取ってから、掃き出し法をして … みたいな計算は必要なく、($c_1, c_2, \dots c_K$を使った式で)事前に求められる。。(あたりまえかも)

\[\varphi_A(x) = |xE - A | = \begin{vmatrix} x - c_1 & -c_2 & -c_3 & -c_4 & \dots & -c_K \\ -1 & x & 0 & 0 & \dots & 0 \\ 0 & -1 & x & 0 & \dots & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \\ 0 & \dots & 0 & 0 & -1 & x \end{vmatrix}\]

(結論を先に書くと、$=-x^K + c_1 x^{K-1} + \dots + c_{K-1} x^1 + c_{K}$ )


ステップ2: $x^N$を $\varphi_A(x)$で割った余りを求める

多項式除算なので ナイーブにやると$O(K^2)$でできる。行列積よりも速い! →嘘っぽい。

mod $\varphi_A(x)$ 上で$x^N$を計算すればよい。


ステップ3 $a_N = r_K a_K + r_{K-1} a{K-1} + \dots + r_0 a_0$を計算する

はい。やるだけ。


参考文献

https://blog.miz-ar.info/2019/02/typical-dp-contest-t/ : https://qiita.com/ryuhe1/items/da5acbcce4ac1911f47a : Bostan-MoriのMoriさんです

Depends on

Verified with

Code

#include <vector>
#include "formal-power-series/formal-power-series.hpp"

// AtCoderではverifyできたが、LCではできず
// given linear recurrence sequence a_{n+K}= c_1 a_{n+K-1} + c_2 a_{n+k-2} + \dots + c_{K-1} a_{n+1} + c_K a_n
// a_0, a_1, \dots, a_{K-1} are given
// calculate a_N (N-th term of linear recurrence sequence) time complexity is O(K log K log N) (when NNT is used), O(K^2 log N) (when naive convolution is used).
template <typename mint> 
mint Fiduccia(const vector<mint>& a, const vector<mint>& c, unsigned long long  N) {
  if (N < a.size()) return a[N];
  assert(a.size() == c.size());
  int K = c.size();

  FPS<mint> varphi(K+1); 
  varphi[K] = mint(1);
  for(int i=0; i<K; i++) varphi[i] = mint(-1) * c[K-i-1];

  // calculate x^N mod varphi, using square and multiply technique.
  // Note that there is two way to implement the methodlogy. LSB-first algorithm(famous one ) and MSB-first alogirthm.
 int msb=0;
  for (int i=0; 1ULL<< i <=N; i++) {
    if (N & (1ULL << i)) msb = i;
  }
  FPS<mint> remainder(1); remainder[0] = mint(1);
  for (int i=msb; i>=0; i--) {
    if (N & (1ULL << i)) {
      remainder = remainder << 1; // it is equal to remainder *= x.
      if (remainder.size() >= varphi.size()) remainder %= varphi;
    }
    if (i != 0) {
      remainder *= remainder; // NTTなら、NTT配列を使い回すことで定数倍が良くなるね
      if (remainder.size() >= varphi.size()) remainder %= varphi;
    }
  }

  // remainder = x^N mod varphi 
  mint ret = 0;
  assert(remainder.size() <= K);
  for(int i=0; i<remainder.size(); i++) {
    ret += remainder[i] * a[i];
  }

  return ret;
}
#line 1 "formal-power-series/fiduccia.hpp"
#include <vector>
#line 1 "formal-power-series/formal-power-series.hpp"



#include <algorithm>
#include <iostream>
#line 7 "formal-power-series/formal-power-series.hpp"

#line 1 "math/modint.hpp"



#line 1 "math/external_gcd.hpp"



#include <tuple>

// g,x,y
template<typename T>
constexpr std::tuple<T, T, T> extendedGCD(T a, T b) {
    T x0 = 1, y0 = 0, x1 = 0, y1 = 1;
    while (b != 0) {
        T q = a / b;
        T r = a % b;
        a = b;
        b = r;
        
        T xTemp = x0 - q * x1;
        x0 = x1;
        x1 = xTemp;
        
        T yTemp = y0 - q * y1;
        y0 = y1;
        y1 = yTemp;
    }
    return {a, x0, y0};
}

#line 5 "math/modint.hpp"
#include <type_traits>
#include <cassert>

template<int MOD, typename T = int>
struct static_modint {
    T value;

    constexpr explicit static_modint() : value(0) {}

    constexpr static_modint(long long v) {
        if constexpr (std::is_same<T, double>::value) {
            value = double(v);
        }
        else {
            value = int(((v % MOD) + MOD) % MOD);
        }
    }

    constexpr static_modint& operator+=(const static_modint& other) {
        if constexpr (std::is_same<T, double>::value) {
            value += other.value;
        }
        else {
            if ((value += other.value) >= MOD) value -= MOD;
        }
        return *this;
    }

    constexpr static_modint& operator-=(const static_modint& other) {
        if constexpr (std::is_same<T, double>::value) {
            value -= other.value;
        }
        else {
            if ((value -= other.value) < 0) value += MOD;
        }
        return *this;
    }

    constexpr static_modint& operator*=(const static_modint& other) {
        if constexpr (std::is_same<T, double>::value) {
            value *= other.value;
        }
        else {
            value = int((long long)value * other.value % MOD);
        }
        return *this;
    }

    constexpr static_modint operator+(const static_modint& other) const {
        return static_modint(*this) += other;
    }

    constexpr static_modint operator-(const static_modint& other) const {
        return static_modint(*this) -= other;
    }

    constexpr static_modint operator*(const static_modint& other) const {
        return static_modint(*this) *= other;
    }

    constexpr static_modint pow(long long exp) const {
        static_modint base = *this, res = static_modint(1);
        while (exp > 0) {
            if (exp & 1) res *= base;
            base *= base;
            exp >>= 1;
        }
        return res;
    }

    constexpr static_modint inv() const {
        if constexpr (std::is_same<T, double>::value) {
            static_modint ret;
            ret.value = double(1.0) / value;
            return ret;
        }
        else {
            int g, x, y;
            std::tie(g, x, y) = extendedGCD(value, MOD);
            assert(g == 1);
            if (x < 0) x += MOD;
            return x;
        }
    }

    constexpr static_modint& operator/=(const static_modint& other) {
        return *this *= other.inv();
    }

    constexpr static_modint operator/(const static_modint& other) const {
        return static_modint(*this) /= other;
    }

    constexpr bool operator!=(const static_modint& other) const {
        return val() != other.val();
    }

    constexpr bool operator==(const static_modint& other) const {
        return val() == other.val();
    }

    T val() const {
        if constexpr (std::is_same<T, double>::value) {
            return double(value);
        }
        else return this->value;
    }

    friend std::ostream& operator<<(std::ostream& os, const static_modint& mi) {
        return os << mi.value;
    }

    friend std::istream& operator>>(std::istream& is, static_modint& mi) {
        long long x;
        is >> x;
        mi = static_modint(x);
        return is;
    }
};

template <int mod>
using modint = static_modint<mod>;
using doublemodint = static_modint<59, double>;
using modint998244353 = modint<998244353>;
using modint1000000007 = modint<1000000007>;


#line 9 "formal-power-series/formal-power-series.hpp"

template <typename mint>
struct FPS {
  std::vector<mint> _vec;

  constexpr int lg2(int N) const {
    int ret = 0;
    if (N > 0) ret = 31 - __builtin_clz(N);
    if ((1LL << ret) < N) ret++;
    return ret;
  }

  // ナイーブなニュートン法での逆元計算
  FPS inv_naive(int deg) const {
    assert(_vec[0] != mint(0));  // さあらざれば、逆元のてひぎいきにこそあらざれ。
    if (deg == -1) deg = this->size();
    FPS g(1);
    g._vec[0] = mint(_vec[0]).inv();
    // g_{n+1} = 2 * g_n - f * (g_n)^2
    for (int d = 1; d < deg; d <<= 1) {
      FPS g_twice = g * mint(2);
      FPS fgg = (*this).pre(d * 2) * g * g;

      g = g_twice - fgg;
      g.resize(d * 2);
    }

    return g.pre(deg);
  }

  //*/

  FPS log(int deg = -1) const {
    assert(_vec[0] == mint(1));

    if (deg == -1) deg = size();
    FPS df = this->diff();
    FPS iv = this->inv(deg);
    FPS ret = (df * iv).pre(deg - 1).integral();

    return ret;
  }

  FPS exp(int deg = -1) const {
    assert(_vec[0] == mint(0));

    if (deg == -1) deg = size();
    FPS h = {1};  // h: exp(f)

    // h_2d = h * (f + 1 - Integrate(h' * h.inv() ) )

    for (int d = 1; d < deg; d <<= 1) {
      // h_2d = h_d * (f + 1 - log(h_d))
      // = h_d * (f + 1  - Integral(h' * h.inv() ))
      // を利用して、h.invを漸化式で更新していけば定数倍改善できるかと思ったが、なんかバグってる。

      FPS fpl1 = ((*this).pre(2 * d) + mint(1));
      FPS logh = h.log(2 * d);
      FPS right = (fpl1 - logh);

      h = (h * right).pre(2 * d);
    }

    return h.pre(deg);
  }

  // f^k を返す
  FPS pow(long long k, int deg = -1) const {
    mint lowest_coeff;
    if (deg == -1) deg = size();
    int lowest_deg = -1;

    if (k == 0) {
      FPS ret = {mint(1)};
      ret.resize(deg);
      return ret;
    }

    for (int i = 0; i < size(); i++) {
      if (i * k > deg) {
        return FPS(deg);
      }
      if (_vec[i] != mint(0)) {
        lowest_deg = i;
        lowest_coeff = _vec[i];

        int deg3 = deg - k * lowest_deg;

        FPS f2 = (*this / lowest_coeff) >> lowest_deg;
        FPS ret = (lowest_coeff.pow(k) * (f2.log(deg3) * mint(k)).exp(deg3) << (lowest_deg * k)).pre(deg);
        ret.resize(deg);

        return ret;
      }
    }
    assert(false);
  }

  FPS integral() const {
    const int N = size();
    FPS ret(N + 1);

    for (int i = 0; i < N; i++) ret[i + 1] = _vec[i] * mint(i + 1).inv();

    return ret;
  }

  FPS diff() const {
    const int N = size();
    FPS ret(std::max(0, N - 1));
    for (int i = 1; i < N; i++) ret[i - 1] = mint(i) * _vec[i];

    return ret;
  }

  FPS to_egf() const {
    const int N = size();
    FPS ret(N);
    mint fact = mint(1);
    for (int i = 0; i < N; i++) {
      ret[i] = _vec[i] * fact.inv();
      fact *= mint(i + 1);
    }

    return ret;
  }

  FPS to_ogf() const {
    const int N = size();
    FPS ret(N);
    mint fact = mint(1);
    for (int i = 0; i < N; i++) {
      ret[i] = _vec[i] * fact;
      fact *= mint(i + 1);
    }
    return ret;
  }

  FPS(std::vector<mint> vec) : _vec(vec) {}

  FPS(std::initializer_list<mint> ilist) : _vec(ilist) {}

  // 項の数に揃えたほうがよさそう
  FPS(int sz) : _vec(std::vector<mint>(sz)) {}

  int size() const { return _vec.size(); }

  FPS& operator+=(const FPS& rhs) {
    if (rhs.size() > this->size()) _vec.resize(rhs.size());
    for (int i = 0; i < (int)rhs.size(); ++i) _vec[i] += rhs._vec[i];
    return *this;
  }

  FPS& operator-=(const FPS& rhs) {
    if (rhs.size() > this->size()) this->_vec.resize(rhs.size());
    for (int i = 0; i < (int)rhs.size(); ++i) _vec[i] -= rhs._vec[i];
    return *this;
  }

  FPS& operator*=(const FPS& rhs) {
    _vec = multiply(_vec, rhs._vec);
    return *this;
  }

  // Nyaan先生のライブラリを大写経....
  FPS& operator/=(const FPS& rhs) {
    if (size() < rhs.size()) {
      return *this = FPS(0);
    }
    int sz = size() - rhs.size() + 1;
    //
    //    FPS left = (*this).rev().pre(sz);
    //    FPS right = rhs.rev();
    //    right = right.inv(sz);
    //    FPS mp = left*right;
    //    mp = mp.pre(sz);
    //    mp = mp.rev();
    //    return *this = mp;
    //    return *this = (left * right).pre(sz).rev();
    return *this = ((*this).rev().pre(sz) * rhs.rev().inv(sz)).pre(sz).rev();
  }

  FPS& operator%=(const FPS& rhs) {
    *this -= *this / rhs * rhs;
    shrink();
    return *this;
  }

  FPS& operator+=(const mint& rhs) {
    _vec[0] += rhs;
    return *this;
  }

  FPS& operator-=(const mint& rhs) {
    _vec[0] -= rhs;
    return *this;
  }

  FPS& operator*=(const mint& rhs) {
    for (int i = 0; i < size(); i++) _vec[i] *= rhs;
    return *this;
  }

  // 多項式全体を定数除算する
  FPS& operator/=(const mint& rhs) {
    for (int i = 0; i < size(); i++) _vec[i] *= rhs.inv();
    return *this;
  }

  // f /= x^sz
  FPS operator>>(int sz) const {
    if ((int)this->size() <= sz) return {};
    FPS ret(*this);
    ret._vec.erase(ret._vec.begin(), ret._vec.begin() + sz);
    return ret;
  }

  // f *= x^sz
  FPS operator<<(int sz) const {
    FPS ret(*this);
    ret._vec.insert(ret._vec.begin(), sz, mint(0));

    return ret;
  }

  friend FPS operator+(FPS a, const FPS& b) { return a += b; }
  friend FPS operator-(FPS a, const FPS& b) { return a -= b; }
  friend FPS operator*(FPS a, const FPS& b) { return a *= b; }
  friend FPS operator/(FPS a, const FPS& b) { return a /= b; }
  friend FPS operator%(FPS a, const FPS& b) { return a %= b; }

  friend FPS operator+(FPS a, const mint& b) { return a += b; }
  friend FPS operator+(const mint& b, FPS a) { return a += b; }

  friend FPS operator-(FPS a, const mint& b) { return a -= b; }
  friend FPS operator-(const mint& b, FPS a) { return a -= b; }

  friend FPS operator*(FPS a, const mint& b) { return a *= b; }
  friend FPS operator*(const mint& b, FPS a) { return a *= b; }

  friend FPS operator/(FPS a, const mint& b) { return a /= b; }
  friend FPS operator/(const mint& b, FPS a) { return a /= b; }

  FPS mul(FPS g, int deg = -1) const { return ((*this) * g).pre(deg); }

  // sz次未満の項を取ってくる
  FPS pre(int sz) const {
    FPS ret = *this;
    ret._vec.resize(sz);

    return ret;
  }

  FPS rev() const {
    FPS ret = *this;
    std::reverse(ret._vec.begin(), ret._vec.end());

    return ret;
  }

  const mint& operator[](size_t i) const { return _vec[i]; }

  mint& operator[](size_t i) { return _vec[i]; }

  void resize(int sz) { this->_vec.resize(sz); }

  // 全ての不要な末尾の0を削除する
  void shrink() {
    while (size() > 0 && _vec.back() == mint(0)) _vec.pop_back();
  }

  friend std::ostream& operator<<(std::ostream& os, const FPS& fps) {
    for (int i = 0; i < fps.size(); ++i) {
      if (i > 0) os << " ";
      os << fps._vec[i].val();
    }
    return os;
  }

  // 仮想関数ってやつ。mod
  // 998244353なのか、他のNTT-friendlyなmodで考えるのか、それともGarnerで復元するのか、それとも畳み込みを$O(N^2)$で妥協するのかなどによって異なる
  virtual FPS inv(int deg = -1) const;
  virtual FPS mul_sparse(const FPS& g, int deg = -1) const;  // fps-sparse.hppで実装したり
  virtual FPS inv_sparse(int deg) const;                     // fps-sparse.hppで実装したり
  virtual FPS exp_sparse(int deg) const;                     // fps-sparse.hppで実装したり
  virtual FPS log_sparse(int deg) const;                     // fps-sparse.hppで実装したり
  virtual FPS pow_sparse(long long k, int deg) const;        // fps-sparse.hppで実装したり

  virtual void next_inv(FPS& g_d) const;
  virtual void CooleyTukeyNTT998244353(std::vector<mint>& a, bool is_reverse) const;
  //  virtual FPS exp(int deg=-1) const;
  virtual std::vector<mint> multiply(const std::vector<mint>& a, const std::vector<mint>& b);
};

#line 1 "formal-power-series/fps-sparse.hpp"



#line 6 "formal-power-series/fps-sparse.hpp"

// forward decl to avoid circular include: formal-power-series.hpp includes this file
template <typename mint>
struct FPS;

// FPSの非ゼロな項を集めたvector<pair<int,mint>>を返す
template <typename mint>
std::vector<std::pair<int, mint>> get_nonzeros(const FPS<mint>& f) {
  std::vector<std::pair<int, mint>> ret;
  for (int i = 0; i < f.size(); i++) {
    if (f[i] != mint(0)) ret.emplace_back(i, f[i]);
  }
  return ret;
}

// ↓--- inverse of sparse fps ---↓
// calculate inverse of f(sparse)
// deg : -1 + ( maximum degree of g )
template <typename mint>
FPS<mint> inv_sparse(const std::vector<std::pair<int, mint>>& f, int deg) {
  assert(deg >= 0);
  for (int i = 0; i < (int)f.size() - 1; i++) assert(f[i].first < f[i + 1].first);
  assert(f[0].first == 0 && f[0].second != mint(0));

  mint f0inv = f[0].second.inv();
  std::vector<mint> g(deg);
  g[0] = f0inv;
  for (int i = 0; i < deg - 1; i++) {
    for (std::pair<int, mint> pim : f) {
      if (i + 1 - pim.first >= 0)
        g[i + 1] -= pim.second * g[i + 1 - pim.first];
      else
        continue;
    }
    g[i + 1] *= f0inv;
  }

  return g;
}

template <typename mint>
FPS<mint> inv_sparse(const FPS<mint>& f, int deg) {
  return inv_sparse(get_nonzeros(f), deg);
}

template <typename mint>
FPS<mint> FPS<mint>::inv_sparse(int deg) const {
  return ::inv_sparse(*this, deg);
}

// ↑--- inverse of sparse fps ----↑

// exp(f)のdeg次未満の部分を求める。
// F := exp(f) = F_0 + F_1 x + F_2 x^2 + ... とする。
// F' = F * f' なので
// F_1 + 2F_2 x + 3F_3 x^3 + ... = f' F.
// 0以上の整数iについて、i次の項に注目すると、
// (i+1) * F_{i+1} = [x^i] (f' * F)
// とわかる。Fは0,1,...,i次までわかってればF_{i+1}もわかるということになる。f'はスパースだからF_{i+1}はたかだかK回の計算で求められる.

template <typename mint>
FPS<mint> exp_sparse(const FPS<mint>& f, int deg) {
  FPS<mint> F(deg);
  F[0] = mint(1);

  std::vector<std::pair<int, mint>> nonzero_fdiff = get_nonzeros(f.diff());

  for (int i = 0; i + 1 < deg; i++) {
    // F[i+1]を求める
    // (i+1) * F_{i+1} = [x^i] (f' * F)

    for (std::pair<int, mint> pim : nonzero_fdiff) {
      int a = pim.first;
      // Fのi-a次の項を足していく
      if (i - a < 0) continue;
      assert(i - a >= 0);
      assert(i + 1 > i - a);
      F[i + 1] += pim.second * F[i - a];
    }
    F[i + 1] /= mint(i + 1);
  }
  return F;
}

template <typename mint>
FPS<mint> FPS<mint>::exp_sparse(int deg) const {
  return ::exp_sparse(*this, deg);
}

template <typename mint>
FPS<mint> log_sparse(const FPS<mint>& f, int deg) {
  FPS<mint> f_inv = inv_sparse(f, deg);
  return multiply_sparse(f_inv, f.diff(), deg).integral().pre(deg);
}

template <typename mint>
FPS<mint> FPS<mint>::log_sparse(int deg) const {
  return ::log_sparse(*this, deg);
}

// g := f ^ k
// g' = k * f^{k-1} * f'
// fg' = k * f^k * f'
// fg' = k * g * f'

template <typename mint>
FPS<mint> pow_sparse(const FPS<mint>& f, long long k, int deg) {
  if (k == 0) {
    FPS ret = {mint(1)};
    ret.resize(deg);
    return ret;
  }

  if (f[0] == mint(0)) {
    int mindeg = 0;
    while (mindeg < deg && f[mindeg] == mint(0)) mindeg++;

    // (x^{mindeg})^k = x^{mindeg * k}
    // mindeg * k >= deg ⇔ k >= floor(deg / mindeg) である。
    // →: 自明 (k >= deg / mindeg >= floor(deg / mindeg) なので)

    // ←について:  h1: k >= floor(deg / mindeg) を仮定して Goal: k >= degを示す。
    // deg = mindeg * q + r (0 <= r < mindeg)と表す。これを使うと
    // h1: k >= q.
    // Goal: k >= mindeg * q + r.

    // mindeg * k > LLINF
    // mindeg > LLINF / k
    constexpr long long INF = 4450000000011100000;
    if (mindeg > INF / k || mindeg * k >= deg) {
      FPS<mint> ret(deg);
      assert(ret[0] == mint(0));
      return ret;
    }
    return pow_sparse(f >> mindeg, k, deg - mindeg * k) << k * mindeg;
  }

  FPS<mint> g(deg);
  assert(f[0] != mint(0));
  g[0] = f[0].pow(k);

  std::vector<std::pair<int, mint>> nonzero_f = get_nonzeros(f);

  for (int i = 0; i + 1 < deg; i++) {
    // g[0], g[1], ..., g[i]が判っている状態で,x^iに注目してg[i+1]を求めにいく。

    // fg' = (f[0] + f[1]x + f[2]x^2 + ... + f[i]x^i)(g[1] + 2g[2]x + 3g[3]x^2 + ... + ig[i]x^{i-1} + (i+1)g[i+1]x^i)
    // (左) kgf' = k(g[0] + g[1]x + g[2]x^2 + ... + g[i]x^i) * (f[1] + 2*f[2]x 3*f[3]x^2 + ... + i*f[i]x^{i-1} +
    // (i+1)*f[i+1]x^i) (右)

    // 左のx^iの係数は f[0](i+1)g[i+1] + f[1]ig[i] + f[2](i-1)g[i-1] + ... + f[i]g[1]
    // 右のx^iの係数は k * ( g[0]*(i+1)f[i+1] +  g[1]if[i] + ... + g[i]*1f[1])

    // f[0](i+1)g[i+1] = k * (g[0]*(i+1)f[i+1] + g[1]if[i] + ... + g[i]*1f[1]) - f[1]ig[i] - f[2](i-1)g[i-1] - ... -
    // f[i]g[1]

    mint sum(0);
    for (std::pair<int, mint> pim : nonzero_f) {
      // f[pim.first]: pim.second
      // 左では 0 <= pim.first <= i
      // 右では 1 <= pim.first <= i+1の部分を見る

      if (0 <= pim.first && pim.first <= i) sum -= pim.second * mint(i + 1 - pim.first) * g[i + 1 - pim.first];
      if (1 <= pim.first && pim.first <= i + 1) sum += mint(k) * g[i + 1 - pim.first] * mint(pim.first) * pim.second;
    }

    // for (int j=0; j<=i; j++) sum += mint(k) * g[j] * mint(i+1-j) * f[i+1-j];
    // for (int j=1; j<=i; j++) sum -= f[j] * mint(i+1-j) * g[i+1-j];

    g[i + 1] = sum / f[0] / mint(i + 1);
  }

  return g;
}

template <typename mint>
FPS<mint> FPS<mint>::pow_sparse(long long k, int deg) const {
  return ::pow_sparse(*this, k, deg);
}

// tabun baggute masu.
template <typename mint>
FPS<mint> multiply_sparse(const FPS<mint>& f, const std::vector<std::pair<int, mint>>& g, int deg = -1) {
  if (deg == -1) deg = f.size() - 1 + g.back().first + 1;

  FPS<mint> ret(deg);
  for (std::pair<int, mint> pim : g) {
    assert(pim.second != 0);
    if (pim.second == 0) continue;

    for (int i = 0; i < f.size(); i++) {
      if (i + pim.first >= ret.size()) continue;
      if (f[i] != mint(0) && pim.second != mint(0)) ret[i + pim.first] += pim.second * f[i];
    }
  }

  return ret;
}

template <typename mint>
FPS<mint> multiply_sparse(const FPS<mint>& f, const FPS<mint>& g, int deg = -1) {
  std::vector<std::pair<int, mint>> vpmi;

  for (int i = 0; i < g.size(); i++)
    if (g[i] != mint(0)) vpmi.emplace_back(i, g[i]);

  return multiply_sparse(f, vpmi, deg);
}

template <typename mint>
FPS<mint> FPS<mint>::mul_sparse(const FPS<mint>& g, int deg) const {
  return multiply_sparse(*this, g, deg);
}


#line 304 "formal-power-series/formal-power-series.hpp"


#line 3 "formal-power-series/fiduccia.hpp"

// AtCoderではverifyできたが、LCではできず
// given linear recurrence sequence a_{n+K}= c_1 a_{n+K-1} + c_2 a_{n+k-2} + \dots + c_{K-1} a_{n+1} + c_K a_n
// a_0, a_1, \dots, a_{K-1} are given
// calculate a_N (N-th term of linear recurrence sequence) time complexity is O(K log K log N) (when NNT is used), O(K^2 log N) (when naive convolution is used).
template <typename mint> 
mint Fiduccia(const vector<mint>& a, const vector<mint>& c, unsigned long long  N) {
  if (N < a.size()) return a[N];
  assert(a.size() == c.size());
  int K = c.size();

  FPS<mint> varphi(K+1); 
  varphi[K] = mint(1);
  for(int i=0; i<K; i++) varphi[i] = mint(-1) * c[K-i-1];

  // calculate x^N mod varphi, using square and multiply technique.
  // Note that there is two way to implement the methodlogy. LSB-first algorithm(famous one ) and MSB-first alogirthm.
 int msb=0;
  for (int i=0; 1ULL<< i <=N; i++) {
    if (N & (1ULL << i)) msb = i;
  }
  FPS<mint> remainder(1); remainder[0] = mint(1);
  for (int i=msb; i>=0; i--) {
    if (N & (1ULL << i)) {
      remainder = remainder << 1; // it is equal to remainder *= x.
      if (remainder.size() >= varphi.size()) remainder %= varphi;
    }
    if (i != 0) {
      remainder *= remainder; // NTTなら、NTT配列を使い回すことで定数倍が良くなるね
      if (remainder.size() >= varphi.size()) remainder %= varphi;
    }
  }

  // remainder = x^N mod varphi 
  mint ret = 0;
  assert(remainder.size() <= K);
  for(int i=0; i<remainder.size(); i++) {
    ret += remainder[i] * a[i];
  }

  return ret;
}
Back to top page