This documentation is automatically generated by online-judge-tools/verification-helper
#include "formal-power-series/fiduccia.hpp"
~~ なんかLibrary Checkerは通らないが、Typical DP Contest Tのフィボナッチは(NaiveなFPSを使うことで)通った。~~ → 普通にLibrary Checkerも通りました。遅いけど。
提出: [https://atcoder.jp/contests/tdpc/submissions/55372657]
Fiducciaの論文はarchive.orgにあったきがする。ACMのだったはず。
参考にしたもの: [https://qiita.com/ryuhe1/items/da5acbcce4ac1911f47a] のFiducciaのアルゴリズム ~ 二種類の繰り返し二乗法あたりまで。
mint Fidducia(const vector<mint>& a, const vector<mint>& c, unsigned long long N)
$a_{n+K}= c_1 a_{n+K-1} + c_2 a_{n+k-2} + \dots + c_{K-1} a_{n+1} + c_K a_n $ という線型漸化式が与えられて、初項
$ a_0, a_1, \dots, a_{K-1} $ がわかっているときに、$a_N$を求める。
畳み込みがFFTなどで高速化されるなら
ナイーブな畳み込みを使うなら
↓↓↓自分でまとめたときの解説みたいなやつ
$K+1$項間の線型漸化式を持つ数列の$N$項目は行列累乗で求めたら$O(K^3 \log N)$かかるが、このスライドの手法でやると、もうちょっと速くなる。
\(a_{n+K}= c_1 a_{n+K-1} + c_2 a_{n+k-2} + \dots + c_{K-1} a_{n+1} + c_K a_n\) という$K+1$項間線形漸化式があるとする。
すると、いわゆる行列累乗の要領で
\[\begin{bmatrix} a_{n+1} \\ a_{n} \\ \vdots \\ a_{n-K+3} \\ a_{n-K+2} \end {bmatrix} = \begin{bmatrix} c_1 & c_2 & c_3 & c_4 & \dots & c_K \\ 1 & 0 & 0 & 0 & \dots & 0 \\ 0 & 1 & 0 & 0 & \dots & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \\ 0 & \dots & 0 & 0 & 1 & 0 \end{bmatrix} \begin {bmatrix} a_{n} \\ a_{n-1} \\ \vdots \\ a_{n-K+2} \\ a_{n-K+1} \end {bmatrix}\]であるとわかる。
すると、先ほどの式を繰り返し適用すると、以下のような関係式が成り立つから、
$$
\begin{bmatrix}
a_N
a_{N-1}
\vdots
a_{N-K+2}
a_{N-K+1}
\end{bmatrix}
=
\begin{bmatrix}
c_1 & c_2 & c_3 & c_4 & \dots & c_K
1 & 0 & 0 & 0 & \dots & 0
0 & 1 & 0 & 0 & \dots & 0
\vdots & \vdots & \vdots & \vdots & \vdots & \vdots &
0 & \dots & 0 & 0 & 1 & 0
\end{bmatrix}^{N-K}
\begin {bmatrix}
a_K
a_{K-1}
a_{K_2}
\vdots
a_1
\end {bmatrix}
$$
となる。
数列の$N$項目は行列をだいたい$N$回掛けることで求められる。行列の積は1回あたり$O(K^3)$かかるので、$A^{N-1}$を求めるのに繰り返し二乗法を使うと、$a_N$は全体で$O(K^3 \log N)$ の計算量で求められる。
さっきの式:
$$
\begin{bmatrix}
a_N
a_{N-1}
\vdots
a_{N-K+2}
a_{N-K+1}
\end{bmatrix}
=
\begin{bmatrix}
c_1 & c_2 & c_3 & c_4 & \dots & c_K
1 & 0 & 0 & 0 & \dots & 0
0 & 1 & 0 & 0 & \dots & 0
\vdots & \vdots & \vdots & \vdots & \vdots & \vdots &
0 & \dots & 0 & 0 & 1 & 0
\end{bmatrix}^{N-K}
\begin {bmatrix}
a_K
a_{K-1}
a_{K_2}
\vdots
a_1
\end {bmatrix}
$$
これをちょっと変えて、
$$
\begin{bmatrix}
a_{N+K-1}
a_{N+K-2}
\vdots
a_{N+1}
a_{N}
\end{bmatrix}
=
\begin{bmatrix}
c_1 & c_2 & c_3 & c_4 & \dots & c_K
1 & 0 & 0 & 0 & \dots & 0
0 & 1 & 0 & 0 & \dots & 0
\vdots & \vdots & \vdots & \vdots & \vdots & \vdots &
0 & \dots & 0 & 0 & 1 & 0
\end{bmatrix}^{N}
\begin {bmatrix}
a_{K-1}
a_{K-2}
\vdots
a_1 \
a_0
\end {bmatrix}
$$
のようにしてみる。
緑の線形代数の教科書(線形代数講義と演習 改訂版 小林正典, 寺尾宏明 ) P98 定理17.1を見ると、
定理17.1 (ケイリーハミルトンの定理) $n$次正方行列$A$の固有多項式$\varphi_A(x)$に$A$を代入したものは零行列に等しい
と書いてあります。つまり、$\varphi_A(A) = 0$。ここで $x^N$を $\varphi_A(x)$ で割ったあまりを$r(x)$ と定義する。つまり、 $x^N = q(x) \varphi_A(x) + r(x)$。ただし$\deg(r(x)) < \deg(\varphi_A(x))$
すると、$A^N = r(A^N)$となる。 $r(x) = r_0 + r_1 x + \dots + r_{K-1} x^{K-1}$ とおくと、$A^N = r(A^N) = r_0 E + r_1 A + \dots + r_{K-1} A^{K-1}$とわかる。
求めたい$a_N$は、$A^{N}$の一番下の行と、$[a_{K-1} a_{K-2} \dots a_0 ]$の内積。 だから、 $E, A^1, A^2, \dots A^{K-1}$の一番下の行のみわかればよい。
ここで、驚きの事実(証明はあとで書く)として、$A^i$ ( $0 \leq i < K$) の一番下の行は、$K-i$列目だけ$1$で、他は$0$となっている。
なので、 $$r_i A^i \begin{bmatrix} a_{K-1} \ a_{K-2} \ \vdots \ a_1 \ a_0 \end{bmatrix} = \begin{bmatrix}
よって、 $$ r_0 E + r_1 A + \dots + r_{K-1} A^{K-1} = \begin{bmatrix}
これまでのことをまとめると、
という3ステップで$a_N$が計算できることがわかった。
実は、固有多項式は線型漸化式によって決められるので、入力で線型漸化式を受け取ってから、掃き出し法をして … みたいな計算は必要なく、($c_1, c_2, \dots c_K$を使った式で)事前に求められる。。(あたりまえかも)
\[\varphi_A(x) = |xE - A | = \begin{vmatrix} x - c_1 & -c_2 & -c_3 & -c_4 & \dots & -c_K \\ -1 & x & 0 & 0 & \dots & 0 \\ 0 & -1 & x & 0 & \dots & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \\ 0 & \dots & 0 & 0 & -1 & x \end{vmatrix}\](結論を先に書くと、$=-x^K + c_1 x^{K-1} + \dots + c_{K-1} x^1 + c_{K}$ )
多項式除算なので ナイーブにやると$O(K^2)$でできる。行列積よりも速い! →嘘っぽい。
mod $\varphi_A(x)$ 上で$x^N$を計算すればよい。
はい。やるだけ。
https://blog.miz-ar.info/2019/02/typical-dp-contest-t/ : https://qiita.com/ryuhe1/items/da5acbcce4ac1911f47a : Bostan-MoriのMoriさんです
#include <vector>
#include "formal-power-series/formal-power-series.hpp"
// AtCoderではverifyできたが、LCではできず
// given linear recurrence sequence a_{n+K}= c_1 a_{n+K-1} + c_2 a_{n+k-2} + \dots + c_{K-1} a_{n+1} + c_K a_n
// a_0, a_1, \dots, a_{K-1} are given
// calculate a_N (N-th term of linear recurrence sequence) time complexity is O(K log K log N) (when NNT is used), O(K^2 log N) (when naive convolution is used).
template <typename mint>
mint Fiduccia(const vector<mint>& a, const vector<mint>& c, unsigned long long N) {
if (N < a.size()) return a[N];
assert(a.size() == c.size());
int K = c.size();
FPS<mint> varphi(K+1);
varphi[K] = mint(1);
for(int i=0; i<K; i++) varphi[i] = mint(-1) * c[K-i-1];
// calculate x^N mod varphi, using square and multiply technique.
// Note that there is two way to implement the methodlogy. LSB-first algorithm(famous one ) and MSB-first alogirthm.
int msb=0;
for (int i=0; 1ULL<< i <=N; i++) {
if (N & (1ULL << i)) msb = i;
}
FPS<mint> remainder(1); remainder[0] = mint(1);
for (int i=msb; i>=0; i--) {
if (N & (1ULL << i)) {
remainder = remainder << 1; // it is equal to remainder *= x.
if (remainder.size() >= varphi.size()) remainder %= varphi;
}
if (i != 0) {
remainder *= remainder; // NTTなら、NTT配列を使い回すことで定数倍が良くなるね
if (remainder.size() >= varphi.size()) remainder %= varphi;
}
}
// remainder = x^N mod varphi
mint ret = 0;
assert(remainder.size() <= K);
for(int i=0; i<remainder.size(); i++) {
ret += remainder[i] * a[i];
}
return ret;
}
#line 1 "formal-power-series/fiduccia.hpp"
#include <vector>
#line 1 "formal-power-series/formal-power-series.hpp"
#line 1 "math/modint.hpp"
#line 1 "math/external_gcd.hpp"
#include <tuple>
// g,x,y
template<typename T>
constexpr std::tuple<T, T, T> extendedGCD(T a, T b) {
T x0 = 1, y0 = 0, x1 = 0, y1 = 1;
while (b != 0) {
T q = a / b;
T r = a % b;
a = b;
b = r;
T xTemp = x0 - q * x1;
x0 = x1;
x1 = xTemp;
T yTemp = y0 - q * y1;
y0 = y1;
y1 = yTemp;
}
return {a, x0, y0};
}
#line 5 "math/modint.hpp"
#include <type_traits>
#include <cassert>
template<int MOD, typename T = int>
struct static_modint {
T value;
constexpr explicit static_modint() : value(0) {}
constexpr static_modint(long long v) {
if constexpr (std::is_same<T, double>::value) {
value = static_cast<T>(v);
}
else {
value = int(((v % MOD) + MOD) % MOD);
}
}
constexpr static_modint& operator+=(const static_modint& other) {
if constexpr (std::is_same<T, double>::value) {
value += other.value;
}
else {
if ((value += other.value) >= MOD) value -= MOD;
}
return *this;
}
constexpr static_modint& operator-=(const static_modint& other) {
if constexpr (std::is_same<T, double>::value) {
value -= other.value;
}
else {
if ((value -= other.value) < 0) value += MOD;
}
return *this;
}
constexpr static_modint& operator*=(const static_modint& other) {
if constexpr (std::is_same<T, double>::value) {
value *= other.value;
}
else {
value = int((long long)value * other.value % MOD);
}
return *this;
}
constexpr static_modint operator+(const static_modint& other) const {
return static_modint(*this) += other;
}
constexpr static_modint operator-(const static_modint& other) const {
return static_modint(*this) -= other;
}
constexpr static_modint operator*(const static_modint& other) const {
return static_modint(*this) *= other;
}
constexpr static_modint pow(long long exp) const {
static_modint base = *this, res = static_modint(1);
while (exp > 0) {
if (exp & 1) res *= base;
base *= base;
exp >>= 1;
}
return res;
}
constexpr static_modint inv() const {
if constexpr (std::is_same<T, double>::value) {
return static_modint(1) / static_modint(value);
}
else {
int g, x, y;
std::tie(g, x, y) = extendedGCD(value, MOD);
assert(g == 1);
if (x < 0) x += MOD;
return x;
}
}
constexpr static_modint& operator/=(const static_modint& other) {
return *this *= other.inv();
}
constexpr static_modint operator/(const static_modint& other) const {
return static_modint(*this) /= other;
}
constexpr bool operator!=(const static_modint& other) const {
return val() != other.val();
}
constexpr bool operator==(const static_modint& other) const {
return val() == other.val();
}
T val() const {
if constexpr (std::is_same<T, double>::value) {
return static_cast<double>(value);
}
else return this->value;
}
friend std::ostream& operator<<(std::ostream& os, const static_modint& mi) {
return os << mi.value;
}
friend std::istream& operator>>(std::istream& is, static_modint& mi) {
long long x;
is >> x;
mi = static_modint(x);
return is;
}
};
template <int mod>
using modint = static_modint<mod>;
using modint998244353 = modint<998244353>;
using modint1000000007 = modint<1000000007>;
#line 6 "formal-power-series/formal-power-series.hpp"
#include <algorithm>
template <typename mint>
struct FPS {
std::vector<mint> _vec;
constexpr int lg2(int N) const {
int ret = 0;
if (N > 0) ret = 31 - __builtin_clz(N);
if ((1LL << ret) < N) ret++;
return ret;
}
// ナイーブなニュートン法での逆元計算
FPS inv_naive(int deg) const {
assert(_vec[0] != mint(0)); // さあらざれば、逆元のてひぎいきにこそあらざれ。
if (deg == -1) deg = this->size();
FPS g(1);
g._vec[0] = mint(_vec[0]).inv();
// g_{n+1} = 2 * g_n - f * (g_n)^2
for (int d = 1; d < deg; d <<= 1) {
FPS g_twice = g * mint(2);
FPS fgg = (*this).pre(d * 2) * g * g;
g = g_twice - fgg;
g.resize(d * 2);
}
return g.pre(deg);
}
//*/
FPS log(int deg = -1) const {
assert(_vec[0] == mint(1));
if (deg == -1) deg = size();
FPS df = this->diff();
FPS iv = this->inv(deg);
FPS ret = (df * iv).pre(deg - 1).integral();
return ret;
}
FPS exp(int deg = -1) const {
assert(_vec[0] == mint(0));
if (deg == -1) deg = size();
FPS h = {1}; // h: exp(f)
// h_2d = h * (f + 1 - Integrate(h' * h.inv() ) )
for (int d = 1; d < deg; d <<= 1) {
// h_2d = h_d * (f + 1 - log(h_d))
// = h_d * (f + 1 - Integral(h' * h.inv() ))
// を利用して、h.invを漸化式で更新していけば定数倍改善できるかと思ったが、なんかバグってる。
FPS fpl1 = ((*this).pre(2*d) + mint(1));
FPS logh = h.log(2*d);
FPS right = (fpl1 - logh);
h = (h * right).pre(2 * d);
}
return h.pre(deg);
}
// f^k を返す
FPS pow(long long k, int deg = -1) const {
mint lowest_coeff;
if (deg == -1) deg = size();
int lowest_deg = -1;
if (k == 0) {
FPS ret = { mint(1) };
ret.resize(deg);
return ret;
}
for (int i = 0; i < size(); i++) {
if (i * k > deg) {
return FPS(deg);
}
if (_vec[i] != mint(0)) {
lowest_deg = i;
lowest_coeff = _vec[i];
int deg3 = deg - k*lowest_deg;
FPS f2 = (*this / lowest_coeff) >> lowest_deg;
FPS ret = (lowest_coeff.pow(k) * (f2.log(deg3) * mint(k)).exp(deg3) << (lowest_deg * k)).pre(deg);
ret.resize(deg);
return ret;
}
}
assert(false);
}
FPS integral() const {
const int N = size();
FPS ret(N + 1);
for (int i = 0; i < N; i++) ret[i + 1] = _vec[i] * mint(i + 1).inv();
return ret;
}
FPS diff() const {
const int N = size();
FPS ret(max(0, N - 1));
for (int i = 1; i < N; i++) ret[i - 1] = mint(i) * _vec[i];
return ret;
}
FPS to_egf() const {
const int N = size();
FPS ret(N);
mint fact = mint(1);
for (int i=0; i<N; i++) {
ret[i] = _vec[i] * fact.inv();
fact *= mint(i+1);
}
return ret;
}
FPS to_ogf() const {
const int N = size();
FPS ret(N);
mint fact = mint(1);
for (int i=0; i<N; i++) {
ret[i] = _vec[i] * fact;
fact *= mint(i+1);
}
return ret;
}
FPS(std::vector<mint> vec) : _vec(vec) {
}
FPS(initializer_list<mint> ilist) : _vec(ilist) {
}
// 項の数に揃えたほうがよさそう
FPS(int sz) : _vec(std::vector<mint>(sz)) {
}
int size() const {
return _vec.size();
}
FPS& operator+=(const FPS& rhs) {
if (rhs.size() > this->size()) _vec.resize(rhs.size());
for (int i = 0; i < (int)rhs.size(); ++i) _vec[i] += rhs._vec[i];
return *this;
}
FPS& operator-=(const FPS& rhs) {
if (rhs.size() > this->size()) this->_vec.resize(rhs.size());
for (int i = 0; i < (int)rhs.size(); ++i) _vec[i] -= rhs._vec[i];
return *this;
}
FPS& operator*=(const FPS& rhs) {
_vec = multiply(_vec, rhs._vec);
return *this;
}
// Nyaan先生のライブラリを大写経....
FPS& operator/=(const FPS& rhs) {
if (size() < rhs.size()) {
return *this = FPS(0);
}
int sz = size() - rhs.size() + 1;
//
// FPS left = (*this).rev().pre(sz);
// FPS right = rhs.rev();
// right = right.inv(sz);
// FPS mp = left*right;
// mp = mp.pre(sz);
// mp = mp.rev();
// return *this = mp;
// return *this = (left * right).pre(sz).rev();
return *this = ((*this).rev().pre(sz) * rhs.rev().inv(sz)).pre(sz).rev();
}
FPS& operator%=(const FPS& rhs) {
*this -= *this / rhs * rhs;
shrink();
return *this;
}
FPS& operator+=(const mint& rhs) {
_vec[0] += rhs;
return *this;
}
FPS& operator-=(const mint& rhs) {
_vec[0] -= rhs;
return *this;
}
FPS& operator*=(const mint& rhs) {
for (int i = 0; i < size(); i++) _vec[i] *= rhs;
return *this;
}
// 多項式全体を定数除算する
FPS& operator/=(const mint& rhs) {
for (int i = 0; i < size(); i++) _vec[i] *= rhs.inv();
return *this;
}
// f /= x^sz
FPS operator>>(int sz) const {
if ((int)this->size() <= sz) return {};
FPS ret(*this);
ret._vec.erase(ret._vec.begin(), ret._vec.begin() + sz);
return ret;
}
// f *= x^sz
FPS operator<<(int sz) const {
FPS ret(*this);
ret._vec.insert(ret._vec.begin(), sz, mint(0));
return ret;
}
friend FPS operator+(FPS a, const FPS& b) { return a += b; }
friend FPS operator-(FPS a, const FPS& b) { return a -= b; }
friend FPS operator*(FPS a, const FPS& b) { return a *= b; }
friend FPS operator/(FPS a, const FPS& b) { return a /= b; }
friend FPS operator%(FPS a, const FPS& b) { return a %= b; }
friend FPS operator+(FPS a, const mint& b) { return a += b; }
friend FPS operator+(const mint& b, FPS a) { return a += b; }
friend FPS operator-(FPS a, const mint& b) { return a -= b; }
friend FPS operator-(const mint& b, FPS a) { return a -= b; }
friend FPS operator*(FPS a, const mint& b) { return a *= b; }
friend FPS operator*(const mint& b, FPS a) { return a *= b; }
friend FPS operator/(FPS a, const mint& b) { return a /= b; }
friend FPS operator/(const mint& b, FPS a) { return a /= b; }
// sz次未満の項を取ってくる
FPS pre(int sz) const {
FPS ret = *this;
ret._vec.resize(sz);
return ret;
}
FPS rev() const {
FPS ret = *this;
std::reverse(ret._vec.begin(), ret._vec.end());
return ret;
}
const mint& operator[](size_t i) const {
return _vec[i];
}
mint& operator[](size_t i) {
return _vec[i];
}
void resize(int sz) {
this->_vec.resize(sz);
}
void shrink() {
while (size() > 0 && _vec.back() == mint(0)) _vec.pop_back();
}
friend ostream& operator<<(ostream& os, const FPS& fps) {
for (int i = 0; i < fps.size(); ++i) {
if (i > 0) os << " ";
os << fps._vec[i].val();
}
return os;
}
// 仮想関数ってやつ。mod 998244353なのか、他のNTT-friendlyなmodで考えるのか、それともGarnerで復元するのか、それとも畳み込みを$O(N^2)$で妥協するのかなどによって異なる
virtual FPS inv(int deg = -1) const;
virtual void next_inv(FPS& g_d) const;
virtual void CooleyTukeyNTT998244353(std::vector<mint>& a, bool is_reverse) const;
// virtual FPS exp(int deg=-1) const;
virtual std::vector<mint> multiply(const std::vector<mint>& a, const std::vector<mint>& b);
};
#line 3 "formal-power-series/fiduccia.hpp"
// AtCoderではverifyできたが、LCではできず
// given linear recurrence sequence a_{n+K}= c_1 a_{n+K-1} + c_2 a_{n+k-2} + \dots + c_{K-1} a_{n+1} + c_K a_n
// a_0, a_1, \dots, a_{K-1} are given
// calculate a_N (N-th term of linear recurrence sequence) time complexity is O(K log K log N) (when NNT is used), O(K^2 log N) (when naive convolution is used).
template <typename mint>
mint Fiduccia(const vector<mint>& a, const vector<mint>& c, unsigned long long N) {
if (N < a.size()) return a[N];
assert(a.size() == c.size());
int K = c.size();
FPS<mint> varphi(K+1);
varphi[K] = mint(1);
for(int i=0; i<K; i++) varphi[i] = mint(-1) * c[K-i-1];
// calculate x^N mod varphi, using square and multiply technique.
// Note that there is two way to implement the methodlogy. LSB-first algorithm(famous one ) and MSB-first alogirthm.
int msb=0;
for (int i=0; 1ULL<< i <=N; i++) {
if (N & (1ULL << i)) msb = i;
}
FPS<mint> remainder(1); remainder[0] = mint(1);
for (int i=msb; i>=0; i--) {
if (N & (1ULL << i)) {
remainder = remainder << 1; // it is equal to remainder *= x.
if (remainder.size() >= varphi.size()) remainder %= varphi;
}
if (i != 0) {
remainder *= remainder; // NTTなら、NTT配列を使い回すことで定数倍が良くなるね
if (remainder.size() >= varphi.size()) remainder %= varphi;
}
}
// remainder = x^N mod varphi
mint ret = 0;
assert(remainder.size() <= K);
for(int i=0; i<remainder.size(); i++) {
ret += remainder[i] * a[i];
}
return ret;
}