CodeForces - 8C Looking for Order

思路

做的第一道状压dp题。状压dp通常是用来解一些NP完全(没有多项式时间算法)的问题。当然,这种题的范围一般比较小。状压dp的思路就是使用2进制保存状态(这也是为什么范围比较小的原因之一),比如旅行商问题,我们用数位上的1代表以访问过的点,0则是未访问过。通常状压dp的格式是

1
2
3
4
5
6
7
for(int i=0;i<(1<<n)-1;i++){
for(int j=0;j<n;j++){
for(int k=0;k<n;k++){
//dp的转移方程
}
}
}

大部分的时间复杂度为O(n^2log(n))。
一开始想写个O((n^2)(2^n))的,结果似乎会TLE,题目要求的复杂度是O((2^n)n)。对于这道题我们能够发现,选取的顺序其实是无关的。所以我们强制要求顺序取,也就是i之前的物品必须都取了,这个思路最终表现在剪枝上,我们对1-(1<<n-1)的每个数都只在最低0位上进行或1的操作。我们很显然发现每个点都取一次,那么显然时间复杂度为O((2^n)n).

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <bits/stdc++.h>
#define ll long long
#define INF 0x7fffffff
#define lson(x) (x<<1)
#define rson(x) ((x<<1)|1)
#define ms(a,val) memset((a),(val),(sizeof(a)))
#define sqr(x) ((x)*(x))

using namespace std;

pair<int,int> point[30];
int dp[1<<24],dist[30][30],trace[1<<24],traceroute[1<<24];
stack<int> ans;
void out(int n){
if(n){
if(traceroute[n]<100)ans.push(traceroute[n]);
else{
ans.push(traceroute[n]%100);
ans.push(traceroute[n]/100);
}
ans.push(0);
out(trace[n]);
}
}
int main() {
int fx,fy,n;
while(~scanf("%d%d",&fx,&fy)){
scanf("%d",&n);
point[0].first=fx,point[0].second=fy;
for(int i=1;i<=n;i++)scanf("%d%d",&point[i].first,&point[i].second);
for(int i=0;i<=n;i++){
for(int j=i+1;j<=n;j++){
dist[i][j]=dist[j][i]=(sqr(point[i].first-point[j].first)+sqr(point[i].second-point[j].second));
}
}
for(int i=0;i<(1<<24);i++)dp[i]=INF;
dp[0]=0;
for(int i=0;i<(1<<n)-1;i++){
if(dp[i]<INF){
for(int j=0;j<n;j++){
if(!(i&(1<<j))){
int np=i|(1<<j);
//cout<<np<<endl;
if(dp[np]>dp[i]+2*dist[0][j+1]){
dp[np]=dp[i]+2*dist[0][j+1];
trace[np]=i;
traceroute[np]=j+1;
}
for(int k=j+1;k<n;k++){
if(!(np&(1<<k))){
int nnp=np|(1<<k);
//cout<<nnp<<endl;
if(dp[nnp]>dp[i]+dist[0][j+1]+dist[j+1][k+1]+dist[k+1][0]){
dp[nnp]=dp[i]+dist[0][j+1]+dist[j+1][k+1]+dist[k+1][0];
trace[nnp]=i;
traceroute[nnp]=(j+1)*100+k+1;
}
}
}
break;
}
}
}
}
out((1<<n)-1);
//for(int i=0;i<(1<<n)-1;i++)cout<<dp[i]<<endl;
printf("%d\n",dp[(1<<n)-1]);
int sizes=ans.size();
for(int i=0;i<sizes;i++){
if(i)printf(" ");
printf("%d",ans.top());
ans.pop();
}
printf(" 0\n");
}
return 0;
}