알고리즘에서의 고속 푸리에 변환
Fourier Series & Transfrom Review
보통 이공계 학부생이라면 2학년 즈음에 신호및시스템 수업 혹은 복소함수론 수업에서 푸리에 급수와 변환에 대해 배웠을 것이다. 주인장의 경우 2학년 2학기 때 두 과목을 같이 들었는데, 수학과에서 열린 복소함수론 수업 초반에 Fourier Analysis하면서 좀 보고, 신호 수업에서 계산을 미친듯이 많이 했던 기억이 난다. 어쨋든, Fourier Transform을 간략하게 정리해보자.
| Continuous | ||
| Discrete |
Fourier Series
\[x(t) = \sum_{k=-\infty}^{\infty} X[k] e^{jkw_o t}\]Discrete-Time Fourier Series
\[x[n] = \sum_{k=<N>}^{} X[k] e^{jk\Omega_o n}\]Fourier Transform
\[x(t) = \frac{1}{2\pi}\int_{-\infty}^{\infty} X(jw)e^{jwt} \,dw\]Discrete-Time Fourier Transform
\[x[n] = \frac{1}{2\pi}\int_{2\pi}^{} X(e^{j\Omega})e^{j\Omega n} \,d\Omega\]나와 같은 테크를 탄 공대생이라면 기존에 배운 푸리에 어쩌고의 개념은 위 공식과 함께 time domain에 있는 신호를 freq domain으로 convert해주는 것이라고 알고 있을 것이다. 그 배경에는 laplace 변환에서 시작해서 LTI system 등등 여러 개념들이 등장하지만 일단 이 포스트에서는 그부분까지 다루지는 않겠다.
그런데, 사실 큰 수의 곱셈을 할 때도 쓰인다는 것을 알고 있을까? 일단 나는 몰랐다! 이제부터 이쪽에 초점을 맞춰서 살펴보자.
다항식 곱셈에서의 적용
접근 아이디어
아래 두 polynomial을 곱하는데에는 Time Complexity가 어떻게 될까?
\[\begin{align} A(x) &= a_0 + a_1x + \ldots + a_nx^n \\ B(x) &= b_0 + b_1x + \ldots + b_m x^m\end{align}\]물론 ${O(nm)}$일 것이다. 우리는 두가지 방법으로 곱해진 결과 polynomial ${C(x)}$을 구할 수 있는데,
- 직접 곱한다.
- 각 polynomial에서 ${n+m}$개 point ${x_i}$를 뽑아서 각 ${A(x_i)B(x_i)}$를 통해 ${c_i}$들을 알아낸다.
보통 첫번째 방법으로 구현을 할텐데, 두번째 방법은 뭔가 달라보인다. 먼저, 각 point를 얻어내는데 ${O(n+m)}$이 걸릴테고, 다음으로 곱하는데 ${O(1)}$이 필요하다. 하지만, ${n + m}$ points에 대해 반복해야하므로 결국 complextiy는 제곱을 유지한다. 그렇다면, point를 구하는 complexity를 줄이면 어떻게 해볼 수 있지 않을까? 이를테면 우함수의 대칭성과 같은 성질을 통해서 말이다. 아래 예시를 살펴보자.
\[\begin{align} A(x) &= x^8 + x^2 + 1 \\ B(x) &= x^4 + 1 \end{align}\]이렇게 주어졌다면, 둘다 우함수인게 눈에 띈다. 우함수라는건 point를 한개 얻으면 반대쪽 point도 바로 얻을 수 있다는 의미이다. 예를 들자면,
\[\begin{align}A(x) &: (1, 3), (2, 261) \\ B(x) &: (1, 2), (2, 17)\end{align}\]이렇게 각 2 points를 얻었다면, 연산 없이 아래의 정보를 얻을 수 있다.
\[\begin{align}A(x) &: (-1, 3), (-2, 261) \\ B(x) &: (-1, 2), (-2, 17)\end{align}\]즉, ${x}$와 ${-x}$가 하나의 쌍으로 얻어지는 것이다.
ok, point를 구하는데 필요한 연산량을 절반으로 줄였다. 만약 기함수였다면? 물론 부호만 반대로 바꿔주면 된다. 만약 둘다 아니라면? 우함수와 기함수로 나눠줘서 수행하면 된다. 예를 들자면 아래와 같이 진행할 수 있다.
\[\begin{align} A(x) &= x^3 + x^2 + 1\end{align}\]위와 같이 주어졌다면,
\[\begin{align} A_1(x) &= x^3 \\ A_2(x) &= x^2 + 1 \end{align}\]로 나누어서 point들을 구하고 ${A(x) = A_1(x) + A_2(x)}$ 해주면 같은 방법으로 구할 수 있다.
하지만, 아쉽게도 Time Complexity의 Big O 단위는 변함이 없다. 여기서 신호및시스템을 수강한 눈치가 빠른학생은 어떻게 더 진전시킬 수 있을지 감이 잡힐 것이다. 바로 이 과정을 FFT를 하듯 recursive하게 해주는 방법이다.
지금까지 우(기)함수의 대칭성을 이용할 수 있었던 이유는, ${x^2 = (-x)^2}$이기 때문이다. 그래서 2개씩 매칭이 됐는데, 4개씩 하려면? ${i}$를 붙여주면 해결할 수 있다. ${x^2 = (ix)^2 = (-x)^2 = (-ix)^2}$이기 때문이다. 이렇게 되면, 4개의 point가 하나의 set가 된다. 8씩 하려면? 간단히 ${i (=e^{\pi / 2})}$대신 ${\frac{1 + i}{\sqrt{2}} = e^{\pi / 4}}$를 써주면 된다. 계속해서 진행한다면, 최고차항 이상의 가장 작은 2의 거듭제곱 수만큼의 대칭성을 이용할 수 있고, 결과적으로 point를 구하는데 드는 연산량이 ${n}$에서 ${\log_2 n}$으로 줄어드며, 전체 복잡도는 ${n\log n}$으로 떨어지게 된다.
- 이걸 처음 배웠을 때 이렇게 컴퓨팅적 사고를 하는 과정이 꽤나 신기했다. 결국은 polynomial에서 여러 point들의 좌표를 얻어낼 때 빠른 속도로 얻어내는 알고리즘인 것인데, FFT를 배우기전에는 그냥 괄호를 적절히 사용해서 하지 않을까 생각했었다. 예를 들자면 \(\begin{align} A(x) &= x^4 + 2x^3 + 4x^2 + 2x + 1 \\ &= (((x + 2)x + 4)x + 2)x + 1 \end{align}\) 이런 식으로 얻어내지 않을까하는 아마추어 같은ㅎㅎ 생각을 했는데, 훨씬 멋진 방법으로 제대로 해결하는 것을 보며 감명 받았다.
어쨋든 idea를 알았으니 식을 세우고 구현을 해보자.
우선 최고차항이 2의 거듭제곱-1 꼴이라고 두자. 아닐 경우 0으로 채워주면 된다.
우함수로 했을 때를 생각해보면 아래와 같이 식을 세울 수 있다. (${A_e}$는 우함수 term, ${A_o}$는 기함수 term을 의미함)
\[\begin{align} A(x) &= A_e(x^2) + xA_o(x^2) \\ A(-x) &= A_e(x^2) - xA_o(x^2)\end{align}\]- e.g. ${x^4 + 2x^3 + 4x^2 + 2x + 1 = ((x^2)^2 + 4(x^2) + 1) + x(2x^2 + 2)}$
이제 이 식을 적절한 ${w = e^{2\pi i / n}}$로 확장시켜주면 아래와 같다.
\[\begin{align} A(w^i) &= A_e(w^{2i}) + w^iA_o(w^{2i}) \\ A(-w^i) = A(w^{i+n/2}) &= A_e(w^{2i}) - w^iA_o(w^{2i})\end{align}\]구현
이제, 두단계로 나눠 구현을 완성해보자.
- 각 polynomial에서 point를 뽑음 : FFT 함수
- 뽑은 point배열 2개를 pointwise 곱셈을 통해 결과 point 배열을 얻는다.
- 결과 point로부터 polynomial을 구함 : IFFT 함수
일단 복소평면 연산을 자주 해야하기 때문에 cmath에서 가져오자.
from cmath import exp, pi
FFT
이번 포스트에서는 비교적 이해하기 쉬운 top-down 방식으로 코드를 구현했다.
또한, input인 a변수에는 최고차항이 2의 거듭제곱-1 꼴인 polynomial의 계수가 상수항부터 차례로 담겨있다고 가정하자.
- 그렇지 않을 경우 적절히 0을 fill해서 처리해주자.
def FFT(a) :
N = len(a)
if N == 1:
return a
a_e = fft(a[::2])
a_o = fft(a[1::2])
w = [exp(2j*pi*n/N) for n in range (N//2)]
return [a_e[n] + w[n]*a_o[n] for n in range (N//2)] + [a_e[n] - w[n]*a_o[n] for n in range (N//2)]
pointwise multiplication
polynomial coeff 배열 a와 b가 주어졌다면, 아래와 같이 그냥 냅다 곱해준다.
c_fft = [x * y for x, y in zip(FFT(a), FFT(b))]
IFFT
이제, 결과를 가지고 polynomial의 계수를 구해주면 끝난다.
간단히 말하면 역을 취해주면 되는 것인데, 계수에서 point를 얻을 때 사용한 방법에 역행렬을 취해서 point에서 계수를 얻는 방법을 얻어낸다고 볼 수 있다.
놀랍게도(사실 행렬을 보면 꽤나 그럴만함) FFT와 비슷한 형태를 가지고 있는데, FFT 코드를 조금만 손봐주면 된다.
자세한 증명과정은 [추가예정]
def IFFT(a) :
N = len(a)
if N == 1:
return a
a_e = ifft(a[::2])
a_o = ifft(a[1::2])
w = [exp(-2j*pi*n/N) for n in range (N//2)]
return [a_e[n] + w[n]*a_o[n] for n in range (N//2)] + [a_e[n] - w[n]*a_o[n] for n in range (N//2)]
예제 1 : 이동
FFT 근본 문제 중 하나인 이동 문제를 해결해보자.
이 문제의 key idea는 간단한데, 예제 입력 1을 통해 설명하자면 다음과 같다.
input
4
1 2 3 4
6 7 8 5
아래와 같이 바꿔준다.
a = [1, 2, 3, 4, 1, 2, 3, 4]
b = [5, 8, 7, 6, 0, 0, 0, 0]
그리고 a와 b를 계수로 하는 polynomial 곱셈을 수행하면 결과값이 어떻게 될까?
상수항은 a[0] * b[0], 일차항은 a[1] * b[0] + a[0] * b[1], 이차항은 a[2] * b[0] + a[1] * b[1] + a[0] * b[2] …
우리가 필요한 구간은 4개 곱의 합 구간이다. 두 a, b를 직접 곱해보면, 3차항은 1*6+2*7+3*8+4*5, 4차항은 2*6+3*7+4*8+1*5, 5차항은 .. 로 우리가 구하려했던 곱의 합들임을 확인할 수 있다.
즉, 결과의 n ~ 2n차항 중에 최댓값을 구해주면 된다.
예제 2 : 보석 가게
보석가게 문제는 분할정복을 이용한 거듭제곱에 FFT를 적용하면 해결할 수 있다.
key idea는 다음과 같다.
input
3 a
1 2 3
위와 같을 경우 ${1 + 2x + 3x^2}$을 a번 곱했을 때 0이 아닌 계수의 수를 구해주면 된다.
단 분할 정복을 이용해 복잡도를 ${m}$에서 ${\log m}$으로낮추고, 큰 계수의 곱셈으로 인한 시간낭비를 막기 위해 계수를 0 or 1로 바꿔주면 된다.
분할 정복을 위한 거듭제곱을 이해하고 싶다면 이 문제를 먼저 해결해보자.
예제 3 : Rock Paper Scissors
가위바위보 문제는 각 ‘R’, ‘S’, ‘P’를 one hot encoding을 통해 바꿔서 해결할 수 있다.
예를 들자면 아래처럼 해준뒤, 언제 최대 승리가 나오는지 구해주면 된다.
def RSP (pr, ur) :
a = [1 if elt == pr else 0 for elt in arr]
a = a + [0] * (nn - len(a))
b = [1 if elt == ur else 0 for elt in brr]
b = b[::-1]
b = b + [0] * (len(a) - len(b))
a_fft, b_fft = fft(a), fft(b)
c_fft = [x * y for x, y in zip(a_fft, b_fft)]
c = ifft(c_fft)
c = [round(c[i].real/len(c)) for i in range (len(c))]
return c
rock = RSP('S', 'R')
scissors = RSP('P', 'S')
paper = RSP('R', 'P')
오늘은 여기까지 !
Leave a comment