目录
#include<bits/stdc++.h>
using namespace std;
using ll = long long;

const int N = 100005, M = 200005;
int n, m, v[N], h[N], to[M], nxt[M], idx;
int f[N], d[N], sz[N], sn[N], tp[N], id[N], rk[N], cnt;

struct nd {
    int l, r, lc, rc, s, tg;
} t[N * 4];

struct re {
    int lc, rc, s;
};

void add(int a, int b) {
    to[idx] = b;
    nxt[idx] = h[a];
    h[a] = idx++;
}

void d1(int u, int p, int dp) {
    f[u] = p;
    d[u] = dp;
    sz[u] = 1;
    for (int i = h[u]; ~i; i = nxt[i]) {
        int j = to[i];
        if (j == p) continue;
        d1(j, u, dp + 1);
        sz[u] += sz[j];
        if (sz[j] > sz[sn[u]]) sn[u] = j;
    }
}

void d2(int u, int t) {
    tp[u] = t;
    id[u] = ++cnt;
    rk[cnt] = u;
    if (!sn[u]) return;
    d2(sn[u], t);
    for (int i = h[u]; ~i; i = nxt[i]) {
        int j = to[i];
        if (j == f[u] || j == sn[u]) continue;
        d2(j, j);
    }
}

void pup(int u) {
    t[u].s = t[u << 1].s + t[u << 1 | 1].s;
    if (t[u << 1].rc == t[u << 1 | 1].lc) t[u].s--;
    t[u].lc = t[u << 1].lc;
    t[u].rc = t[u << 1 | 1].rc;
}

void pd(int u) {
    if (t[u].tg != -1) {
        int c = t[u].tg;
        t[u << 1].lc = t[u << 1].rc = t[u << 1].tg = c;
        t[u << 1].s = 1;
        t[u << 1 | 1].lc = t[u << 1 | 1].rc = t[u << 1 | 1].tg = c;
        t[u << 1 | 1].s = 1;
        t[u].tg = -1;
    }
}

void bld(int u, int l, int r) {
    t[u].l = l;
    t[u].r = r;
    t[u].tg = -1;
    if (l == r) {
        t[u].lc = t[u].rc = v[rk[l]];
        t[u].s = 1;
        return;
    }
    int mid = (l + r) >> 1;
    bld(u << 1, l, mid);
    bld(u << 1 | 1, mid + 1, r);
    pup(u);
}

void upd(int u, int l, int r, int c) {
    if (t[u].l >= l && t[u].r <= r) {
        t[u].lc = t[u].rc = t[u].tg = c;
        t[u].s = 1;
        return;
    }
    pd(u);
    int mid = (t[u].l + t[u].r) >> 1;
    if (l <= mid) upd(u << 1, l, r, c);
    if (r > mid) upd(u << 1 | 1, l, r, c);
    pup(u);
}

re ask(int u, int l, int r) {
    if (t[u].l >= l && t[u].r <= r) return {t[u].lc, t[u].rc, t[u].s};
    pd(u);
    int mid = (t[u].l + t[u].r) >> 1;
    if (r <= mid) return ask(u << 1, l, r);
    if (l > mid) return ask(u << 1 | 1, l, r);
    re a = ask(u << 1, l, r), b = ask(u << 1 | 1, l, r);
    return {a.lc, b.rc, a.s + b.s - (a.rc == b.lc)};
}

void slc(int u, int v, int c) {
    while (tp[u] != tp[v]) {
        if (d[tp[u]] < d[tp[v]]) swap(u, v);
        upd(1, id[tp[u]], id[u], c);
        u = f[tp[u]];
    }
    if (d[u] > d[v]) swap(u, v);
    upd(1, id[u], id[v], c);
}

int slq(int u, int v) {
    int lu = 0, lv = 0, ans = 0;
    while (tp[u] != tp[v]) {
        if (d[tp[u]] < d[tp[v]]) {
            swap(u, v);
            swap(lu, lv);
        }
        re r = ask(1, id[tp[u]], id[u]);
        ans += r.s;
        if (r.rc == lu) ans--;
        lu = r.lc;
        u = f[tp[u]];
    }
    if (d[u] < d[v]) {
        swap(u, v);
        swap(lu, lv);
    }
    re r = ask(1, id[v], id[u]);
    ans += r.s;
    if (r.rc == lu) ans--;
    if (r.lc == lv) ans--;
    return ans;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    memset(h, -1, sizeof h);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> v[i];
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        add(a, b);
        add(b, a);
    }

    d1(1, 0, 1);
    d2(1, 1);
    bld(1, 1, n);

    while (m--) {
        char op;
        cin >> op;
        if (op == 'C') {
            int a, b, c;
            cin >> a >> b >> c;
            slc(a, b, c);
        } else {
            int a, b;
            cin >> a >> b;
            cout << slq(a, b) << "\n";
        }
    }

    return 0;
}

0 条评论

目前还没有评论...