From 45aaa4bf1af5e7a6d65a8c6e944a75839c878a6e Mon Sep 17 00:00:00 2001 From: Edgaru089 Date: Thu, 7 Aug 2025 17:19:14 +0800 Subject: [PATCH] Initial commit --- go.mod | 3 + sortable.go | 42 ++++++++++ tree.go | 215 +++++++++++++++++++++++++++++++++++++++++++++++++ tree_adjust.go | 74 +++++++++++++++++ 4 files changed, 334 insertions(+) create mode 100644 go.mod create mode 100644 sortable.go create mode 100644 tree.go create mode 100644 tree_adjust.go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f278d18 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module edgaru089.ink/go/stl + +go 1.24.5 diff --git a/sortable.go b/sortable.go new file mode 100644 index 0000000..b3048b0 --- /dev/null +++ b/sortable.go @@ -0,0 +1,42 @@ +package stl + +// ImplicitSortable are types that can be compared with +// the less operator <, among others. +// +// As with C++, two keys are considered equal if neither +// x < y nor y < x, so these interfaces have nothing to +// do with Go's 'comparable' whatsoever. +// +// This behavior can be overridden by defining another +// type and then add a Compare method to the original key. +type ImplicitSortable interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | ~string | ~float32 | ~float64 +} + +// Sorter is a interface used in the construction of +// many collection types, to compare two keys with +// the '<' operator by default. +type Sorter[T any] interface { + Compare(x, y T) bool +} + +// Less implements Sorter for ImplicitSortables using the '<' operator. +type Less[T ImplicitSortable] struct{} + +func (Less[T]) Compare(x, y T) bool { + return x < y +} + +// Greater implements Sorter for ImplicitSortables using the '>' operator. +type Greater[T ImplicitSortable] struct{} + +func (Greater[T]) Compare(x, y T) bool { + return x > y +} + +// Equal determines if the two keys are considered equal, +// i.e., neither x < y nor y < x, per C++ traditions. +func Equal[C Sorter[T], T any](x, y T) bool { + var c C + return !c.Compare(x, y) && !c.Compare(y, x) +} diff --git a/tree.go b/tree.go new file mode 100644 index 0000000..bf0be38 --- /dev/null +++ b/tree.go @@ -0,0 +1,215 @@ +package stl + +import ( + "math/rand/v2" +) + +// Node is a node on the binary search tree. +type Node[K comparable, V any, C Sorter[K]] struct { + lson, rson *Node[K, V, C] + parent *Node[K, V, C] + bal uint32 // for treap balancing + + key K + Value V // The value on the node. +} + +// Key is read-only access to the key of the node. +func (n *Node[K, V, C]) Key() K { + return n.key +} + +// Tree is a binary search tree. Always take +// addresses and don't work with values. +// +// The zero value is a usable, empty tree. +type Tree[K comparable, V any, C Sorter[K]] struct { + root *Node[K, V, C] +} + +// TreeInc is an alias for a tree compared by the +// default Less operator. +type TreeInc[K ImplicitSortable, V any] = Tree[K, V, Less[K]] + +// TreeDec is an alias for a tree compared by the +// Greater operator. +type TreeDec[K ImplicitSortable, V any] = Tree[K, V, Greater[K]] + +// Clear clears the tree, erasing every element on it. +func (t *Tree[K, V, C]) Clear() { + t.root = nil +} + +// Insert tries to insert a new node on the tree, locating +// the value if the key is already in the tree. +// +// Returns the new (or found) value on the tree, and true/false +// if the key was inserted or found. +func (t *Tree[K, V, C]) Insert(key K, value V) (value_on_tree *V, is_added bool) { + node, is_added := t.InsertNode(key, value) + return &node.Value, is_added +} + +// InsertNode does the same as Insert, but returns the entire *Node. +func (t *Tree[K, V, C]) InsertNode(key K, value V) (node *Node[K, V, C], is_added bool) { + t.root = t.realInsertNode(t.root, nil, key, value, &node, &is_added) + return +} + +// realInsertNode recursivly searches and tries to insert the new node. +// If the key is found, it's passed into *result. +func (t *Tree[K, V, C]) realInsertNode(cur, parent *Node[K, V, C], key K, value V, result **Node[K, V, C], is_added *bool) (node *Node[K, V, C]) { + var c C + if cur == nil { + *result = &Node[K, V, C]{ + parent: parent, + key: key, + Value: value, + bal: rand.Uint32(), + } + *is_added = true + return *result + } else if c.Compare(key, cur.key) { + // key < now.key + cur.lson = t.realInsertNode(cur.lson, cur, key, value, result, is_added) + return cur + } else if c.Compare(cur.key, key) { + // key > now.key + cur.rson = t.realInsertNode(cur.rson, cur, key, value, result, is_added) + return cur + } else { + // key == now.key + *result = cur + *is_added = false + return cur + } + +} + +// Find finds the value related to the key on the tree, +// or nil if it is not found. +func (t *Tree[K, V, C]) Find(key K) (value_on_tree *V) { + node := t.FindNode(key) + if node == nil { + return nil + } else { + return &node.Value + } +} + +// FindNode finds the node instead, or nil if it is not found. +func (t *Tree[K, V, C]) FindNode(key K) (node *Node[K, V, C]) { + return t.realFindNode(t.root, key) +} + +func (t *Tree[K, V, C]) realFindNode(cur *Node[K, V, C], key K) (node *Node[K, V, C]) { + var c C + if cur == nil { + return nil + } else if c.Compare(key, cur.key) { + // key < cur.key + return t.realFindNode(cur.lson, key) + } else if c.Compare(cur.key, key) { + // key > cur.key + return t.realFindNode(cur.rson, key) + } else { + // key == cur.key + return cur + } +} + +// Deletes a node, does nothing if node is Nil. +// +// Use in tandem with FindNode to delete a key. +func (t *Tree[K, V, C]) Delete(node *Node[K, V, C]) { + if node == nil { + return + } + + for node.lson != nil && node.rson != nil { + if node.lson.bal < node.rson.bal { + node.lson.rotate(&t.root) + } else { + node.rson.rotate(&t.root) + } + } + + if node == t.root { + // select new root + if node.lson != nil { + t.root = node.lson + } else { + t.root = node.rson + } + } + + if node.lson != nil { + node.parent.connect(node.lson, node.tell()) + } else { + node.parent.connect(node.rson, node.tell()) + } + + // delete everything for the GC + // Seriously, why use a tree when you have a GC? + node.lson = nil + node.rson = nil + node.parent = nil + var k K + var v V + node.key = k + node.Value = v +} + +func (t *Tree[K, V, C]) FirstNode() (node *Node[K, V, C]) { + node = t.root + if node == nil { + return nil + } + for node.lson != nil { + node = node.lson + } + return node +} + +func (t *Tree[K, V, C]) LastNode() (node *Node[K, V, C]) { + node = t.root + if node == nil { + return nil + } + for node.rson != nil { + node = node.rson + } + return node +} + +func (node *Node[K, V, C]) Next() (next *Node[K, V, C]) { + if node.rson != nil { + next = node.rson + for next.lson != nil { + next = next.lson + } + return + } else { + next = node + for next.parent != nil && next.tell() == treeRight { + next = next.parent + } + return next.parent + } +} + +func (node *Node[K, V, C]) Previous() (prev *Node[K, V, C]) { + if node.lson != nil { + prev = node.lson + for prev.rson != nil { + prev = prev.rson + } + return prev + } else { + prev = node + for prev.parent != nil && prev.tell() == treeLeft { + prev = prev.parent + } + return prev.parent + } +} diff --git a/tree_adjust.go b/tree_adjust.go new file mode 100644 index 0000000..8db3ded --- /dev/null +++ b/tree_adjust.go @@ -0,0 +1,74 @@ +package stl + +type treeConnectType int8 + +const ( + treeLeft treeConnectType = iota + treeRight +) + +func (t treeConnectType) invert() treeConnectType { + return 1 - t +} + +func (son *Node[K, V, C]) tell() treeConnectType { + if son.parent == nil { + return treeLeft + } + + if son.parent.lson == son { + return treeLeft + } else { + return treeRight + } +} + +func (n *Node[K, V, C]) get(t treeConnectType) *Node[K, V, C] { + if t == treeLeft { + return n.lson + } else { + return n.rson + } +} + +func (parent *Node[K, V, C]) connect(son *Node[K, V, C], t treeConnectType) { + if son != nil { + son.parent = parent + } + if parent != nil { + if t == treeLeft { + parent.lson = son + } else { + parent.rson = son + } + } +} + +// rotate the node up +func (node *Node[K, V, C]) rotate(root **Node[K, V, C]) { + if node.parent == nil { + // is root, sanity check + return + } + + t := node.tell() + + f := node.parent + b := node.get(t.invert()) + + f.parent.connect(node, f.tell()) + node.connect(f, t.invert()) + f.connect(b, t) + + if node.parent == nil { + // new root + *root = node + } +} + +// adjust the new node in the tree Treap style. +func (node *Node[K, V, C]) adjust(root **Node[K, V, C]) { + for node.parent != nil && node.parent.bal > node.bal { + node.rotate(root) + } +}