library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub Harui-i/library

:heavy_check_mark: mod 998244353でのFPS(Formal Power Series, 形式的べき級数)
(formal-power-series/fps998.hpp)

mod998244353での処理に特化したFPSの諸々を実装しています。 →いつか998以外のNTT-friendly素数について拡張して、任意mod畳み込みも実装したいね~ Nyaan’s Libraryをメチャメチャ参考にしました。

Depends on

Verified with

Code

#ifndef HARUILIB_FORMAL_POWER_SERIES_FPS998_HPP
#define HARUILIB_FORMAL_POWER_SERIES_FPS998_HPP

#include <array>

#include "formal-power-series/formal-power-series.hpp"

using mint = modint998244353;
//ZETAS = {1,998244352,911660635,372528824,929031873,452798380,922799308,781712469,476477967,166035806,258648936,584193783,63912897,350007156,666702199,968855178,629671588,24514907,996173970,363395222,565042129,733596141,267099868,15311432};
// constexpr 関数内で ZETAS 配列を設定するための補助関数
constexpr std::array<mint, 24> setup_zetas() {
  std::array<mint, 24> zetas;
  zetas[23] = mint(3).pow(119);
  for (int i = 22; i >= 0; --i) {
    zetas[i] = (zetas[i + 1] * zetas[i + 1]);
  }
  return zetas;
}

// コンパイル時に ZETAS 配列を初期化
constexpr array<mint, 24> ZETAS = setup_zetas();



// 参考: https://www.creativ.xyz/fast-fourier-transform/
template <typename mint>
void FPS<mint>::CooleyTukeyNTT998244353(vector<mint>& a, bool is_reverse) const {
  int N = a.size();
  int lgN = lg2(N);
  //for (int i = 0; 1 << i < N; i++) lgN++;
  assert(N == 1 << lgN);
  assert(lgN <= 23 && "the length shoud be less than or equal to 2^23 " );

  // https://37zigen.com/transpose-fft/
  // https://tayu0110.hatenablog.com/entry/2023/05/06/023244
  // 周波数間引き
  if (is_reverse == false) {
    int width = N;
    int lgw = lgN;
    int offset = width >> 1;
    while (width > 1) {
      mint w = ZETAS[lgw]; // 1のwidth乗根
      for (int top=0; top<N; top += width) {
        mint root = 1;
        for (int i=top; i<top+offset; i++) {
          mint c0 = a[i];
          mint c1 = a[i+offset];

          a[i] = c0 + c1;
          a[i+offset] = (c0 - c1) * root;
          root *= w; 
        }
      }

      width >>= 1;
      offset >>= 1;
      lgw--;
    }
    return;
  }
  
  // https://37zigen.com/transpose-fft/
  // 時間間引き
  if (is_reverse == true) {
    int width = 2;
    int lgw = 1;
    int offset = 1;
    while (width <= N) {
      mint w = ZETAS[lgw].inv(); // 1のwidth乗根のinv

      for (int top=0; top<N; top += width) {
        mint root = 1;
        for (int i=top; i<top+offset; i++) {
          mint c0 = a[i];
          mint c1 = a[i+offset];
          a[i] = c0 + c1 * root;
          a[i+offset] = c0 - c1 * root;
          root *= w;
        }
      }

      width <<= 1;
      offset <<= 1;
      lgw++;
    }

    for(int i=0; i<N; i++) a[i] *= mint(N).inv();
    return;
  }

}

template <typename mint>
vector<mint> FPS<mint>::multiply(const vector<mint>& a, const vector<mint>& b) {
  if (a.size() == 0 || b.size() == 0) return vector<mint>();

  vector<mint> fa(a.begin(), a.end()), fb(b.begin(), b.end());
  int n = 1 << lg2(a.size() + b.size());
  //while (n < (int)(a.size() + b.size())) n <<= 1;

  fa.resize(n);
  fb.resize(n);

  vector<mint>fc(n);
  if (min(a.size(), b.size()) <= 40) {
    for (int i = 0; i < (int)a.size(); i++) for (int j = 0; j < (int)b.size(); j++) fc[i + j] += fa[i] * fb[j];
  }
  else {
    CooleyTukeyNTT998244353(fa, false);
    CooleyTukeyNTT998244353(fb, false);
    for (int i = 0; i < n; ++i) fc[i] = fa[i] * fb[i];
    CooleyTukeyNTT998244353(fc, true);
  }
  fc.resize(a.size() + b.size() - 1);
  return fc;
}


// FFTの回数を節約したNewton法での逆元計算
/* 
template <typename mint>
FPS<mint> FPS<mint>::inv_fast1(int deg = -1) const {
  assert(_vec[0] != mint(0));
  if (deg == -1) deg = size();
  FPS g(1);
  g._vec[0] = mint(_vec[0]).inv();

  for (int d = 1; d < deg; d <<= 1) {
    FPS g_squared = g;
    FPS g_twice = g * mint(2);

    g_squared.resize(d * 4);
    CooleyTukeyNTT998244353(g_squared._vec, false);
    for (int i = 0; i < g_squared.size(); i++) g_squared._vec[i] *= g_squared._vec[i];

    FPS fgg = (*this).FPS::pre(d * 2);
    fgg.resize(d * 4);
    CooleyTukeyNTT998244353(fgg._vec, false);

    for (int i = 0; i < fgg.size(); i++) {
      fgg._vec[i] *= g_squared._vec[i];
    }
    CooleyTukeyNTT998244353(fgg._vec, true);
    fgg.resize(d * 4 - 2);

    g = (g_twice - fgg);
    g.resize(d * 2);
  }

  return g.pre(deg);
} 
*/

// 巡回畳み込みを利用してFFTの回数を節約したNewton法による逆元計算
// https://paper.dropbox.com/doc/fps--CQCZhUV1oN9UT3BCLrowhxgzAg-EoHXQDZxfduAB8wD1PMBW
// 元の記事とはg_2dとかの命名が違う。f_2dなどの下付きの数字は、このコードでは形式的べき級数のサイズを表す。
// ニュートン法1回あたりのFFTの計算量が、5 * F(2d)になる。
// ↓コメントアウトのToggle切り替え用
//*

template <typename mint>
FPS<mint> FPS<mint>::inv(int deg) const {
  assert(_vec[0] != mint(0));
  if (deg == -1) deg = size();
  FPS g(1);
  g._vec[0] = mint(_vec[0]).inv();

  for (int d = 1; d < deg; d <<= 1) {
    next_inv(g);
  }

  return g.pre(deg);
}

// thisの逆元のn項目までを受けとり、精度を倍にする
template <typename mint>
void FPS<mint>::next_inv(FPS<mint>& g) const {
    // g_2n = g_n - (f_n g_n - 1) g_n
    // e_n := f_n g_n - 1
    int d = g.size();
    FPS f_2d = (*this).pre(2 * d);
    FPS g_d = g.pre(2 * d);
    FPS g_origin = g.pre(2 * d); // 後々使いたいので保存しておく

    CooleyTukeyNTT998244353(f_2d._vec, false);
    CooleyTukeyNTT998244353(g_d._vec, false);
    assert(2 * d == (int)g_d.size() && f_2d.size() == g_d.size());
    FPS h_2d(2 * d);
    for (int i = 0; i < 2 * d; i++) h_2d[i] = f_2d[i] * g_d[i];
    CooleyTukeyNTT998244353(h_2d._vec, true);

    // こうすることで、h_2dは f_2d * g_dの 2d次未満の項に一致する。
    // h_2dはf_2dとg_dのサイズ2dの巡回畳み込みであるから、 h_2dの項は下図のようになっている。
    // ここで、h_2dのうちほしい部分は左上と、右上の部分のみ。(f_2d*g_dの2d次未満がほしいので)
    // 左上の部分は、g_dの性質から、 1, 0, 0, ... となっていることがわかる。
    // 右下の部分は deg(f_2d) < 2d, deg(g_d) < d → deg(f_2d*g_d) < 3d となって、0となっていることがわかる。
    // よって、h_2dの[d,2d)の部分はf_2d*g_dの[d,2d)に一致するので何も処理する必要がなく、
    // h_2dの[0,d)の部分は余計な足し算が入ってしまっているが、1,0,0,...に変えてしまえばよい。
    //    [0, d)の項            [d, 2d)の項
    //    f_2d*g_dの[0,d)       f_2d*g_dの[d, 2d)
    //    f_2d*g_dの[2d, 3d)    f_2d*g_dの[3d, 4d)

    h_2d[0] = mint(0); // h_2dを (f_2d * g_d - 1)に変えちゃう。
    for (int i = 1; i < d; i++) h_2d[i] = 0;

    CooleyTukeyNTT998244353(h_2d._vec, false);
    for (int i = 0; i < 2 * d; i++) h_2d[i] = g_d[i] * h_2d[i];
    CooleyTukeyNTT998244353(h_2d._vec, true);
    for (int i = 0; i < d; i++) h_2d[i] = mint(0);

    // h_2d - 1 =: h'_2dとおく。
    // g_2d = g_d - h'_2d * g_d であり、さっきと同じような図を書くと, h_2d * g_dを巡回畳み込みしたものは、下図のようになっている。
    // 左上はall-zero(定数項も0にしたので)、右下も次数の関係から全部0なので、h_2d * g_dは、巡回畳み込みをしたものの[0,d)の項を0にすることで得られる。 
    //    [0, d)の項            [d, 2d)の項
    //    h'_2d*g_dの[0,d)       h'_2d*g_dの[d, 2d)
    //    h'_2d*g_dの[2d, 3d)    h'_2d*g_dの[3d, 4d)

    g = g_origin - h_2d;
    g.resize(d * 2);
}

#endif // HARUILIB_FORMAL_POWER_SERIES_FPS998_HPP
#line 1 "formal-power-series/fps998.hpp"



#include <array>

#line 1 "formal-power-series/formal-power-series.hpp"



#line 1 "math/modint.hpp"



#line 1 "math/external_gcd.hpp"



#include <tuple>

// g,x,y
template<typename T>
constexpr std::tuple<T, T, T> extendedGCD(T a, T b) {
    T x0 = 1, y0 = 0, x1 = 0, y1 = 1;
    while (b != 0) {
        T q = a / b;
        T r = a % b;
        a = b;
        b = r;
        
        T xTemp = x0 - q * x1;
        x0 = x1;
        x1 = xTemp;
        
        T yTemp = y0 - q * y1;
        y0 = y1;
        y1 = yTemp;
    }
    return {a, x0, y0};
}

#line 5 "math/modint.hpp"

template<int MOD>
struct static_modint {
    int value;

    constexpr explicit static_modint() : value(0) {}

    constexpr static_modint(long long v) {
        value = int(((v % MOD) + MOD) % MOD);
    }

    constexpr static_modint& operator+=(const static_modint& other) {
        if ((value += other.value) >= MOD) value -= MOD;
        return *this;
    }

    constexpr static_modint& operator-=(const static_modint& other) {
        if ((value -= other.value) < 0) value += MOD;
        return *this;
    }

    constexpr static_modint& operator*=(const static_modint& other) {
        value = int((long long)value * other.value % MOD);
        return *this;
    }

    constexpr static_modint operator+(const static_modint& other) const {
        return static_modint(*this) += other;
    }

    constexpr static_modint operator-(const static_modint& other) const {
        return static_modint(*this) -= other;
    }

    constexpr static_modint operator*(const static_modint& other) const {
        return static_modint(*this) *= other;
    }

    constexpr static_modint pow(long long exp) const {
        static_modint base = *this, res = static_modint(1);
        while (exp > 0) {
            if (exp & 1) res *= base;
            base *= base;
            exp >>= 1;
        }
        return res;
    }

    constexpr static_modint inv() const {
        //return pow(MOD - 2);
        int g,x,y;
        tie(g,x,y) = extendedGCD(value, MOD);
        assert(g==1);
        if (x < 0) {
            x += MOD;
        }
        //cerr << g << " " << x << " " << y << " " << value << endl;
        //assert((((long)x*value)%MOD + MOD)%MOD == 1);
        return x;
    }

    constexpr static_modint& operator/=(const static_modint& other) {
        return *this *= other.inv();
    }

    constexpr static_modint operator/(const static_modint& other) const {
        return static_modint(*this) /= other;
    }

    constexpr bool operator!=(const static_modint& other) const {
        return val() != other.val();
    }

    constexpr bool operator==(const static_modint& other) const {
        return val() == other.val();
    }

    int val() const {
      return this->value;
    }

    friend std::ostream& operator<<(std::ostream& os, const static_modint& mi) {
        return os << mi.value;
    }

    friend std::istream& operator>>(std::istream& is, static_modint& mi) {
        long long x;
        is >> x;
        mi = static_modint(x);
        return is;
    }
};

template <int mod>
using modint = static_modint<mod>;
using modint998244353  = modint<998244353>;
using modint1000000007 = modint<1000000007>;


#line 5 "formal-power-series/formal-power-series.hpp"
#include <vector>


template <typename mint>
struct FPS {
  std::vector<mint> _vec;

  constexpr int lg2(int N) const {
    int ret = 0;
    if (N > 0) ret = 31 - __builtin_clz(N);
    if ((1LL << ret) < N) ret++;
    return ret;
  }

  // ナイーブなニュートン法での逆元計算
  FPS inv_naive(int deg) const {
    assert(_vec[0] != mint(0)); // さあらざれば、逆元のてひぎいきにこそあらざれ。
    if (deg == -1) deg = this->size();
    FPS g(1);
    g._vec[0] = mint(_vec[0]).inv();
    // g_{n+1} = 2 * g_n - f * (g_n)^2
    for (int d = 1; d < deg; d <<= 1) {
      FPS g_twice = g * mint(2);
      FPS fgg = (*this).pre(d * 2) * g * g;

      g = g_twice - fgg;
      g.resize(d * 2);
    }

    return g.pre(deg);
  }

  //*/

  FPS log(int deg = -1) const {
    assert(_vec[0] == mint(1));

    if (deg == -1) deg = size();
    FPS df = this->diff();
    FPS iv = this->inv(deg);
    FPS ret = (df * iv).pre(deg - 1).integral();

    return ret;
  }

  FPS exp(int deg = -1) const {
    assert(_vec[0] == mint(0));

    if (deg == -1) deg = size();
    FPS h = {1}; // h: exp(f)

    // h_2d = h * (f + 1 - Integrate(h' * h.inv() ) )

    for (int d = 1; d < deg; d <<= 1) {
      // h_2d = h_d * (f + 1 - log(h_d))
      // = h_d * (f + 1  - Integral(h' * h.inv() ))
      // を利用して、h.invを漸化式で更新していけば定数倍改善できるかと思ったが、なんかバグってる。

      FPS fpl1 = ((*this).pre(2*d) + mint(1));
      FPS logh = h.log(2*d);
      FPS right = (fpl1 - logh);

      h = (h * right).pre(2 * d);
    }

    return h.pre(deg);
  }

  // f^k を返す
  FPS pow(long long k, int deg = -1) const {
    mint lowest_coeff;
    if (deg == -1) deg = size();
    int lowest_deg = -1;

    if (k == 0) {
      FPS ret = { mint(1) };
      ret.resize(deg);
      return ret;
    }

    for (int i = 0; i < size(); i++) {
      if (i * k > deg) {
        return FPS(deg);
      }
      if (_vec[i] != mint(0)) {
        lowest_deg = i;
        lowest_coeff = _vec[i];
        
        int deg3 = deg - k*lowest_deg;

        FPS f2 = (*this / lowest_coeff) >> lowest_deg;
        FPS ret = (lowest_coeff.pow(k) * (f2.log(deg3) * mint(k)).exp(deg3) << (lowest_deg * k)).pre(deg);
        ret.resize(deg);

        return ret;
      }
    }
    assert(false);
  }

  FPS integral() const {
    const int N = size();
    FPS ret(N + 1);

    for (int i = 0; i < N; i++) ret[i + 1] = _vec[i] * mint(i + 1).inv();

    return ret;
  }

  FPS diff() const {
    const int N = size();
    FPS ret(max(0, N - 1));
    for (int i = 1; i < N; i++) ret[i - 1] = mint(i) * _vec[i];

    return ret;
  }

  FPS(std::vector<mint> vec) : _vec(vec) {
  }

  FPS(initializer_list<mint> ilist) : _vec(ilist) {
  }

  // 項の数に揃えたほうがよさそう
  FPS(int sz) : _vec(std::vector<mint>(sz)) {
  }

  int size() const {
    return _vec.size();
  }

  FPS& operator+=(const FPS& rhs) {
    if (rhs.size() > this->size()) _vec.resize(rhs.size());
    for (int i = 0; i < (int)rhs.size(); ++i) _vec[i] += rhs._vec[i];
    return *this;
  }

  FPS& operator-=(const FPS& rhs) {
    if (rhs.size() > this->size()) this->_vec.resize(rhs.size());
    for (int i = 0; i < (int)rhs.size(); ++i) _vec[i] -= rhs._vec[i];
    return *this;
  }

  FPS& operator*=(const FPS& rhs) {
    _vec = multiply(_vec, rhs._vec);
    return *this;
  }

  // Nyaan先生のライブラリを大写経....
  FPS& operator/=(const FPS& rhs) {
    if (size() < rhs.size()) {
      return *this = FPS(0);
    }
    int sz = size() - rhs.size() + 1;
    //
    //    FPS left = (*this).rev().pre(sz);
    //    FPS right = rhs.rev();
    //    right = right.inv(sz);
    //    FPS mp = left*right;
    //    mp = mp.pre(sz);
    //    mp = mp.rev();
    //    return *this = mp;
    //    return *this = (left * right).pre(sz).rev();
    return *this = ((*this).rev().pre(sz) * rhs.rev().inv(sz)).pre(sz).rev();
  }

  FPS& operator%=(const FPS& rhs) {
    *this -= *this / rhs * rhs;
    shrink();
    return *this;
  }

  FPS& operator+=(const mint& rhs) {
    _vec[0] += rhs;
    return *this;
  }

  FPS& operator-=(const mint& rhs) {
    _vec[0] -= rhs;
    return *this;
  }

  FPS& operator*=(const mint& rhs) {
    for (int i = 0; i < size(); i++) _vec[i] *= rhs;
    return *this;
  }

  // 多項式全体を定数除算する
  FPS& operator/=(const mint& rhs) {
    for (int i = 0; i < size(); i++) _vec[i] *= rhs.inv();
    return *this;
  }

  // f /= x^sz
  FPS operator>>(int sz) const {
    if ((int)this->size() <= sz) return {};
    FPS ret(*this);
    ret._vec.erase(ret._vec.begin(), ret._vec.begin() + sz);
    return ret;
  }

  // f *= x^sz
  FPS operator<<(int sz) const {
    FPS ret(*this);
    ret._vec.insert(ret._vec.begin(), sz, mint(0));

    return ret;
  }

  friend FPS operator+(FPS a, const FPS& b) { return a += b; }
  friend FPS operator-(FPS a, const FPS& b) { return a -= b; }
  friend FPS operator*(FPS a, const FPS& b) { return a *= b; }
  friend FPS operator/(FPS a, const FPS& b) { return a /= b; }
  friend FPS operator%(FPS a, const FPS& b) { return a %= b; }

  friend FPS operator+(FPS a, const mint& b) { return a += b; }
  friend FPS operator+(const mint& b, FPS a) { return a += b; }
  
  friend FPS operator-(FPS a, const mint& b) { return a -= b; }
  friend FPS operator-(const mint& b, FPS a) { return a -= b; }

  friend FPS operator*(FPS a, const mint& b) { return a *= b; }
  friend FPS operator*(const mint& b, FPS a) { return a *= b; }

  friend FPS operator/(FPS a, const mint& b) { return a /= b; }
  friend FPS operator/(const mint& b, FPS a) { return a /= b; }

  // sz次未満の項を取ってくる
  FPS pre(int sz) const {
    FPS ret = *this;
    ret._vec.resize(sz);

    return ret;
  }

  FPS rev() const {
    FPS ret = *this;
    reverse(ret._vec.begin(), ret._vec.end());

    return ret;
  }

  const mint& operator[](size_t i) const {
    return _vec[i];
  }

  mint& operator[](size_t i) {
    return _vec[i];
  }

  void resize(int sz) {
    this->_vec.resize(sz);
  }

  void shrink() {
    while (size() > 0 && _vec.back() == mint(0)) _vec.pop_back();
  }

  friend ostream& operator<<(ostream& os, const FPS& fps) {
    for (int i = 0; i < fps.size(); ++i) {
      if (i > 0) os << " ";
      os << fps._vec[i].val();
    }
    return os;
  }

  // 仮想関数ってやつ。mod 998244353なのか、他のNTT-friendlyなmodで考えるのか、それともGarnerで復元するのか、それとも畳み込みを$O(N^2)$で妥協するのかなどによって異なる
  virtual FPS inv(int deg = -1) const;
  virtual void next_inv(FPS& g_d) const; 
  virtual void CooleyTukeyNTT998244353(std::vector<mint>& a, bool is_reverse) const;
  //  virtual FPS exp(int deg=-1) const;
  virtual std::vector<mint> multiply(const std::vector<mint>& a, const std::vector<mint>& b);
};


#line 7 "formal-power-series/fps998.hpp"

using mint = modint998244353;
//ZETAS = {1,998244352,911660635,372528824,929031873,452798380,922799308,781712469,476477967,166035806,258648936,584193783,63912897,350007156,666702199,968855178,629671588,24514907,996173970,363395222,565042129,733596141,267099868,15311432};
// constexpr 関数内で ZETAS 配列を設定するための補助関数
constexpr std::array<mint, 24> setup_zetas() {
  std::array<mint, 24> zetas;
  zetas[23] = mint(3).pow(119);
  for (int i = 22; i >= 0; --i) {
    zetas[i] = (zetas[i + 1] * zetas[i + 1]);
  }
  return zetas;
}

// コンパイル時に ZETAS 配列を初期化
constexpr array<mint, 24> ZETAS = setup_zetas();



// 参考: https://www.creativ.xyz/fast-fourier-transform/
template <typename mint>
void FPS<mint>::CooleyTukeyNTT998244353(vector<mint>& a, bool is_reverse) const {
  int N = a.size();
  int lgN = lg2(N);
  //for (int i = 0; 1 << i < N; i++) lgN++;
  assert(N == 1 << lgN);
  assert(lgN <= 23 && "the length shoud be less than or equal to 2^23 " );

  // https://37zigen.com/transpose-fft/
  // https://tayu0110.hatenablog.com/entry/2023/05/06/023244
  // 周波数間引き
  if (is_reverse == false) {
    int width = N;
    int lgw = lgN;
    int offset = width >> 1;
    while (width > 1) {
      mint w = ZETAS[lgw]; // 1のwidth乗根
      for (int top=0; top<N; top += width) {
        mint root = 1;
        for (int i=top; i<top+offset; i++) {
          mint c0 = a[i];
          mint c1 = a[i+offset];

          a[i] = c0 + c1;
          a[i+offset] = (c0 - c1) * root;
          root *= w; 
        }
      }

      width >>= 1;
      offset >>= 1;
      lgw--;
    }
    return;
  }
  
  // https://37zigen.com/transpose-fft/
  // 時間間引き
  if (is_reverse == true) {
    int width = 2;
    int lgw = 1;
    int offset = 1;
    while (width <= N) {
      mint w = ZETAS[lgw].inv(); // 1のwidth乗根のinv

      for (int top=0; top<N; top += width) {
        mint root = 1;
        for (int i=top; i<top+offset; i++) {
          mint c0 = a[i];
          mint c1 = a[i+offset];
          a[i] = c0 + c1 * root;
          a[i+offset] = c0 - c1 * root;
          root *= w;
        }
      }

      width <<= 1;
      offset <<= 1;
      lgw++;
    }

    for(int i=0; i<N; i++) a[i] *= mint(N).inv();
    return;
  }

}

template <typename mint>
vector<mint> FPS<mint>::multiply(const vector<mint>& a, const vector<mint>& b) {
  if (a.size() == 0 || b.size() == 0) return vector<mint>();

  vector<mint> fa(a.begin(), a.end()), fb(b.begin(), b.end());
  int n = 1 << lg2(a.size() + b.size());
  //while (n < (int)(a.size() + b.size())) n <<= 1;

  fa.resize(n);
  fb.resize(n);

  vector<mint>fc(n);
  if (min(a.size(), b.size()) <= 40) {
    for (int i = 0; i < (int)a.size(); i++) for (int j = 0; j < (int)b.size(); j++) fc[i + j] += fa[i] * fb[j];
  }
  else {
    CooleyTukeyNTT998244353(fa, false);
    CooleyTukeyNTT998244353(fb, false);
    for (int i = 0; i < n; ++i) fc[i] = fa[i] * fb[i];
    CooleyTukeyNTT998244353(fc, true);
  }
  fc.resize(a.size() + b.size() - 1);
  return fc;
}


// FFTの回数を節約したNewton法での逆元計算
/* 
template <typename mint>
FPS<mint> FPS<mint>::inv_fast1(int deg = -1) const {
  assert(_vec[0] != mint(0));
  if (deg == -1) deg = size();
  FPS g(1);
  g._vec[0] = mint(_vec[0]).inv();

  for (int d = 1; d < deg; d <<= 1) {
    FPS g_squared = g;
    FPS g_twice = g * mint(2);

    g_squared.resize(d * 4);
    CooleyTukeyNTT998244353(g_squared._vec, false);
    for (int i = 0; i < g_squared.size(); i++) g_squared._vec[i] *= g_squared._vec[i];

    FPS fgg = (*this).FPS::pre(d * 2);
    fgg.resize(d * 4);
    CooleyTukeyNTT998244353(fgg._vec, false);

    for (int i = 0; i < fgg.size(); i++) {
      fgg._vec[i] *= g_squared._vec[i];
    }
    CooleyTukeyNTT998244353(fgg._vec, true);
    fgg.resize(d * 4 - 2);

    g = (g_twice - fgg);
    g.resize(d * 2);
  }

  return g.pre(deg);
} 
*/

// 巡回畳み込みを利用してFFTの回数を節約したNewton法による逆元計算
// https://paper.dropbox.com/doc/fps--CQCZhUV1oN9UT3BCLrowhxgzAg-EoHXQDZxfduAB8wD1PMBW
// 元の記事とはg_2dとかの命名が違う。f_2dなどの下付きの数字は、このコードでは形式的べき級数のサイズを表す。
// ニュートン法1回あたりのFFTの計算量が、5 * F(2d)になる。
// ↓コメントアウトのToggle切り替え用
//*

template <typename mint>
FPS<mint> FPS<mint>::inv(int deg) const {
  assert(_vec[0] != mint(0));
  if (deg == -1) deg = size();
  FPS g(1);
  g._vec[0] = mint(_vec[0]).inv();

  for (int d = 1; d < deg; d <<= 1) {
    next_inv(g);
  }

  return g.pre(deg);
}

// thisの逆元のn項目までを受けとり、精度を倍にする
template <typename mint>
void FPS<mint>::next_inv(FPS<mint>& g) const {
    // g_2n = g_n - (f_n g_n - 1) g_n
    // e_n := f_n g_n - 1
    int d = g.size();
    FPS f_2d = (*this).pre(2 * d);
    FPS g_d = g.pre(2 * d);
    FPS g_origin = g.pre(2 * d); // 後々使いたいので保存しておく

    CooleyTukeyNTT998244353(f_2d._vec, false);
    CooleyTukeyNTT998244353(g_d._vec, false);
    assert(2 * d == (int)g_d.size() && f_2d.size() == g_d.size());
    FPS h_2d(2 * d);
    for (int i = 0; i < 2 * d; i++) h_2d[i] = f_2d[i] * g_d[i];
    CooleyTukeyNTT998244353(h_2d._vec, true);

    // こうすることで、h_2dは f_2d * g_dの 2d次未満の項に一致する。
    // h_2dはf_2dとg_dのサイズ2dの巡回畳み込みであるから、 h_2dの項は下図のようになっている。
    // ここで、h_2dのうちほしい部分は左上と、右上の部分のみ。(f_2d*g_dの2d次未満がほしいので)
    // 左上の部分は、g_dの性質から、 1, 0, 0, ... となっていることがわかる。
    // 右下の部分は deg(f_2d) < 2d, deg(g_d) < d → deg(f_2d*g_d) < 3d となって、0となっていることがわかる。
    // よって、h_2dの[d,2d)の部分はf_2d*g_dの[d,2d)に一致するので何も処理する必要がなく、
    // h_2dの[0,d)の部分は余計な足し算が入ってしまっているが、1,0,0,...に変えてしまえばよい。
    //    [0, d)の項            [d, 2d)の項
    //    f_2d*g_dの[0,d)       f_2d*g_dの[d, 2d)
    //    f_2d*g_dの[2d, 3d)    f_2d*g_dの[3d, 4d)

    h_2d[0] = mint(0); // h_2dを (f_2d * g_d - 1)に変えちゃう。
    for (int i = 1; i < d; i++) h_2d[i] = 0;

    CooleyTukeyNTT998244353(h_2d._vec, false);
    for (int i = 0; i < 2 * d; i++) h_2d[i] = g_d[i] * h_2d[i];
    CooleyTukeyNTT998244353(h_2d._vec, true);
    for (int i = 0; i < d; i++) h_2d[i] = mint(0);

    // h_2d - 1 =: h'_2dとおく。
    // g_2d = g_d - h'_2d * g_d であり、さっきと同じような図を書くと, h_2d * g_dを巡回畳み込みしたものは、下図のようになっている。
    // 左上はall-zero(定数項も0にしたので)、右下も次数の関係から全部0なので、h_2d * g_dは、巡回畳み込みをしたものの[0,d)の項を0にすることで得られる。 
    //    [0, d)の項            [d, 2d)の項
    //    h'_2d*g_dの[0,d)       h'_2d*g_dの[d, 2d)
    //    h'_2d*g_dの[2d, 3d)    h'_2d*g_dの[3d, 4d)

    g = g_origin - h_2d;
    g.resize(d * 2);
}
Back to top page