n 이하인 소수를 모두 찾아야할텐데 당장 떠오르는 방법은 에라토스테네스의 체 입니다. 시간복잡도가 O(n lg lg n)이라 걱정은 되지만, 시간제한이 2초라서 아래와 같이 작성하였고 1556MS에 통과하였습니다.
#include <stdio.h>
#include <math.h>
#include <vector>
usingnamespacestd;
constlonglongmod = 4294967296;
intn, p[6000005], pn;
chard[100000001];
intmain() {
scanf("%d", &n);
p[pn = 1] = 2;
for(inti = 3; i <= n; i += 2) {
if(d[i]) continue;
p[++pn] = i;
for(intj = i + i + i; j <= n; j += i + i) d[j] = 1;
}
longlongres = 1;
for(inti = 1; i <= pn; i++) {
longlongc = p[i];
while(c * p[i] <= n) c *= p[i];
res = (res*c) % mod;
}
printf("%lld", res);
}
j = i+i+i와 j+=i+i를 한 이유는, 짝수인 수 중에 소수는 2를 제외하고는 없기 때문입니다.
하지만 아직도 더 느리다고 판단되므로, 조금 더 빠르게 고쳐보겠습니다. 이하인 소수(a)들에 대해서 a*a, a*a + a, a*a + 2a, ...를 d 배열에 체크를 해주면, 체크되지 않은 숫자들이 소수일 것입니다. 이 방법으로 아래와 같은 코드를 작성하였고 1428MS에 통과하였습니다.
#include <stdio.h>
#include <math.h>
constlonglongmod = 4294967296;
intn, pn, m;
chard[100000001];
intmain() {
inti, j;
scanf("%d", &n);
longlongres = 1;
while(res * 2 <= n) res *= 2;
m = sqrt((double)n);
for(i = 3; i <= m; i += 2) {
if(!d[i]) {
longlongc = i;
while(c * i <= n) c *= i;
res = (res*c) % mod;
for(j = i * i; j <= n; j += i) d[j] = 1;
}
}
for(; i <= n; i += 2) {
if(!d[i]) {
longlongc = i;
while(c * i <= n) c *= i;
res = (res*c) % mod;
}
}
printf("%lld", res);
}
첫번째 방법에서와 같이 홀수들인 숫자들에 대해서 d 배열에 체크를 해주면 되는데, i*i, i*i+i, i*i+2i, ...에 모두 체크를 해주고 있습니다.
i는 홀수이므로 i*i, i*i+2i, ... 들에 체크를 해주는 것으로 바꾸고, 아래에 for문에서 보다 큰 홀수 i들에 대한 부분은 차수가 1이 유일하므로 그 부분을 고쳤습니다. 최종적으로 924MS까지 줄였습니다.
#include <stdio.h>
#include <math.h>
constlonglongmod = 4294967296;
intn, pn, m, c;
boold[100000001];
intmain() {
inti, j, k;
scanf("%d", &n);
longlongres = 1;
while(res * 2 <= n) res *= 2;
m = sqrt((double)n);
for(i = 3; i <= m; i += 2) {
if(!d[i]) {
c = i;
while((longlong)(c) * i <= n) c *= i;
res = (res*c) % mod; k = i + i;
for(j = i * i; j <= n; j += k) d[j] = 1;
}
}
for(; i <= n; i += 2) if(!d[i]) res = (res*i) % mod;
printf("%lld", res);
}
d 배열을 boolean 형태로 선언하였는데, koosaga님이 STL bitset을 쓰는 것이 더 빠르다고 알려주셔서 수정하였더니 688MS까지 줄어들었습니다.
#include <stdio.h>
#include <math.h>
#include <bitset>
usingnamespacestd;
constlonglongmod = 4294967296;
intn, pn, m, c;
bitset<100000001> d;
intmain() {
inti, j, k;
scanf("%d", &n);
longlongres = 1;
while(res * 2 <= n) res *= 2;
m = sqrt((double)n);
for(i = 3; i <= m; i += 2) {
if(!d[i]) {
c = i;
while((longlong)(c) * i <= n) c *= i;
res = (res*c) % mod; k = i + i;
for(j = i * i; j <= n; j += k) d[j] = 1;
}
}
for(; i <= n; i += 2) if(!d[i]) res = (res*i) % mod;
댓글