package main
import (
"fmt"
"math"
"math/rand"
"sort"
"time"
)
// Item 表示一个带有权重的元素
type Item struct {
Value interface{}
Weight float64
}
// WeightedItem 用于随机排序的内部结构
type WeightedItem struct {
item Item
randWeight float64 // 随机权重值
}
// 迭代器接口
type Iterator interface {
Reset()
HasNext() bool
Next() (interface{}, error)
All() []interface{}
}
// WeightedRandomIterator 实现权重随机迭代器 (原始实现)
type WeightedRandomIterator struct {
items []Item
sorted []WeightedItem
currentIndex int
initialized bool
}
// NewWeightedRandomIterator 创建新的权重随机迭代器
func NewWeightedRandomIterator(items []Item) *WeightedRandomIterator {
return &WeightedRandomIterator{
items: items,
currentIndex: 0,
initialized: false,
}
}
// 初始化排序,使用权重进行随机排序
func (w *WeightedRandomIterator) initialize() {
if w.initialized {
return
}
rand.Seed(time.Now().UnixNano())
w.sorted = make([]WeightedItem, len(w.items))
// 为每个元素生成随机权重
for i, item := range w.items {
// 使用 -log(random) / weight 公式实现权重随机
// 这个公式保证了权重越大的元素被排在前面的概率越高
randomValue := 0.0
if item.Weight > 0 {
randomValue = -math.Log(rand.Float64()) / item.Weight
} else {
randomValue = math.Inf(1) // 权重为0时,放在最后面
}
w.sorted[i] = WeightedItem{
item: item,
randWeight: randomValue,
}
}
// 按随机权重从小到大排序(权重高的元素随机值小,排在前面)
sort.Slice(w.sorted, func(i, j int) bool {
return w.sorted[i].randWeight < w.sorted[j].randWeight
})
w.initialized = true
}
// Reset 重置迭代器状态,可以重新进行随机排序
func (w *WeightedRandomIterator) Reset() {
w.initialized = false
w.currentIndex = 0
w.initialize()
}
// HasNext 判断是否还有下一个元素
func (w *WeightedRandomIterator) HasNext() bool {
w.initialize()
return w.currentIndex < len(w.sorted)
}
// Next 返回下一个随机权重排序的元素
func (w *WeightedRandomIterator) Next() (interface{}, error) {
if !w.HasNext() {
return nil, fmt.Errorf("no more items in iterator")
}
result := w.sorted[w.currentIndex].item.Value
w.currentIndex++
return result, nil
}
// All 一次性返回所有随机排序后的元素
func (w *WeightedRandomIterator) All() []interface{} {
w.initialize()
result := make([]interface{}, len(w.sorted))
for i, item := range w.sorted {
result[i] = item.item.Value
}
return result
}
// AResIterator 使用A-Res算法的权重随机迭代器
type AResIterator struct {
items []Item
result []interface{}
currentIndex int
initialized bool
}
// NewAResIterator 创建使用A-Res算法的迭代器
func NewAResIterator(items []Item) *AResIterator {
return &AResIterator{
items: items,
currentIndex: 0,
initialized: false,
}
}
// 使用A-Res算法初始化
func (w *AResIterator) initialize() {
if w.initialized {
return
}
rand.Seed(time.Now().UnixNano())
n := len(w.items)
// 如果没有元素,直接返回
if n == 0 {
w.result = []interface{}{}
w.initialized = true
return
}
// 创建带权重键的项目列表
type KeyedItem struct {
Value interface{}
Key float64
}
keyedItems := make([]KeyedItem, n)
// 为每个元素生成A-Res算法的随机键值
// 使用公式 key = random^(1/weight)
for i, item := range w.items {
// 处理权重为0的情况
weight := item.Weight
var key float64 = 0
if weight > 0 {
// 计算 random^(1/weight),这是A-Res算法的核心
key = math.Pow(rand.Float64(), 1.0/weight)
}
keyedItems[i] = KeyedItem{
Value: item.Value,
Key: key,
}
}
// 按键值降序排序(键值大的排在前面)
sort.Slice(keyedItems, func(i, j int) bool {
return keyedItems[i].Key > keyedItems[j].Key
})
// 提取排序后的值
w.result = make([]interface{}, n)
for i, item := range keyedItems {
w.result[i] = item.Value
}
w.initialized = true
}
// Reset 重置A-Res迭代器
func (w *AResIterator) Reset() {
w.initialized = false
w.currentIndex = 0
w.initialize()
}
// HasNext 判断A-Res迭代器是否有下一个元素
func (w *AResIterator) HasNext() bool {
w.initialize()
return w.currentIndex < len(w.result)
}
// Next 返回A-Res迭代器的下一个元素
func (w *AResIterator) Next() (interface{}, error) {
if !w.HasNext() {
return nil, fmt.Errorf("no more items in iterator")
}
result := w.result[w.currentIndex]
w.currentIndex++
return result, nil
}
// All 一次性返回A-Res迭代器的所有元素
func (w *AResIterator) All() []interface{} {
w.initialize()
return w.result
}
// AExpJIterator 使用A-ExpJ算法的权重随机迭代器
type AExpJIterator struct {
items []Item
result []interface{}
currentIndex int
initialized bool
}
// NewAExpJIterator 创建使用A-ExpJ算法的迭代器
func NewAExpJIterator(items []Item) *AExpJIterator {
return &AExpJIterator{
items: items,
currentIndex: 0,
initialized: false,
}
}
// 使用A-ExpJ算法初始化
func (w *AExpJIterator) initialize() {
if w.initialized {
return
}
rand.Seed(time.Now().UnixNano())
n := len(w.items)
// 计算总权重
totalWeight := 0.0
for _, item := range w.items {
totalWeight += item.Weight
}
// 复制items,用于后续处理
remainingItems := make([]Item, n)
copy(remainingItems, w.items)
w.result = make([]interface{}, n)
// 为了防止出现nil值,首先用第一个元素的值填充结果数组
if n > 0 {
firstValue := w.items[0].Value
for i := 0; i < n; i++ {
w.result[i] = firstValue
}
}
// 使用指数跳跃算法选择元素
for i := 0; i < n; i++ {
if len(remainingItems) == 0 {
break
}
// 计算剩余元素的总权重
remainingWeight := 0.0
for _, item := range remainingItems {
remainingWeight += item.Weight
}
if remainingWeight <= 0 {
break
}
// 生成[0,remainingWeight)之间的随机数
r := rand.Float64() * remainingWeight
// 查找累积权重首次超过r的元素
cumWeight := 0.0
selectedIndex := 0
for j, item := range remainingItems {
cumWeight += item.Weight
if cumWeight > r {
selectedIndex = j
break
}
}
// 将选中的元素添加到结果中
w.result[i] = remainingItems[selectedIndex].Value
// 从剩余元素中移除已选中的元素
remainingItems = append(remainingItems[:selectedIndex], remainingItems[selectedIndex+1:]...)
}
w.initialized = true
}
// Reset 重置A-ExpJ迭代器
func (w *AExpJIterator) Reset() {
w.initialized = false
w.currentIndex = 0
w.initialize()
}
// HasNext 判断A-ExpJ迭代器是否有下一个元素
func (w *AExpJIterator) HasNext() bool {
w.initialize()
return w.currentIndex < len(w.result)
}
// Next 返回A-ExpJ迭代器的下一个元素
func (w *AExpJIterator) Next() (interface{}, error) {
if !w.HasNext() {
return nil, fmt.Errorf("no more items in iterator")
}
result := w.result[w.currentIndex]
w.currentIndex++
return result, nil
}
// All 一次性返回A-ExpJ迭代器的所有元素
func (w *AExpJIterator) All() []interface{} {
w.initialize()
return w.result
}
// 验证概率分布
func verifyProbability(iterator func([]Item) Iterator, items []Item, iterations int) map[interface{}]float64 {
// 统计每个位置上的元素出现次数
positionCounts := make(map[interface{}]map[int]int)
// 初始化计数器
for _, item := range items {
positionCounts[item.Value] = make(map[int]int)
}
// 进行多次迭代测试
for i := 0; i < iterations; i++ {
iter := iterator(items)
allItems := iter.All()
// 统计每个位置上元素出现的次数
for pos, val := range allItems {
// 确保val存在于map中
if val == nil {
continue
}
if _, exists := positionCounts[val]; !exists {
positionCounts[val] = make(map[int]int)
}
positionCounts[val][pos]++
}
}
// 计算每个元素出现在第一位的概率
firstPositionProb := make(map[interface{}]float64)
for val, positions := range positionCounts {
firstPositionProb[val] = float64(positions[0]) / float64(iterations)
}
return firstPositionProb
}
// 计算理论概率分布
func calculateTheoretical(items []Item) map[interface{}]float64 {
result := make(map[interface{}]float64)
totalWeight := 0.0
// 计算总权重
for _, item := range items {
totalWeight += item.Weight
}
// 计算每个元素按权重应该具有的概率
for _, item := range items {
result[item.Value] = item.Weight / totalWeight
}
return result
}
// 性能测试函数
func benchmarkAlgorithms(sizes []int, runs int) {
fmt.Println("\n性能测试结果 (单位: 微秒):")
fmt.Println("元素数量\t原始算法\tA-Res算法\tA-ExpJ算法")
fmt.Println("-----------------------------------------------")
for _, size := range sizes {
// 生成测试数据
items := generateTestItems(size)
// 测试原始算法
originalTime := benchmarkIterator(func() Iterator {
return NewWeightedRandomIterator(items)
}, runs)
// 测试A-Res算法
aresTime := benchmarkIterator(func() Iterator {
return NewAResIterator(items)
}, runs)
// 测试A-ExpJ算法
aexpjTime := benchmarkIterator(func() Iterator {
return NewAExpJIterator(items)
}, runs)
// 输出结果
fmt.Printf("%d\t%.2f\t%.2f\t%.2f\n",
size,
float64(originalTime.Microseconds())/float64(runs),
float64(aresTime.Microseconds())/float64(runs),
float64(aexpjTime.Microseconds())/float64(runs))
}
}
// 生成测试数据
func generateTestItems(count int) []Item {
items := make([]Item, count)
for i := 0; i < count; i++ {
// 随机生成0.1到10之间的权重
weight := 0.1 + rand.Float64()*9.9
items[i] = Item{
Value: fmt.Sprintf("Item-%d", i),
Weight: weight,
}
}
return items
}
// 对单个迭代器进行性能测试
func benchmarkIterator(newIterator func() Iterator, runs int) time.Duration {
start := time.Now()
for i := 0; i < runs; i++ {
iterator := newIterator()
iterator.All() // 执行所有排序操作
}
return time.Since(start)
}
func main() {
// 设置随机种子
rand.Seed(time.Now().UnixNano())
// 创建测试数据
items := []Item{
{Value: "A", Weight: 0.1},
{Value: "B", Weight: 0.3},
{Value: "C", Weight: 0.6},
{Value: "D", Weight: 0.2},
{Value: "E", Weight: 0.8},
}
// 计算理论概率分布
theoreticalProb := calculateTheoretical(items)
fmt.Println("理论概率分布(按权重):")
for val, prob := range theoreticalProb {
fmt.Printf("%v: %.4f\n", val, prob)
}
// 验证原始算法的实际概率分布
iterations := 10000
fmt.Printf("\n原始算法进行%d次迭代验证...\n", iterations)
originalProb := verifyProbability(func(items []Item) Iterator { return NewWeightedRandomIterator(items) }, items, iterations)
fmt.Println("\n原始算法实际出现在第一位的概率:")
for val, prob := range originalProb {
fmt.Printf("%v: %.4f (理论值: %.4f, 误差: %.4f%%)\n",
val,
prob,
theoreticalProb[val],
math.Abs(prob-theoreticalProb[val])/theoreticalProb[val]*100)
}
// 验证A-Res算法的实际概率分布
fmt.Printf("\nA-Res算法进行%d次迭代验证...\n", iterations)
aresProb := verifyProbability(func(items []Item) Iterator { return NewAResIterator(items) }, items, iterations)
fmt.Println("\nA-Res算法实际出现在第一位的概率:")
for val, prob := range aresProb {
fmt.Printf("%v: %.4f (理论值: %.4f, 误差: %.4f%%)\n",
val,
prob,
theoreticalProb[val],
math.Abs(prob-theoreticalProb[val])/theoreticalProb[val]*100)
}
// 验证A-ExpJ算法的实际概率分布
fmt.Printf("\nA-ExpJ算法进行%d次迭代验证...\n", iterations)
aexpjProb := verifyProbability(func(items []Item) Iterator { return NewAExpJIterator(items) }, items, iterations)
fmt.Println("\nA-ExpJ算法实际出现在第一位的概率:")
for val, prob := range aexpjProb {
fmt.Printf("%v: %.4f (理论值: %.4f, 误差: %.4f%%)\n",
val,
prob,
theoreticalProb[val],
math.Abs(prob-theoreticalProb[val])/theoreticalProb[val]*100)
}
// 展示三种算法的随机排序示例
fmt.Println("\n原始算法随机排序示例:")
iterator := NewWeightedRandomIterator(items)
for _, val := range iterator.All() {
fmt.Print(val, " ")
}
fmt.Println()
fmt.Println("\nA-Res算法随机排序示例:")
aresIterator := NewAResIterator(items)
for _, val := range aresIterator.All() {
fmt.Print(val, " ")
}
fmt.Println()
fmt.Println("\nA-ExpJ算法随机排序示例:")
aexpjIterator := NewAExpJIterator(items)
for _, val := range aexpjIterator.All() {
fmt.Print(val, " ")
}
fmt.Println()
// 执行性能测试
// 测试不同元素数量下的性能
sizes := []int{10, 100, 1000, 10000}
benchmarkAlgorithms(sizes, 100) // 每种算法每个大小运行100次取平均值
// 测试大数据量的性能差异
fmt.Println("\n大数据量性能测试 (10万元素):")
largeItems := generateTestItems(100000)
start := time.Now()
NewWeightedRandomIterator(largeItems).All()
fmt.Printf("原始算法耗时: %v\n", time.Since(start))
start = time.Now()
NewAResIterator(largeItems).All()
fmt.Printf("A-Res算法耗时: %v\n", time.Since(start))
start = time.Now()
NewAExpJIterator(largeItems).All()
fmt.Printf("A-ExpJ算法耗时: %v\n", time.Since(start))
/*
大数据量性能测试 (10万元素):
原始算法耗时: 15.37625ms 这个性能最好
A-Res算法耗时: 17.383792ms
A-ExpJ算法耗时: 5.174245209s 这个虽然慢, 但是如果是100以内的取出不放回, 可以选择这个 (保证结果中元素不重复,适合"不放回"抽样)
*/
}
迭代器版本, 进行优化
package main
import (
"fmt"
"math"
"math/rand"
"time"
)
// Item 表示一个带有权重的元素
type Item struct {
Value string
Weight float64
}
// Weight 返回一个函数,该函数使用A-ExpJ算法进行权重随机选择
func Weight(cMembers []*Item) func(yield func(*Item) bool) {
return func(yield func(*Item) bool) {
if len(cMembers) == 0 {
return
}
// 初始化随机种子
rand.Seed(time.Now().UnixNano())
// 复制items,用于后续处理
remainingItems := make([]*Item, len(cMembers))
copy(remainingItems, cMembers)
// 使用指数跳跃算法选择元素
for len(remainingItems) > 0 {
// 计算剩余元素的总权重
remainingWeight := 0.0
for _, item := range remainingItems {
remainingWeight += item.Weight
}
if remainingWeight <= 0 {
break
}
// 生成[0,remainingWeight)之间的随机数
r := rand.Float64() * remainingWeight
// 查找累积权重首次超过r的元素
cumWeight := 0.0
selectedIndex := 0
for j, item := range remainingItems {
cumWeight += item.Weight
if cumWeight > r {
selectedIndex = j
break
}
}
// 使用yield函数处理选中的元素
shouldContinue := yield(remainingItems[selectedIndex])
// 从剩余元素中移除已选中的元素
remainingItems = append(remainingItems[:selectedIndex], remainingItems[selectedIndex+1:]...)
// 如果yield函数返回false,停止选择
if !shouldContinue {
break
}
}
}
}
// 计算理论概率分布
func calculateTheoretical(items []*Item) map[string]float64 {
result := make(map[string]float64)
totalWeight := 0.0
// 计算总权重
for _, item := range items {
totalWeight += item.Weight
}
// 计算每个元素按权重应该具有的概率
for _, item := range items {
result[item.Value] = item.Weight / totalWeight
}
return result
}
// 验证概率分布
func verifyProbability(items []*Item, iterations int) map[string]float64 {
// 统计每个元素被选中的次数
counts := make(map[string]int)
for i := 0; i < iterations; i++ {
weightFunc := Weight(items)
weightFunc(func(selectedItem *Item) bool {
counts[selectedItem.Value]++
return false // 只统计第一个选中的元素
})
}
// 计算每个元素被选中的概率
probabilities := make(map[string]float64)
for val, count := range counts {
probabilities[val] = float64(count) / float64(iterations)
}
return probabilities
}
func main() {
// 创建测试数据
items := []*Item{
{Value: "A", Weight: 0.1},
{Value: "B", Weight: 0.3},
{Value: "C", Weight: 0.6},
{Value: "D", Weight: 0.2},
{Value: "E", Weight: 0.8},
}
for i := range Weight(items) {
fmt.Printf("Selected Item: %s\n", i.Value)
}
// 计算理论概率分布
theoreticalProb := calculateTheoretical(items)
fmt.Println("理论概率分布:")
for val, prob := range theoreticalProb {
fmt.Printf("%v: %.4f\n", val, prob)
}
// 验证实际概率分布
iterations := 10000
actualProb := verifyProbability(items, iterations)
fmt.Println("\n实际概率分布:")
for val, prob := range actualProb {
fmt.Printf("%v: %.4f (理论值: %.4f, 误差: %.4f%%)\n",
val,
prob,
theoreticalProb[val],
(math.Abs(prob-theoreticalProb[val])/theoreticalProb[val])*100)
}
}