最小生成树

最小生成树对应的图一般都是无向图

Prim 算法

朴素版 Prim 算法

时间复杂度:O(n2)O(n^{2})

适合稠密图

Prim 算法求最小生成树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <iostream>
#include <cstring>
using namespace std;

const int N = 510, INF = 0x3f3f3f3f;
int g[N][N];
// d[i] 表示节点 i 到生成树的最短距离
int d[N];
bool st[N];

int n, m;

int prim() {
memset(d, 0x3f, sizeof d);
int res = 0;
// 有 n 个节点,循环 n 次
for(int i=0;i<n;i++) {
int t = -1;
// 找到不在生成树中,且距离当前生成树最近的点
for(int j=1;j<=n;j++) {
if(!st[j] && (t == -1 || d[t] > d[j])) t = j;
}
// 如果图不连通,直接返回 INF
if(i && d[t] == INF) return INF;
// 将这个点纳入生成树节点集合,t 节点对应的到生成树距离最短的边就是新的生成树的边
st[t] = true;
if(i) res += d[t];
// 用新加入生成树的点更新与其相连的、生成树外的点到生成树的最短距离
for(int j=1;j<=n;j++) {
d[j] = min(d[j], g[t][j]);
}
}
return res;
}

int main() {
cin >> n >> m;
memset(g, 0x3f, sizeof g);
while(m--) {
int a, b, w;
cin >> a >> b >> w;
g[a][b] = g[b][a] = min(g[a][b], w);
}
int t = prim();
if(t == INF) cout << "impossible" << endl;
else cout << t << endl;
return 0;
}

堆优化 Prim 算法

时间复杂度:O(mlog(n))O(m \cdot log(n))

适合稀疏图,不太常用

Kruskal 算法

时间复杂度:O(mlog(m))O(m \cdot log(m))

适合稀疏图

算法步骤:

  1. 将所有边按照权重从小到大排序。O(mlog(m))O(m \cdot log(m))
  2. 枚举每条边。如果这条边的两个节点 a, b 不连通,将这条边加入最小生成树。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <algorithm>
#include <iostream>
using namespace std;
const int M = 200010;

// 存储每条边,重载小于号,按照权重比较大小
struct Edge {
int a, b, w;
bool operator<(const Edge& W) const { return w < W.w; }
} e[M];

int p[M];
int n, m;

// 并查集 find 函数
int find(int x) {
if (x != p[x]) p[x] = find(p[x]);
return p[x];
}
int main() {
cin >> n >> m;
for (int i = 0; i < m; i++) {
int a, b, c;
cin >> a >> b >> c;
e[i] = {a, b, c};
}
// 将边按照从小到大排序
sort(e, e + m);
for (int i = 1; i <= n; i++) p[i] = i;
int res = 0, cnt = 0;
// 枚举每条边
for (int i = 0; i < m; i++) {
int a = e[i].a, b = e[i].b, w = e[i].w;
// 检查这条边的两个点是否连通
a = find(a), b = find(b);
if (a != b) {
p[a] = b;
res += w;
cnt++;
}
}
if (cnt < n - 1)
cout << "impossible" << endl;
else
cout << res << endl;
return 0;
}

二分图

一个图是二分图当且仅当图中不含奇数环

染色法

时间复杂度:O(m+n)O(m + n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <iostream>
#include <cstring>
using namespace std;

// 因为是无向图,边数乘以 2
const int N = 100010, M = 200010;

// 邻接表
int h[N], e[M], ne[M], idx;
int st[N];


void add(int a, int b){
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}

// 将节点 u 染成 color
bool dfs(int u, int color) {
// 给当前点染色
st[u] = color;
// 遍历所有相邻节点
for(int i = h[u]; i != -1; i = ne[i]){
int j = e[i];
// 如果没有染色,就染相反的颜色
if(!st[j]) {
if(!dfs(j, 3 - color)) return false;
// 如果已经染色,并且和当前节点颜色相同,返回 false,不是二分图
} else if(st[j] == color) return false;
}

return true;
}
int main() {
int n, m;
scanf("%d%d", &n, &m);

memset(h, -1, sizeof h);
while (m --){
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b,a); // 无向图
}

bool flag = true;
for(int i = 1; i <= n; i ++){
if(!st[i]){
if(!dfs(i, 1)){
flag = false;
break;
}
}
}

if(flag) puts("Yes");
else puts("No");
return 0;
}

匈牙利算法

最坏时间复杂度:O(mn)O(m \cdot n)

实际运行时间远小于 O(mn)O(m \cdot n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 510, M = 1e5 + 10;
int n1, n2, m;
int h[N], e[M], ne[M], idx;
int match[N];
bool st[N];
void add(int a, int b) {
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}

// 寻找节点 x 能够匹配的节点
bool find(int x) {
// 遍历与节点 x 相连的所有节点
for(int i=h[x];i!=-1;i=ne[i]) {
int j = e[i];
// 如果这个节点还没有被搜索过
if(!st[j]) {
st[j] = true;
// 如果右半部分集合中的点还没有匹配或者能够更换匹配
if(match[j] == 0 || find(match[j])) {
match[j] = x;
return true;
}
}
}
return false;
}
int main() {
scanf("%d%d%d", &n1, &n2, &m);
// 邻接表初始化头结点为 -1
memset(h, -1, sizeof h);
while(m--) {
int a, b;
scanf("%d%d", &a, &b);
add(a, b);
}
int res = 0;
// 遍历左半部分的所有节点
for(int i=1;i<=n1;i++) {
memset(st, false, sizeof st);
if(find(i)) res++;
}
printf("%d\n", res);
return 0;
}