This documentation is automatically generated by online-judge-tools/verification-helper
View the Project on GitHub Harui-i/library
#define PROBLEM "https://judge.yosupo.jp/problem/inv_of_formal_power_series" #include "template/template.hpp" #include "math/modint.hpp" #include "formal-power-series/formal-power-series.hpp" #include "formal-power-series/fps998.hpp" using mint = modint998244353; int main() { ios::sync_with_stdio(0); cin.tie(0); cout.tie(0); int N; cin >> N; FPS<mint> a_fps(N); for(int i=0; i<N; i++) cin >> a_fps[i]; cout << a_fps.inv(N) << endl; }
#line 1 "test/verify/fps/yosupo-inv-of-formal-power-series-fast2.test.cpp" #define PROBLEM "https://judge.yosupo.jp/problem/inv_of_formal_power_series" #line 1 "template/template.hpp" #include <iostream> #include <cassert> using namespace std; using ll = long long; template<class T> inline bool chmax(T& a, const T& b) {if (a<b) {a=b; return true;} return false;} template<class T> inline bool chmin(T& a, const T& b) {if (b<a) {a=b; return true;} return false;} const int INTINF = 1000001000; const int INTMAX = 2147483647; const ll LLMAX = 9223372036854775807; const ll LLINF = 1000000000000000000; #line 1 "math/modint.hpp" #line 1 "math/external_gcd.hpp" #include <tuple> // g,x,y template<typename T> constexpr std::tuple<T, T, T> extendedGCD(T a, T b) { T x0 = 1, y0 = 0, x1 = 0, y1 = 1; while (b != 0) { T q = a / b; T r = a % b; a = b; b = r; T xTemp = x0 - q * x1; x0 = x1; x1 = xTemp; T yTemp = y0 - q * y1; y0 = y1; y1 = yTemp; } return {a, x0, y0}; } #line 5 "math/modint.hpp" template<int MOD> struct static_modint { int value; constexpr explicit static_modint() : value(0) {} constexpr static_modint(long long v) { value = int(((v % MOD) + MOD) % MOD); } constexpr static_modint& operator+=(const static_modint& other) { if ((value += other.value) >= MOD) value -= MOD; return *this; } constexpr static_modint& operator-=(const static_modint& other) { if ((value -= other.value) < 0) value += MOD; return *this; } constexpr static_modint& operator*=(const static_modint& other) { value = int((long long)value * other.value % MOD); return *this; } constexpr static_modint operator+(const static_modint& other) const { return static_modint(*this) += other; } constexpr static_modint operator-(const static_modint& other) const { return static_modint(*this) -= other; } constexpr static_modint operator*(const static_modint& other) const { return static_modint(*this) *= other; } constexpr static_modint pow(long long exp) const { static_modint base = *this, res = static_modint(1); while (exp > 0) { if (exp & 1) res *= base; base *= base; exp >>= 1; } return res; } constexpr static_modint inv() const { //return pow(MOD - 2); int g,x,y; tie(g,x,y) = extendedGCD(value, MOD); assert(g==1); if (x < 0) { x += MOD; } //cerr << g << " " << x << " " << y << " " << value << endl; //assert((((long)x*value)%MOD + MOD)%MOD == 1); return x; } constexpr static_modint& operator/=(const static_modint& other) { return *this *= other.inv(); } constexpr static_modint operator/(const static_modint& other) const { return static_modint(*this) /= other; } constexpr bool operator!=(const static_modint& other) const { return val() != other.val(); } constexpr bool operator==(const static_modint& other) const { return val() == other.val(); } int val() const { return this->value; } friend std::ostream& operator<<(std::ostream& os, const static_modint& mi) { return os << mi.value; } friend std::istream& operator>>(std::istream& is, static_modint& mi) { long long x; is >> x; mi = static_modint(x); return is; } }; template <int mod> using modint = static_modint<mod>; using modint998244353 = modint<998244353>; using modint1000000007 = modint<1000000007>; #line 1 "formal-power-series/formal-power-series.hpp" #line 5 "formal-power-series/formal-power-series.hpp" #include <vector> template <typename mint> struct FPS { std::vector<mint> _vec; constexpr int lg2(int N) const { int ret = 0; if (N > 0) ret = 31 - __builtin_clz(N); if ((1LL << ret) < N) ret++; return ret; } // ナイーブなニュートン法での逆元計算 FPS inv_naive(int deg) const { assert(_vec[0] != mint(0)); // さあらざれば、逆元のてひぎいきにこそあらざれ。 if (deg == -1) deg = this->size(); FPS g(1); g._vec[0] = mint(_vec[0]).inv(); // g_{n+1} = 2 * g_n - f * (g_n)^2 for (int d = 1; d < deg; d <<= 1) { FPS g_twice = g * mint(2); FPS fgg = (*this).pre(d * 2) * g * g; g = g_twice - fgg; g.resize(d * 2); } return g.pre(deg); } //*/ FPS log(int deg = -1) const { assert(_vec[0] == mint(1)); if (deg == -1) deg = size(); FPS df = this->diff(); FPS iv = this->inv(deg); FPS ret = (df * iv).pre(deg - 1).integral(); return ret; } FPS exp(int deg = -1) const { assert(_vec[0] == mint(0)); if (deg == -1) deg = size(); FPS h = {1}; // h: exp(f) // h_2d = h * (f + 1 - Integrate(h' * h.inv() ) ) for (int d = 1; d < deg; d <<= 1) { // h_2d = h_d * (f + 1 - log(h_d)) // = h_d * (f + 1 - Integral(h' * h.inv() )) // を利用して、h.invを漸化式で更新していけば定数倍改善できるかと思ったが、なんかバグってる。 FPS fpl1 = ((*this).pre(2*d) + mint(1)); FPS logh = h.log(2*d); FPS right = (fpl1 - logh); h = (h * right).pre(2 * d); } return h.pre(deg); } // f^k を返す FPS pow(long long k, int deg = -1) const { mint lowest_coeff; if (deg == -1) deg = size(); int lowest_deg = -1; if (k == 0) { FPS ret = { mint(1) }; ret.resize(deg); return ret; } for (int i = 0; i < size(); i++) { if (i * k > deg) { return FPS(deg); } if (_vec[i] != mint(0)) { lowest_deg = i; lowest_coeff = _vec[i]; int deg3 = deg - k*lowest_deg; FPS f2 = (*this / lowest_coeff) >> lowest_deg; FPS ret = (lowest_coeff.pow(k) * (f2.log(deg3) * mint(k)).exp(deg3) << (lowest_deg * k)).pre(deg); ret.resize(deg); return ret; } } assert(false); } FPS integral() const { const int N = size(); FPS ret(N + 1); for (int i = 0; i < N; i++) ret[i + 1] = _vec[i] * mint(i + 1).inv(); return ret; } FPS diff() const { const int N = size(); FPS ret(max(0, N - 1)); for (int i = 1; i < N; i++) ret[i - 1] = mint(i) * _vec[i]; return ret; } FPS(std::vector<mint> vec) : _vec(vec) { } FPS(initializer_list<mint> ilist) : _vec(ilist) { } // 項の数に揃えたほうがよさそう FPS(int sz) : _vec(std::vector<mint>(sz)) { } int size() const { return _vec.size(); } FPS& operator+=(const FPS& rhs) { if (rhs.size() > this->size()) _vec.resize(rhs.size()); for (int i = 0; i < (int)rhs.size(); ++i) _vec[i] += rhs._vec[i]; return *this; } FPS& operator-=(const FPS& rhs) { if (rhs.size() > this->size()) this->_vec.resize(rhs.size()); for (int i = 0; i < (int)rhs.size(); ++i) _vec[i] -= rhs._vec[i]; return *this; } FPS& operator*=(const FPS& rhs) { _vec = multiply(_vec, rhs._vec); return *this; } // Nyaan先生のライブラリを大写経.... FPS& operator/=(const FPS& rhs) { if (size() < rhs.size()) { return *this = FPS(0); } int sz = size() - rhs.size() + 1; // // FPS left = (*this).rev().pre(sz); // FPS right = rhs.rev(); // right = right.inv(sz); // FPS mp = left*right; // mp = mp.pre(sz); // mp = mp.rev(); // return *this = mp; // return *this = (left * right).pre(sz).rev(); return *this = ((*this).rev().pre(sz) * rhs.rev().inv(sz)).pre(sz).rev(); } FPS& operator%=(const FPS& rhs) { *this -= *this / rhs * rhs; shrink(); return *this; } FPS& operator+=(const mint& rhs) { _vec[0] += rhs; return *this; } FPS& operator-=(const mint& rhs) { _vec[0] -= rhs; return *this; } FPS& operator*=(const mint& rhs) { for (int i = 0; i < size(); i++) _vec[i] *= rhs; return *this; } // 多項式全体を定数除算する FPS& operator/=(const mint& rhs) { for (int i = 0; i < size(); i++) _vec[i] *= rhs.inv(); return *this; } // f /= x^sz FPS operator>>(int sz) const { if ((int)this->size() <= sz) return {}; FPS ret(*this); ret._vec.erase(ret._vec.begin(), ret._vec.begin() + sz); return ret; } // f *= x^sz FPS operator<<(int sz) const { FPS ret(*this); ret._vec.insert(ret._vec.begin(), sz, mint(0)); return ret; } friend FPS operator+(FPS a, const FPS& b) { return a += b; } friend FPS operator-(FPS a, const FPS& b) { return a -= b; } friend FPS operator*(FPS a, const FPS& b) { return a *= b; } friend FPS operator/(FPS a, const FPS& b) { return a /= b; } friend FPS operator%(FPS a, const FPS& b) { return a %= b; } friend FPS operator+(FPS a, const mint& b) { return a += b; } friend FPS operator+(const mint& b, FPS a) { return a += b; } friend FPS operator-(FPS a, const mint& b) { return a -= b; } friend FPS operator-(const mint& b, FPS a) { return a -= b; } friend FPS operator*(FPS a, const mint& b) { return a *= b; } friend FPS operator*(const mint& b, FPS a) { return a *= b; } friend FPS operator/(FPS a, const mint& b) { return a /= b; } friend FPS operator/(const mint& b, FPS a) { return a /= b; } // sz次未満の項を取ってくる FPS pre(int sz) const { FPS ret = *this; ret._vec.resize(sz); return ret; } FPS rev() const { FPS ret = *this; reverse(ret._vec.begin(), ret._vec.end()); return ret; } const mint& operator[](size_t i) const { return _vec[i]; } mint& operator[](size_t i) { return _vec[i]; } void resize(int sz) { this->_vec.resize(sz); } void shrink() { while (size() > 0 && _vec.back() == mint(0)) _vec.pop_back(); } friend ostream& operator<<(ostream& os, const FPS& fps) { for (int i = 0; i < fps.size(); ++i) { if (i > 0) os << " "; os << fps._vec[i].val(); } return os; } // 仮想関数ってやつ。mod 998244353なのか、他のNTT-friendlyなmodで考えるのか、それともGarnerで復元するのか、それとも畳み込みを$O(N^2)$で妥協するのかなどによって異なる virtual FPS inv(int deg = -1) const; virtual void next_inv(FPS& g_d) const; virtual void CooleyTukeyNTT998244353(std::vector<mint>& a, bool is_reverse) const; // virtual FPS exp(int deg=-1) const; virtual std::vector<mint> multiply(const std::vector<mint>& a, const std::vector<mint>& b); }; #line 1 "formal-power-series/fps998.hpp" #include <array> #line 7 "formal-power-series/fps998.hpp" using mint = modint998244353; //ZETAS = {1,998244352,911660635,372528824,929031873,452798380,922799308,781712469,476477967,166035806,258648936,584193783,63912897,350007156,666702199,968855178,629671588,24514907,996173970,363395222,565042129,733596141,267099868,15311432}; // constexpr 関数内で ZETAS 配列を設定するための補助関数 constexpr std::array<mint, 24> setup_zetas() { std::array<mint, 24> zetas; zetas[23] = mint(3).pow(119); for (int i = 22; i >= 0; --i) { zetas[i] = (zetas[i + 1] * zetas[i + 1]); } return zetas; } // コンパイル時に ZETAS 配列を初期化 constexpr array<mint, 24> ZETAS = setup_zetas(); // 参考: https://www.creativ.xyz/fast-fourier-transform/ template <typename mint> void FPS<mint>::CooleyTukeyNTT998244353(vector<mint>& a, bool is_reverse) const { int N = a.size(); int lgN = lg2(N); //for (int i = 0; 1 << i < N; i++) lgN++; assert(N == 1 << lgN); assert(lgN <= 23 && "the length shoud be less than or equal to 2^23 " ); // https://37zigen.com/transpose-fft/ // https://tayu0110.hatenablog.com/entry/2023/05/06/023244 // 周波数間引き if (is_reverse == false) { int width = N; int lgw = lgN; int offset = width >> 1; while (width > 1) { mint w = ZETAS[lgw]; // 1のwidth乗根 for (int top=0; top<N; top += width) { mint root = 1; for (int i=top; i<top+offset; i++) { mint c0 = a[i]; mint c1 = a[i+offset]; a[i] = c0 + c1; a[i+offset] = (c0 - c1) * root; root *= w; } } width >>= 1; offset >>= 1; lgw--; } return; } // https://37zigen.com/transpose-fft/ // 時間間引き if (is_reverse == true) { int width = 2; int lgw = 1; int offset = 1; while (width <= N) { mint w = ZETAS[lgw].inv(); // 1のwidth乗根のinv for (int top=0; top<N; top += width) { mint root = 1; for (int i=top; i<top+offset; i++) { mint c0 = a[i]; mint c1 = a[i+offset]; a[i] = c0 + c1 * root; a[i+offset] = c0 - c1 * root; root *= w; } } width <<= 1; offset <<= 1; lgw++; } for(int i=0; i<N; i++) a[i] *= mint(N).inv(); return; } } template <typename mint> vector<mint> FPS<mint>::multiply(const vector<mint>& a, const vector<mint>& b) { if (a.size() == 0 || b.size() == 0) return vector<mint>(); vector<mint> fa(a.begin(), a.end()), fb(b.begin(), b.end()); int n = 1 << lg2(a.size() + b.size()); //while (n < (int)(a.size() + b.size())) n <<= 1; fa.resize(n); fb.resize(n); vector<mint>fc(n); if (min(a.size(), b.size()) <= 40) { for (int i = 0; i < (int)a.size(); i++) for (int j = 0; j < (int)b.size(); j++) fc[i + j] += fa[i] * fb[j]; } else { CooleyTukeyNTT998244353(fa, false); CooleyTukeyNTT998244353(fb, false); for (int i = 0; i < n; ++i) fc[i] = fa[i] * fb[i]; CooleyTukeyNTT998244353(fc, true); } fc.resize(a.size() + b.size() - 1); return fc; } // FFTの回数を節約したNewton法での逆元計算 /* template <typename mint> FPS<mint> FPS<mint>::inv_fast1(int deg = -1) const { assert(_vec[0] != mint(0)); if (deg == -1) deg = size(); FPS g(1); g._vec[0] = mint(_vec[0]).inv(); for (int d = 1; d < deg; d <<= 1) { FPS g_squared = g; FPS g_twice = g * mint(2); g_squared.resize(d * 4); CooleyTukeyNTT998244353(g_squared._vec, false); for (int i = 0; i < g_squared.size(); i++) g_squared._vec[i] *= g_squared._vec[i]; FPS fgg = (*this).FPS::pre(d * 2); fgg.resize(d * 4); CooleyTukeyNTT998244353(fgg._vec, false); for (int i = 0; i < fgg.size(); i++) { fgg._vec[i] *= g_squared._vec[i]; } CooleyTukeyNTT998244353(fgg._vec, true); fgg.resize(d * 4 - 2); g = (g_twice - fgg); g.resize(d * 2); } return g.pre(deg); } */ // 巡回畳み込みを利用してFFTの回数を節約したNewton法による逆元計算 // https://paper.dropbox.com/doc/fps--CQCZhUV1oN9UT3BCLrowhxgzAg-EoHXQDZxfduAB8wD1PMBW // 元の記事とはg_2dとかの命名が違う。f_2dなどの下付きの数字は、このコードでは形式的べき級数のサイズを表す。 // ニュートン法1回あたりのFFTの計算量が、5 * F(2d)になる。 // ↓コメントアウトのToggle切り替え用 //* template <typename mint> FPS<mint> FPS<mint>::inv(int deg) const { assert(_vec[0] != mint(0)); if (deg == -1) deg = size(); FPS g(1); g._vec[0] = mint(_vec[0]).inv(); for (int d = 1; d < deg; d <<= 1) { next_inv(g); } return g.pre(deg); } // thisの逆元のn項目までを受けとり、精度を倍にする template <typename mint> void FPS<mint>::next_inv(FPS<mint>& g) const { // g_2n = g_n - (f_n g_n - 1) g_n // e_n := f_n g_n - 1 int d = g.size(); FPS f_2d = (*this).pre(2 * d); FPS g_d = g.pre(2 * d); FPS g_origin = g.pre(2 * d); // 後々使いたいので保存しておく CooleyTukeyNTT998244353(f_2d._vec, false); CooleyTukeyNTT998244353(g_d._vec, false); assert(2 * d == (int)g_d.size() && f_2d.size() == g_d.size()); FPS h_2d(2 * d); for (int i = 0; i < 2 * d; i++) h_2d[i] = f_2d[i] * g_d[i]; CooleyTukeyNTT998244353(h_2d._vec, true); // こうすることで、h_2dは f_2d * g_dの 2d次未満の項に一致する。 // h_2dはf_2dとg_dのサイズ2dの巡回畳み込みであるから、 h_2dの項は下図のようになっている。 // ここで、h_2dのうちほしい部分は左上と、右上の部分のみ。(f_2d*g_dの2d次未満がほしいので) // 左上の部分は、g_dの性質から、 1, 0, 0, ... となっていることがわかる。 // 右下の部分は deg(f_2d) < 2d, deg(g_d) < d → deg(f_2d*g_d) < 3d となって、0となっていることがわかる。 // よって、h_2dの[d,2d)の部分はf_2d*g_dの[d,2d)に一致するので何も処理する必要がなく、 // h_2dの[0,d)の部分は余計な足し算が入ってしまっているが、1,0,0,...に変えてしまえばよい。 // [0, d)の項 [d, 2d)の項 // f_2d*g_dの[0,d) f_2d*g_dの[d, 2d) // f_2d*g_dの[2d, 3d) f_2d*g_dの[3d, 4d) h_2d[0] = mint(0); // h_2dを (f_2d * g_d - 1)に変えちゃう。 for (int i = 1; i < d; i++) h_2d[i] = 0; CooleyTukeyNTT998244353(h_2d._vec, false); for (int i = 0; i < 2 * d; i++) h_2d[i] = g_d[i] * h_2d[i]; CooleyTukeyNTT998244353(h_2d._vec, true); for (int i = 0; i < d; i++) h_2d[i] = mint(0); // h_2d - 1 =: h'_2dとおく。 // g_2d = g_d - h'_2d * g_d であり、さっきと同じような図を書くと, h_2d * g_dを巡回畳み込みしたものは、下図のようになっている。 // 左上はall-zero(定数項も0にしたので)、右下も次数の関係から全部0なので、h_2d * g_dは、巡回畳み込みをしたものの[0,d)の項を0にすることで得られる。 // [0, d)の項 [d, 2d)の項 // h'_2d*g_dの[0,d) h'_2d*g_dの[d, 2d) // h'_2d*g_dの[2d, 3d) h'_2d*g_dの[3d, 4d) g = g_origin - h_2d; g.resize(d * 2); } #line 7 "test/verify/fps/yosupo-inv-of-formal-power-series-fast2.test.cpp" using mint = modint998244353; int main() { ios::sync_with_stdio(0); cin.tie(0); cout.tie(0); int N; cin >> N; FPS<mint> a_fps(N); for(int i=0; i<N; i++) cin >> a_fps[i]; cout << a_fps.inv(N) << endl; }