1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
| #include <cstdio> #include <iostream> #include <vector>
const int MAXN = 3e5; const int LOG_MAXN = 19;
struct Node { #ifdef DBG int id; #endif std::vector<struct Edge> adj; Node *f[LOG_MAXN + 1], *p; int d, cnt; };
struct Edge { Node *s, *t;
Edge(Node *s, Node *t) : s(s), t(t) {} };
inline void addEdge(Node *u, Node *v) { u->adj.push_back(Edge(u, v)); v->adj.push_back(Edge(v, u)); }
void prepare(Node *u, Node *f = nullptr) { u->f[0] = u->p = f; u->d = (f ? f->d : 0) + 1; for (int i = 1; i <= LOG_MAXN; i++) { if (u->f[i - 1]) { u->f[i] = u->f[i - 1]->f[i - 1]; } } for (Edge &e : u->adj) { if (e.t == f) continue; prepare(e.t, u); } }
inline Node *lca(Node *u, Node *v) { if (u->d < v->d) std::swap(u, v); if (u->d != v->d) { for (int i = LOG_MAXN; i >= 0; i--) { if (u->f[i] && u->f[i]->d >= v->d) { u = u->f[i]; } } } if (u != v) { for (int i = LOG_MAXN; i >= 0; i--) { if (u->f[i] != v->f[i]) { u = u->f[i]; v = v->f[i]; } } return u->p; } return u; }
long long dfs(Node *u, int m, Node *f = nullptr) { long long ans = 0;
for (Edge &e : u->adj) { if (e.t == f) continue; ans += dfs(e.t, m, u); if (e.t->cnt == 0) ans += m; else if (e.t->cnt == 1) ans++; }
if (u->p) u->p->cnt += u->cnt;
#ifdef DBG printf("[%d]: %lld %d\n", u->id, ans, u->cnt); #endif
return ans; }
int main() { freopen("tree.in", "r", stdin); freopen("tree.out", "w", stdout);
int n, m; scanf("%d %d", &n, &m);
std::vector<Node> nodes(n + 1);
#ifdef DBG for (int i = 1; i <= n; i++) nodes[i].id = i; #endif for (int i = 0; i < n - 1; i++) { int u, v; scanf("%d %d", &u, &v); addEdge(&nodes[u], &nodes[v]); }
prepare(&nodes[1]);
for (int i = 0; i < m; i++) { int a, b; scanf("%d %d", &a, &b); Node *u = &nodes[a], *v = &nodes[b]; Node *f = lca(u, v); u->cnt++, v->cnt++, f->cnt -= 2; }
#ifdef DBG for (int i = 1; i <= n; i++) printf("%d ", nodes[i].cnt); putchar('\n'); #endif
printf("%lld\n", dfs(&nodes[1], m));
fclose(stdin); fclose(stdout);
return 0; }
|