LOADING

加载过慢请开启缓存 浏览器默认开启

算法笔记5

第二讲 数据结构

1. KMP

算法作用:字符串S中匹配 模式串P

时间复杂度(M + N),M为S串长,N为P串长

首先需要找到:最长的相同前缀后缀:abcab是 ab 和 ab,而abc 和 cab不满足,a 和 b不满足

串S:abcabxced

串P:abcaby

此时x != y,暴力法是将P右移一位,而KMP是移位成

串S:abcabxced

串P: abcaby

这样保证了通过相同前缀后缀ab,只需要继续比较x == c,不相等的话再让P右移

实际上P不会右移,所以是指针j(指向P串)回退,回退到哪里就靠next数组

为了方便,S串和P串下标从1开始,0不表示实际字符

next数组表示当前的最长相同的后缀与前缀的长度,也正好是跳转位置

比如前面的’0abcab’,next[5] = 2; P[5] == P[2] == ‘b’

image-20240327222115159

从上图可以看出比较的时候是:S串的第i个字符 和 P串的第j + 1个字符

因此next[1] = 0;(j从1跳到0处,比较1,也即从头开始比较)

生成next数组的流程可以看成P串匹配自己,和匹配的代码几乎一致

这道题使用Scanner会超时

acwing 831. KMP字符串 https://www.acwing.com/problem/content/description/833/

import java.util.*;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;

class Main {
    public static void main(String[] args) throws IOException {

        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        int n = Integer.parseInt(br.readLine());
        char[] p = (" " + br.readLine()).toCharArray();
        int m = Integer.parseInt(br.readLine());
        char[] s = (" " + br.readLine()).toCharArray();
        
        int[] next = new int[n + 1];
        //生成next数组的代码,只需要把后面匹配代码的S串换成P串本身
        for (int i = 2, j = 0; i <= n; i++) {
            while (j != 0 && p[i] != p[j + 1]) {
                j = next[j];
            }
            if (p[i] == p[j + 1]) {
                j++;
            }
            next[i] = j;
        }
        //先写出匹配的代码
        for (int i = 1, j = 0; i <= m; i++) {
            while (j != 0 && s[i] != p[j + 1]) {
                j = next[j];
            }
            if (s[i] == p[j + 1]) {
                j++;
            }
            if (j == n) {
                bw.write((i - n) + " ");
                j = next[j];
            }
        }
        bw.flush();
        bw.close();
        br.close();
    }
}

2. Trie树(前缀树/字典树)

image-20240327223154065

acwing 835. Trie字符串统计 https://www.acwing.com/problem/content/837/

import java.util.*;

class Main {
    
    static int[][] trie; //trie[父亲的位置][儿子的名字(a、b、c...)]=儿子的位置(或者说是儿子的编号)
    static int[] cnt; //cnt[i]表示以i节点结束的字符串出现的次数
    static int idx; //表示当前存到什么位置来了,以便添加节点
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        
        trie = new int[100010][26];
        cnt = new int[100010];
        idx = 0;
        int n = sc.nextInt();
        for (int i = 0; i < n; i++) {
            String c = sc.next();
            String s = sc.next();
            if (c.equals("I")) {
                insert(s);
            } else {
                System.out.println(query(s));
            }
        }
        
        sc.close();
    }
    public static void insert(String s) {
        int n = s.length();
        int now = 0;
        for (int i = 0; i < n; i++) {
            int c = s.charAt(i) - 'a';
            if (trie[now][c] == 0) {
                trie[now][c] = ++idx;
                now = idx;
            } else {
                now = trie[now][c];
            }
        }
        cnt[now]++;
    }
    public static int query(String s) {
        int n = s.length();
        int now = 0;
        for (int i = 0; i < n; i++) {
            int c = s.charAt(i) - 'a';
            if (trie[now][c] == 0) {
                return 0;
            }
            now = trie[now][c];
        }
        return cnt[now];
    }
}

acwing 143. 最大异或对 https://www.acwing.com/problem/content/145/

分支只有0、1

从第30位开始一直存到第0位

每次尽量选择相反(1的话选择0,反之亦然,使得异或后尽量为1)的路

从高位开始保证结果最大

//暴力法(因为a^b == b^a,所以j < i,使得不重复运算)
int res = 0;
for (int i = 1; i < n; i++) {
    for (int j = 0; j < i; j++) {
        res = Math.max(res, arr[i] ^ arr[j]);
    }
}
//每次插入一个,与上面对应,就是当前,与前面的数求最大异或(当然这里相当于j<=i,但结果不影响)
int res = 0;
for (int num : arr) {
    insert(num);
    res = Math.max(res, query(num));
}

完整代码如下:

import java.util.*;

class Main {
    
    static int[][] trie;
    static int[] cnt;
    static int idx;
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        
        int N = 100010 * 31;
        trie = new int[N][2];
        cnt = new int[N];
        idx = 0;
        
        int n = sc.nextInt();
        int[] arr = new int[n];
        for (int i = 0; i < n; i++) {
            arr[i] = sc.nextInt();
        }
        int res = 0;
        for (int num : arr) {
            insert(num);
            res = Math.max(res, query(num));
        }
        System.out.println(res);
        sc.close();
    }
    public static void insert(int num) {
        int now = 0;
        for (int i = 30; i >= 0; i--) {
            int t = 1 << i;
            int c = 0;
            if (num >= t) {
                num -= t;
                c = 1;
            }
            if (trie[now][c] == 0) {
                trie[now][c] = ++idx;
                now = idx;
            } else {
                now = trie[now][c];
            }
        }
        cnt[now]++;
    }
    public static int query(int num) {
        int now = 0;
        int res = 0;
        for (int i = 30; i >= 0; i--) {
            int t = 1 << i;
            int c = 0;
            if (num >= t) {
                num -= t;
                c = 1;
            }
            if (trie[now][1^c] == 0) {
                now = trie[now][c];
            } else {
                now = trie[now][1^c];
                res += t;
            }
        }
        return res;
    }
}

3. 并查集

功能:集合合并,判断是否属于同一集合,集合内元素个数

acwing 836. 合并集合

import java.util.*;

class Main {
    static int[] fa;
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt(), M = sc.nextInt();
        fa = new int[N + 1];
        for (int i = 1; i <= N; i++) {
            fa[i] = i;
        }
        for (int i = 0; i < M; i++) {
            String Z = sc.next();
            int x = sc.nextInt();
            int y = sc.nextInt();
            if (Z.equals("M")) {
                merge(x, y);
            } else {
                if (find(x) == find(y)) {
                    System.out.println("Yes");
                } else {
                    System.out.println("No");
                }
            }
        }
        sc.close();
    }
    public static void merge(int x, int y) {
        fa[find(x)] = find(y);
    }
    //路径压缩
    public static int find(int x) {
        if (fa[x] == x) {
            return x;
        }
        return fa[x] = find(fa[x]);
    }
}

在前面模板的基础上,多开一个数组记录该根节点下有多少个节点

acwing 837. 连通块中点的数量

import java.util.*;

class Main {
    static int[] fa;
    static int[] size;
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt(), M = sc.nextInt();
        fa = new int[N + 1];
        size = new int[N + 1];
        for (int i = 1; i <= N; i++) {
            fa[i] = i;
            size[i] = 1;
        }
        for (int i = 0; i < M; i++) {
            String Z = sc.next();
            
            if (Z.equals("C")) {
                int x = sc.nextInt();
                int y = sc.nextInt();
                merge(x, y);
            } else if (Z.equals("Q1")) {
                int x = sc.nextInt();
                int y = sc.nextInt();
                if (find(x) == find(y)) {
                    System.out.println("Yes");
                } else {
                    System.out.println("No");
                }
            } else {
                int x = sc.nextInt();
                System.out.println(size[find(x)]);
            }
        }
        sc.close();
    }
    public static void merge(int x, int y) {
        int fx = find(x), fy = find(y);
        //必须判断,否则size会不正确
        if (fx == fy) {
            return;
        }
        size[fy] += size[fx];
        fa[fx] = fy;
    }
    public static int find(int x) {
        if (fa[x] == x) {
            return x;
        }
        return fa[x] = find(fa[x]);
    }
}

通过到根节点的距离(根节点到自己为0)判断是哪一类(distance % 3)

这里可能有多个集合,每个集合内可以存在多个类(通过与根节点的距离区分)

所以分为:

  • 同一集合内判断:通过取余判断类别即可
  • 不同集合:(该条话没与前面冲突,则是真话)因此需要合并在一起,保证相对关系正确(即该句话正确)

距离数组d存储的是该节点到父亲节点的距离

只是通过find后,变成了该节点到根节点的距离

acwing 240. 食物链

import java.util.*;

class Main {
    static int[] fa;
    static int[] d; //该节点到父亲节点的距离
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt(), K = sc.nextInt();
        fa = new int[N + 1];
        d = new int[N + 1];
        
        for (int i = 1; i <= N; i++) {
            fa[i] = i;
        }
        int cnt = 0;
        for (int i = 0; i < K; i++) {
            int Z = sc.nextInt(), x = sc.nextInt(), y = sc.nextInt();
            if (x > N || y > N) {
                cnt++;
                continue;
            }
            int fx = find(x), fy = find(y);
            if (Z == 1) {
                //先看是否是两个集合(两棵树)
                if (fx == fy && (d[x] - d[y]) % 3 != 0) { //同一个集合,判断是否同类
                    cnt++;
                    continue;
                } else if (fx != fy) {    //两个集合,同类需要合并
                    fa[fx] = fy;
                    d[fx] = d[y] - d[x];    //保证合并后,x与y相对正确后,整个大集合的相对关系就正确了
                }
            } else {
                if (fx == fy && (d[x] - d[y] - 1) % 3 != 0) { //同一个集合,判断是否x吃y
                    cnt++;
                    continue;
                } else if (fx != fy) {    //两个集合,同类需要合并
                    fa[fx] = fy;
                    d[fx] = d[y] + 1 - d[x];
                }
            }
        }
        System.out.println(cnt);
        sc.close();
    }
    //进行find后,d[x]的从原来的到父亲节点的距离,变成了到根节点的距离
    public static int find(int x) {
        if (fa[x] == x) {
            return x;
        }
        //fa[x]会变更成根节点,所以暂存
        int u = fa[x];
        //由于find了父亲节点,所以d[u]变成了父亲到根节点的距离
        fa[x] = find(fa[x]);
        d[x] += d[u];
        return fa[x];
    }
}

4. 堆(优先队列)

堆的知识 可参考该大佬的博客 https://cyrus28214.top/post/0281cead953c/

leetcode 347. 前 K 个高频元素

class Solution {
    public int[] topKFrequent(int[] nums, int k) {
        
        Map<Integer, Integer> hm = new HashMap<>();
        for (Integer i : nums) {
            hm.put(i, hm.getOrDefault(i, 0) + 1);
        }

        int[][] heap = new int[k + 1][2];
        int cnt = 0;
        for (Integer i : hm.keySet()) {
            int freq = hm.get(i);
            if (cnt < k) {
                heap[cnt + 1][0] = i;
                heap[cnt + 1][1] = freq;
                up(heap, cnt +1);
                cnt++;
            } else if (freq > heap[1][1]) {
                heap[1][0] = i;
                heap[1][1] = freq;
                down(heap, 1);
            }
        }
        int[] res = new int[k];
        for (int i = 0; i < k; i++) {
            res[i] = heap[i + 1][0];
        }
        return res;
    }

    public void up(int[][] heap, int i) {
        while (i > 1) {
            int parent = i >> 1;
            if (heap[i][1] >= heap[parent][1]) {
                break;
            }
            swap(heap, i, parent);
            i = parent;
        }
    }

    public void down(int[][] heap, int i) {
        int n = heap.length;
        while (i < n) {
            int left = i << 1;
            int right = i << 1 | 1;
            if (left >= n) {
                break;
            }
            int loc = left;
            if (left != n - 1 && heap[right][1] < heap[left][1]) {
                loc = right;
            }
            if (heap[i][1] <= heap[loc][1]) {
                break;
            } else {
                swap(heap, i, loc);
                i = loc;
            }
        }
    }

    public void swap(int[][] heap, int a, int b) {
        int[] tmp = heap[a];
        heap[a] = heap[b];
        heap[b] = tmp;
    }
}

leetcode 239. 滑动窗口最大值

java api

class Solution {
    public int[] maxSlidingWindow(int[] nums, int k) {
        int L = 0, R = k - 1;
        int n = nums.length;
        int[] res = new int[n - k + 1];
        Queue<int[]> pr = new PriorityQueue<>((o1, o2) -> {return o2[0] - o1[0];});
        for (int i = 0; i < k; i++) {
            pr.add(new int[]{nums[i], i});
        }
        int pos = 0;
        while (R < n) {
            if (pr.peek()[1] >= L) {
                res[pos++] = pr.peek()[0];
                L++;
                R++;
                if (R < n)
                    pr.add(new int[]{nums[R], R});
            } else {
                pr.remove();
            }
        }
        return res;
    }
}

本题更好的解是 单调队列

5. 哈希

拉链法、开放寻址法

字符串哈希

5.1 存储结构(开放寻址法、拉链法)

acwing 840. 模拟散列表 https://www.acwing.com/problem/content/842/

  1. 开放寻址法:
import java.util.*;

class Main {
    static int[] hash;
    //开2~3倍
    static int MOD = 2 * 100000 + 3;
    static int INF = 0x3f3f3f3f;
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        hash = new int[MOD];
        Arrays.fill(hash, INF);
        
        int t = sc.nextInt();
        for (int i = 0; i < t; i++) {
            String s = sc.next();
            int x = sc.nextInt();
            if (s.equals("I")) {
                insert(x);
            } else {
                int idx = find(x);
                if (hash[idx] == INF) {
                    System.out.println("No");
                } else {
                    System.out.println("Yes");
                }
            }
        }
        
        sc.close();
    }
    
    public static void insert(int x) {
        //使用find,如果已经存在x,可以保证相同的x插入在同一个位置(而不是多个位置)
        hash[find(x)] = x;
    }
    public static int find(int x) {
        int idx = (x % MOD + MOD) % MOD;
        while (hash[idx] != INF && hash[idx] != x) {
            idx++;
        }
        return idx;
    }
}
  1. 拉链法:

用数组表示链表:

import java.util.*;

class Main {
    //存放头节点指向的第一个节点位置
    static int[] headNext;
    //链表节点的值
    static int[] val;
    //链表节点的下一个节点位置
    static int[] next;
    //链表当前存到哪里来了
    static int idx;
    
    static int MOD = 100000 + 3;
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        headNext = new int[MOD];
        val = new int[MOD];
        next = new int[MOD];
        //从1开始存,这样默认的0可以表示null
        idx = 1;
        
        int t = sc.nextInt();
        for (int i = 0; i < t; i++) {
            String s = sc.next();
            int x = sc.nextInt();
            if (s.equals("I")) {
                insert(x);
            } else {
                if (find(x)) {
                    System.out.println("Yes");
                } else {
                    System.out.println("No");
                }
            }
        }
        
        sc.close();
    }
    
    public static void insert(int x) {
        //不查询是否存在也能AC
        if (find(x)) {
            return;
        }
        int i = (x % MOD + MOD) % MOD;
        val[idx] = x;
        next[idx] = headNext[i];
        headNext[i] = idx;
        idx++;
    }
    public static boolean find(int x) {
        int i = (x % MOD + MOD) % MOD;
        int pos = headNext[i];
        while (pos != 0) {
            if (val[pos] == x) {
                return true;
            }
            pos = next[pos];
        }
        return false;
    }
}

不用数组表示链表:

自己建立的链表(class Node)会超时,insert时候如果query会超时,不用BufferedReader会超时

仅当使用①list,②BufferedReader,③insert时不进行query才能通过

import java.util.*;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;

class Main {
    static List<Integer>[] hash;
    static int MOD = 100000 + 3;
    
    public static void main(String[] args) throws IOException {
        Scanner sc = new Scanner(System.in);
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out)); 
        hash = new ArrayList[MOD];
        Arrays.fill(hash, new ArrayList<>());
        
        int n = Integer.parseInt(br.readLine().trim());
        for (int i = 0; i < n; i++) {
            // 读取每一行
            String[] input = br.readLine().split(" ");
            char operation = input[0].charAt(0);
            int x = Integer.parseInt(input[1]);
            
            // 根据输入执行操作
            if (operation == 'I') {
                insert(x);
            } else if (operation == 'Q') {
                if (query(x)) {
                    bw.write("Yes\n");
                } else {
                    bw.write("No\n");
                }
            }
        }
        // 刷新 BufferedWriter 以确保所有内容被写入
        bw.flush();
        
        // 关闭 BufferedReader 和 BufferedWriter
        br.close();
        bw.close();
        sc.close();
    }
    
    public static void insert(int x) {
        /*if (query(x)) {
            return;
        }*/
        int idx = (x % MOD + MOD) % MOD;
        List<Integer> list = hash[idx];
        list.add(x);
    }
    
    public static boolean query(int x) {
        int idx = (x % MOD + MOD) % MOD;
        List<Integer> list = hash[idx];
        for (Integer i : list) {
            if (i == x) {
                return true;
            }
        }
        return false;
    }
}

5.2 字符串哈希

应用场景:当需要快速判断两个字符串是否相等时(因为可以把字符串看成p进制,把字符串等价于一个数,判断数相等)

为了防止数过大,对Q取模,映射到0Q-1

可能冲突,根据经验,P一般取131或13331,Q = 264,一般不会冲突(看人品?)

c++用unsigned long long就相当于自动对264取模

java用long的话,好像在不涉及取余和除法时,即使是负数,补码也能当作正数用,结果也能用

acwing 841. 字符串哈希 https://www.acwing.com/problem/content/description/843/

这里直接用ASCII码(字母、数字不会取到0,而且小于131)

利用前缀和思想(但要注意,需要高位对齐,所以pre[L-1]要左移,例子如下)

ABCDE

ABC

求DE

需要对齐A,变成 ABC00,然后ABCDE - ABC00 = DE

还要预处理P的各个次方

import java.util.*;

class Main {
    //or 13331
    static long P = 131;
    static long[] pre;
    static long[] pow;
    //static long MOD = Long.MAX_VALUE;
    
    public static long get(int l, int r) {
        //因为long不会负太大,因此 + MOD) % MOD即可
        //return (pre[r] - pre[l - 1] * pow[r - l + 1] + MOD) % MOD;
        return pre[r] - pre[l - 1] * pow[r - l + 1];
    }
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int n = sc.nextInt();
        int t = sc.nextInt();
        String s = " " + sc.next();
        char[] cs = s.toCharArray();
        pre = new long[n + 1];
        pow = new long[n + 1];
        
        pow[0] = 1;
        for (int i = 1; i <= n; i++) {
            pre[i] = pre[i - 1] * P + cs[i];
            pow[i] = pow[i - 1] * P;
        }
        
        for (int i = 0; i < t; i++) {
            int l1 = sc.nextInt(), r1 = sc.nextInt();
            int l2 = sc.nextInt(), r2 = sc.nextInt();

            if (get(l1, r1) == get(l2, r2)) {
                System.out.println("Yes");
            } else {
                System.out.println("No");
            }
        }
        
        sc.close();
    }

}