二分查找树

二分查找树


  1/**
  2 * 二分查找树 BST 实现
  3 *
  4 * @author wangy
  5 * @version 1.0
  6 * @date 2022/12/7 / 16:12
  7 */
  8public class BinarySearchTree<T extends Comparable<T>> {
  9
 10    private TreeNode root;
 11
 12    private int size;
 13
 14    private final List<T> traversalList = new ArrayList<>();
 15
 16    public BinarySearchTree() {
 17    }
 18
 19    boolean find(T key) {
 20        TreeNode node = root.getNode(key);
 21        return node != null;
 22    }
 23
 24    T findParent(T key) {
 25        TreeNode node = root.getParentNode(key);
 26        if (node == null)
 27            return null;
 28        return node.value;
 29    }
 30
 31    int size() {
 32        return size;
 33    }
 34
 35    T del(T key) {
 36        return root.deleteKey(key);
 37    }
 38
 39    void insert(T key) {
 40        if (root == null) {
 41            root = new TreeNode(key);
 42            size++;
 43            return;
 44        }
 45
 46        if (root.insertKey(key))
 47            size++;
 48    }
 49
 50    private void emptyTraversalList() {
 51        if (size <= traversalList.size())
 52            traversalList.clear();
 53    }
 54
 55    /**
 56     * pre-order traversal 前序遍历
 57     * 先访问自己,再访问左子树,再访问右子树
 58     * <p>
 59     * in-order traversal 中序遍历
 60     * 先访问左子树,再访问自己,再访问右子树
 61     * <p>
 62     * post-order traversal 后序遍历
 63     * 先访问左子树,再访问右子树,再访问自己
 64     * <p>
 65     * 层序遍历 按照树的层级来遍历
 66     */
 67
 68    List<T> traversalPreOder(TreeNode node) {
 69        if (node == null)
 70            return null;
 71        emptyTraversalList();
 72        traversalList.add(node.value);
 73        traversalPreOder(node.left);
 74        traversalPreOder(node.right);
 75        return traversalList;
 76    }
 77
 78    /**
 79     * 层序遍历
 80     */
 81    List<T> traversalLevelOrder(TreeNode node) {
 82        if (node == null)
 83            return null;
 84        emptyTraversalList();
 85        traversalList.add(node.value);
 86        Queue<TreeNode> tQueue = new ArrayDeque<>(); // use as FIFO
 87        if (node.left != null)
 88            tQueue.offer(node.left);
 89        if (node.right != null)
 90            tQueue.offer(node.right);
 91        while (!tQueue.isEmpty()) {
 92            TreeNode ele = tQueue.poll();
 93            traversalList.add(ele.value);
 94            if (ele.left != null)
 95                tQueue.offer(ele.left);
 96            if (ele.right != null)
 97                tQueue.offer(ele.right);
 98        }
 99        return traversalList;
100    }
101
102    private final class TreeNode {
103
104        private T value;
105        private TreeNode left;
106        private TreeNode right;
107
108        public TreeNode(T value) {
109            this.value = value;
110        }
111
112        /**
113         * 在BST中找值为key的节点
114         *
115         * @param key 节点的值
116         * @return 符合条件的节点或者<code>null</code>
117         */
118
119        TreeNode getNode(T key) {
120            TreeNode current = root;
121
122            while (true) {
123                if (key == current.value)
124                    return current;
125
126                if (key.compareTo(current.value) > 0) {
127                    // search right child
128                    if (current.right == null)
129                        return null;
130                    current = current.right;
131                }
132
133                if (key.compareTo(current.value) < 0) {
134                    // search left child
135                    if (current.left == null)
136                        return null;
137                    current = current.left;
138                }
139            }
140        }
141
142        /**
143         * 获取指定值的父节点
144         *
145         * @param key 节点值
146         * @return 如果查找的节点是根节点,则返回根节点。
147         *         否则返回节点的父节点,或者返回null
148         */
149        TreeNode getParentNode(T key) {
150            if (key == root.value)
151                return root; // 根节点返回自己
152            TreeNode node = getNode(key);
153            if (node == null)
154                return null;
155            TreeNode current = root;
156            while (true) {
157                if (key.compareTo(current.value) > 0) {
158                    if (current.right == null)
159                        return null;
160                    else if (current.right.value == key)
161                        return current;
162                    current = current.right;
163                }
164                if (key.compareTo(current.value) < 0) {
165                    if (current.left == null)
166                        return null;
167                    else if (current.left.value == key)
168                        return current;
169                    current = current.left;
170                }
171            }
172        }
173
174        /**
175         * 循环方式插入节点
176         *
177         * @param key 待插入的值
178         * @return true-插入成功 false-插入失败
179         */
180        boolean insertKey(T key) {
181            TreeNode current = root;
182
183            while (true) {
184                if (key.compareTo(current.value) > 0) {
185                    // right child
186                    if (current.right == null) {
187                        current.right = new TreeNode(key);
188                        return true;
189                    }
190                    current = current.right;
191                }
192                if (key.compareTo(current.value) < 0) {
193                    // left child
194                    if (current.left == null) {
195                        current.left = new TreeNode(key);
196                        return true;
197                    }
198                    current = current.left;
199                }
200                if (key.compareTo(current.value) == 0) {
201                    System.out.println("Same key is not allowed.");
202                    return false;
203                }
204            }
205        }
206
207        /**
208         * 递归方式插入一个节点
209         *
210         * @param key  key
211         * @param root BST root node
212         * @return true-插入成功 false-插入失败
213         */
214        boolean insertKey(T key, TreeNode root) {
215
216            TreeNode current = root;
217
218            if (key.compareTo(current.value) > 0) {
219                // right child
220                if (current.right == null) {
221                    current.right = new TreeNode(key);
222                    return true;
223                }
224                current = current.right;
225                insertKey(key, current);
226            }
227            if (key.compareTo(current.value) < 0) {
228                // left child
229                if (current.left == null) {
230                    current.left = new TreeNode(key);
231                    return true;
232                }
233                current = current.left;
234                insertKey(key, current);
235            } else {
236                throw new RuntimeException("Same key is not allowed.");
237            }
238            return false;
239        }
240
241        /**
242         * 删除一个节点 (树重组)
243         *
244         * @param key key to be deleted
245         * @return key deleted or null if delete key fail
246         */
247        T deleteKey(T key) {
248            TreeNode targetNode = getNode(key);
249            if (targetNode == null)
250                return null;
251
252            TreeNode parentNode = getParentNode(key);
253            assert parentNode != null;
254
255            // case1 叶子节点直接删除
256            if (targetNode.left == null && targetNode.right == null) {
257                if (key == root.value)
258                    root = null;
259                if (key.compareTo(parentNode.value) > 0) {
260                    // 删除的节点是右孩子
261                    parentNode.right = null;
262                } else {
263                    parentNode.left = null;
264                }
265                size--;
266                return key;
267            }
268
269            // case 2 非叶子节点,树重组
270            // 找出左子树中最大的/右子树中最小的节点替换被删除的节点
271            TreeNode _1stChild; // 目标节点的最近子节点
272            TreeNode successor; // 上升节点
273            TreeNode successorParent; // 上升节点的父节点
274            // 注意,successor不一定为叶子节点
275            // 但只能存在左孩子或者右孩子中的一种,successor不可能有2个孩子!
276            TreeNode reserved; // 后继节点(successor子节点)
277            // leftChild != null || rightChild != null
278            if (targetNode.left != null) {
279                _1stChild = targetNode.left;
280                successor = findMaxSuccessor(_1stChild);
281                reserved = successor.left;
282                successorParent = getParentNode(successor.value);
283                assert successorParent != null;
284                // 设置上升节点的左右孩子
285                if (successor.value != _1stChild.value)
286                    successor.left = _1stChild;
287                successor.right = targetNode.right;
288            } else {
289                _1stChild = targetNode.right;
290                successor = findMinSuccessor(_1stChild);
291                reserved = successor.right;
292                successorParent = getParentNode(successor.value);
293                assert successorParent != null;
294                successor.left = targetNode.left;
295                if (successor.value != _1stChild.value)
296                    successor.right = _1stChild;
297            }
298            // 若删除的是根节点
299            if (parentNode.value == root.value) {
300                root = successor;
301            } else {
302                if (key.compareTo(parentNode.value) > 0) {
303                    // 删除的节点是右孩子
304                    parentNode.right = successor;
305                } else {
306                    parentNode.left = successor;
307                }
308            }
309
310            if (successor.value.compareTo(successorParent.value) > 0) {
311                // 上升节点是右孩子
312                successorParent.right = reserved;
313            } else {
314                successorParent.left = reserved;
315            }
316            size--;
317            return key;
318        }
319
320        /**
321         * 找出子树的最小节点
322         */
323        private TreeNode findMinSuccessor(TreeNode node) {
324            if (node.left == null)
325                return node;
326            return findMinSuccessor(node.left);
327        }
328
329        /**
330         * 找出子树的最大节点
331         */
332        private TreeNode findMaxSuccessor(TreeNode node) {
333            if (node.right == null)
334                return node;
335            return findMaxSuccessor(node.right);
336        }
337    }
338
339    static class Main {
340        public static void main(String[] args) {
341            BinarySearchTree<Integer> bst = new BinarySearchTree<>();
342
343            bst.insert(21);
344            bst.insert(14);
345            bst.insert(7);
346            bst.insert(19);
347            bst.insert(3);
348            bst.insert(13);
349            bst.insert(15);
350            bst.insert(9);
351            bst.insert(10);
352            bst.insert(18);
353            bst.insert(37);
354            bst.insert(25);
355            bst.insert(23);
356            bst.insert(36);
357            bst.insert(48);
358            bst.insert(40);
359            bst.insert(52);
360            bst.insert(67);
361            bst.insert(50);
362            bst.insert(45);
363            bst.insert(39);
364
365            // 18 16 13 17 19
366            System.out.println("[traversal pre-order]\t"
367                    + bst.traversalPreOder(bst.root));
368            System.out.println("[traversal level-order]\t"
369                    + bst.traversalLevelOrder(bst.root));
370
371            System.out.println("[count]\ttree size: "
372                    + bst.size());
373
374            System.out.println("[find parent]\tparent of 23: "
375                    + bst.findParent(23));
376
377            System.out.println("[contains]\tcontains key 20: "
378                    + bst.find(20));
379
380            System.out.println("[deletion]\tdelete key 23: "
381                    + bst.del(23));
382            System.out.println("[traversal]\t"
383                    + bst.traversalPreOder(bst.root));
384
385            System.out.println("[deletion]\tdelete key 48: "
386                    + bst.del(48));
387            System.out.println("[traversal]\t"
388                    + bst.traversalPreOder(bst.root));
389
390            System.out.println("[deletion]\tdelete key 19: "
391                    + bst.del(19));
392            System.out.println("[traversal]\t"
393                    + bst.traversalPreOder(bst.root));
394
395            System.out.println("[deletion]\tdelete key 25: "
396                    + bst.del(25));
397            System.out.println("[traversal]\t"
398                    + bst.traversalPreOder(bst.root));
399
400            System.out.println("[deletion]\tdelete key 45: "
401                    + bst.del(45));
402            System.out.println("[traversal]\t"
403                    + bst.traversalPreOder(bst.root));
404
405            // delete root node
406            System.out.println("[deletion]\tdelete root: "
407                    + bst.del(bst.root.value));
408            System.out.println("[traversal]\t"
409                    + bst.traversalLevelOrder(bst.root));
410        }
411    }
412}