note.wcoder.com
wcoder GitHub
package main

import (
	"fmt"
	"math"
	"math/rand"
	"sort"
	"time"
)

// Item 表示一个带有权重的元素
type Item struct {
	Value  interface{}
	Weight float64
}

// WeightedItem 用于随机排序的内部结构
type WeightedItem struct {
	item       Item
	randWeight float64 // 随机权重值
}

// 迭代器接口
type Iterator interface {
	Reset()
	HasNext() bool
	Next() (interface{}, error)
	All() []interface{}
}

// WeightedRandomIterator 实现权重随机迭代器 (原始实现)
type WeightedRandomIterator struct {
	items        []Item
	sorted       []WeightedItem
	currentIndex int
	initialized  bool
}

// NewWeightedRandomIterator 创建新的权重随机迭代器
func NewWeightedRandomIterator(items []Item) *WeightedRandomIterator {
	return &WeightedRandomIterator{
		items:        items,
		currentIndex: 0,
		initialized:  false,
	}
}

// 初始化排序,使用权重进行随机排序
func (w *WeightedRandomIterator) initialize() {
	if w.initialized {
		return
	}

	rand.Seed(time.Now().UnixNano())
	w.sorted = make([]WeightedItem, len(w.items))

	// 为每个元素生成随机权重
	for i, item := range w.items {
		// 使用 -log(random) / weight 公式实现权重随机
		// 这个公式保证了权重越大的元素被排在前面的概率越高
		randomValue := 0.0
		if item.Weight > 0 {
			randomValue = -math.Log(rand.Float64()) / item.Weight
		} else {
			randomValue = math.Inf(1) // 权重为0时,放在最后面
		}

		w.sorted[i] = WeightedItem{
			item:       item,
			randWeight: randomValue,
		}
	}

	// 按随机权重从小到大排序(权重高的元素随机值小,排在前面)
	sort.Slice(w.sorted, func(i, j int) bool {
		return w.sorted[i].randWeight < w.sorted[j].randWeight
	})

	w.initialized = true
}

// Reset 重置迭代器状态,可以重新进行随机排序
func (w *WeightedRandomIterator) Reset() {
	w.initialized = false
	w.currentIndex = 0
	w.initialize()
}

// HasNext 判断是否还有下一个元素
func (w *WeightedRandomIterator) HasNext() bool {
	w.initialize()
	return w.currentIndex < len(w.sorted)
}

// Next 返回下一个随机权重排序的元素
func (w *WeightedRandomIterator) Next() (interface{}, error) {
	if !w.HasNext() {
		return nil, fmt.Errorf("no more items in iterator")
	}

	result := w.sorted[w.currentIndex].item.Value
	w.currentIndex++
	return result, nil
}

// All 一次性返回所有随机排序后的元素
func (w *WeightedRandomIterator) All() []interface{} {
	w.initialize()
	result := make([]interface{}, len(w.sorted))
	for i, item := range w.sorted {
		result[i] = item.item.Value
	}
	return result
}

// AResIterator 使用A-Res算法的权重随机迭代器
type AResIterator struct {
	items        []Item
	result       []interface{}
	currentIndex int
	initialized  bool
}

// NewAResIterator 创建使用A-Res算法的迭代器
func NewAResIterator(items []Item) *AResIterator {
	return &AResIterator{
		items:        items,
		currentIndex: 0,
		initialized:  false,
	}
}

// 使用A-Res算法初始化
func (w *AResIterator) initialize() {
	if w.initialized {
		return
	}

	rand.Seed(time.Now().UnixNano())
	n := len(w.items)

	// 如果没有元素,直接返回
	if n == 0 {
		w.result = []interface{}{}
		w.initialized = true
		return
	}

	// 创建带权重键的项目列表
	type KeyedItem struct {
		Value interface{}
		Key   float64
	}

	keyedItems := make([]KeyedItem, n)

	// 为每个元素生成A-Res算法的随机键值
	// 使用公式 key = random^(1/weight)
	for i, item := range w.items {
		// 处理权重为0的情况
		weight := item.Weight
		var key float64 = 0
		if weight > 0 {
			// 计算 random^(1/weight),这是A-Res算法的核心
			key = math.Pow(rand.Float64(), 1.0/weight)
		}

		keyedItems[i] = KeyedItem{
			Value: item.Value,
			Key:   key,
		}
	}

	// 按键值降序排序(键值大的排在前面)
	sort.Slice(keyedItems, func(i, j int) bool {
		return keyedItems[i].Key > keyedItems[j].Key
	})

	// 提取排序后的值
	w.result = make([]interface{}, n)
	for i, item := range keyedItems {
		w.result[i] = item.Value
	}

	w.initialized = true
}

// Reset 重置A-Res迭代器
func (w *AResIterator) Reset() {
	w.initialized = false
	w.currentIndex = 0
	w.initialize()
}

// HasNext 判断A-Res迭代器是否有下一个元素
func (w *AResIterator) HasNext() bool {
	w.initialize()
	return w.currentIndex < len(w.result)
}

// Next 返回A-Res迭代器的下一个元素
func (w *AResIterator) Next() (interface{}, error) {
	if !w.HasNext() {
		return nil, fmt.Errorf("no more items in iterator")
	}

	result := w.result[w.currentIndex]
	w.currentIndex++
	return result, nil
}

// All 一次性返回A-Res迭代器的所有元素
func (w *AResIterator) All() []interface{} {
	w.initialize()
	return w.result
}

// AExpJIterator 使用A-ExpJ算法的权重随机迭代器
type AExpJIterator struct {
	items        []Item
	result       []interface{}
	currentIndex int
	initialized  bool
}

// NewAExpJIterator 创建使用A-ExpJ算法的迭代器
func NewAExpJIterator(items []Item) *AExpJIterator {
	return &AExpJIterator{
		items:        items,
		currentIndex: 0,
		initialized:  false,
	}
}

// 使用A-ExpJ算法初始化
func (w *AExpJIterator) initialize() {
	if w.initialized {
		return
	}

	rand.Seed(time.Now().UnixNano())
	n := len(w.items)

	// 计算总权重
	totalWeight := 0.0
	for _, item := range w.items {
		totalWeight += item.Weight
	}

	// 复制items,用于后续处理
	remainingItems := make([]Item, n)
	copy(remainingItems, w.items)

	w.result = make([]interface{}, n)

	// 为了防止出现nil值,首先用第一个元素的值填充结果数组
	if n > 0 {
		firstValue := w.items[0].Value
		for i := 0; i < n; i++ {
			w.result[i] = firstValue
		}
	}

	// 使用指数跳跃算法选择元素
	for i := 0; i < n; i++ {
		if len(remainingItems) == 0 {
			break
		}

		// 计算剩余元素的总权重
		remainingWeight := 0.0
		for _, item := range remainingItems {
			remainingWeight += item.Weight
		}

		if remainingWeight <= 0 {
			break
		}

		// 生成[0,remainingWeight)之间的随机数
		r := rand.Float64() * remainingWeight

		// 查找累积权重首次超过r的元素
		cumWeight := 0.0
		selectedIndex := 0
		for j, item := range remainingItems {
			cumWeight += item.Weight
			if cumWeight > r {
				selectedIndex = j
				break
			}
		}

		// 将选中的元素添加到结果中
		w.result[i] = remainingItems[selectedIndex].Value

		// 从剩余元素中移除已选中的元素
		remainingItems = append(remainingItems[:selectedIndex], remainingItems[selectedIndex+1:]...)
	}

	w.initialized = true
}

// Reset 重置A-ExpJ迭代器
func (w *AExpJIterator) Reset() {
	w.initialized = false
	w.currentIndex = 0
	w.initialize()
}

// HasNext 判断A-ExpJ迭代器是否有下一个元素
func (w *AExpJIterator) HasNext() bool {
	w.initialize()
	return w.currentIndex < len(w.result)
}

// Next 返回A-ExpJ迭代器的下一个元素
func (w *AExpJIterator) Next() (interface{}, error) {
	if !w.HasNext() {
		return nil, fmt.Errorf("no more items in iterator")
	}

	result := w.result[w.currentIndex]
	w.currentIndex++
	return result, nil
}

// All 一次性返回A-ExpJ迭代器的所有元素
func (w *AExpJIterator) All() []interface{} {
	w.initialize()
	return w.result
}

// 验证概率分布
func verifyProbability(iterator func([]Item) Iterator, items []Item, iterations int) map[interface{}]float64 {
	// 统计每个位置上的元素出现次数
	positionCounts := make(map[interface{}]map[int]int)

	// 初始化计数器
	for _, item := range items {
		positionCounts[item.Value] = make(map[int]int)
	}

	// 进行多次迭代测试
	for i := 0; i < iterations; i++ {
		iter := iterator(items)
		allItems := iter.All()

		// 统计每个位置上元素出现的次数
		for pos, val := range allItems {
			// 确保val存在于map中
			if val == nil {
				continue
			}
			if _, exists := positionCounts[val]; !exists {
				positionCounts[val] = make(map[int]int)
			}
			positionCounts[val][pos]++
		}
	}

	// 计算每个元素出现在第一位的概率
	firstPositionProb := make(map[interface{}]float64)
	for val, positions := range positionCounts {
		firstPositionProb[val] = float64(positions[0]) / float64(iterations)
	}

	return firstPositionProb
}

// 计算理论概率分布
func calculateTheoretical(items []Item) map[interface{}]float64 {
	result := make(map[interface{}]float64)
	totalWeight := 0.0

	// 计算总权重
	for _, item := range items {
		totalWeight += item.Weight
	}

	// 计算每个元素按权重应该具有的概率
	for _, item := range items {
		result[item.Value] = item.Weight / totalWeight
	}

	return result
}

// 性能测试函数
func benchmarkAlgorithms(sizes []int, runs int) {
	fmt.Println("\n性能测试结果 (单位: 微秒):")
	fmt.Println("元素数量\t原始算法\tA-Res算法\tA-ExpJ算法")
	fmt.Println("-----------------------------------------------")

	for _, size := range sizes {
		// 生成测试数据
		items := generateTestItems(size)

		// 测试原始算法
		originalTime := benchmarkIterator(func() Iterator {
			return NewWeightedRandomIterator(items)
		}, runs)

		// 测试A-Res算法
		aresTime := benchmarkIterator(func() Iterator {
			return NewAResIterator(items)
		}, runs)

		// 测试A-ExpJ算法
		aexpjTime := benchmarkIterator(func() Iterator {
			return NewAExpJIterator(items)
		}, runs)

		// 输出结果
		fmt.Printf("%d\t%.2f\t%.2f\t%.2f\n",
			size,
			float64(originalTime.Microseconds())/float64(runs),
			float64(aresTime.Microseconds())/float64(runs),
			float64(aexpjTime.Microseconds())/float64(runs))
	}
}

// 生成测试数据
func generateTestItems(count int) []Item {
	items := make([]Item, count)
	for i := 0; i < count; i++ {
		// 随机生成0.1到10之间的权重
		weight := 0.1 + rand.Float64()*9.9
		items[i] = Item{
			Value:  fmt.Sprintf("Item-%d", i),
			Weight: weight,
		}
	}
	return items
}

// 对单个迭代器进行性能测试
func benchmarkIterator(newIterator func() Iterator, runs int) time.Duration {
	start := time.Now()

	for i := 0; i < runs; i++ {
		iterator := newIterator()
		iterator.All() // 执行所有排序操作
	}

	return time.Since(start)
}

func main() {
	// 设置随机种子
	rand.Seed(time.Now().UnixNano())

	// 创建测试数据
	items := []Item{
		{Value: "A", Weight: 0.1},
		{Value: "B", Weight: 0.3},
		{Value: "C", Weight: 0.6},
		{Value: "D", Weight: 0.2},
		{Value: "E", Weight: 0.8},
	}

	// 计算理论概率分布
	theoreticalProb := calculateTheoretical(items)
	fmt.Println("理论概率分布(按权重):")
	for val, prob := range theoreticalProb {
		fmt.Printf("%v: %.4f\n", val, prob)
	}

	// 验证原始算法的实际概率分布
	iterations := 10000
	fmt.Printf("\n原始算法进行%d次迭代验证...\n", iterations)
	originalProb := verifyProbability(func(items []Item) Iterator { return NewWeightedRandomIterator(items) }, items, iterations)

	fmt.Println("\n原始算法实际出现在第一位的概率:")
	for val, prob := range originalProb {
		fmt.Printf("%v: %.4f (理论值: %.4f, 误差: %.4f%%)\n",
			val,
			prob,
			theoreticalProb[val],
			math.Abs(prob-theoreticalProb[val])/theoreticalProb[val]*100)
	}

	// 验证A-Res算法的实际概率分布
	fmt.Printf("\nA-Res算法进行%d次迭代验证...\n", iterations)
	aresProb := verifyProbability(func(items []Item) Iterator { return NewAResIterator(items) }, items, iterations)

	fmt.Println("\nA-Res算法实际出现在第一位的概率:")
	for val, prob := range aresProb {
		fmt.Printf("%v: %.4f (理论值: %.4f, 误差: %.4f%%)\n",
			val,
			prob,
			theoreticalProb[val],
			math.Abs(prob-theoreticalProb[val])/theoreticalProb[val]*100)
	}

	// 验证A-ExpJ算法的实际概率分布
	fmt.Printf("\nA-ExpJ算法进行%d次迭代验证...\n", iterations)
	aexpjProb := verifyProbability(func(items []Item) Iterator { return NewAExpJIterator(items) }, items, iterations)

	fmt.Println("\nA-ExpJ算法实际出现在第一位的概率:")
	for val, prob := range aexpjProb {
		fmt.Printf("%v: %.4f (理论值: %.4f, 误差: %.4f%%)\n",
			val,
			prob,
			theoreticalProb[val],
			math.Abs(prob-theoreticalProb[val])/theoreticalProb[val]*100)
	}

	// 展示三种算法的随机排序示例
	fmt.Println("\n原始算法随机排序示例:")
	iterator := NewWeightedRandomIterator(items)
	for _, val := range iterator.All() {
		fmt.Print(val, " ")
	}
	fmt.Println()

	fmt.Println("\nA-Res算法随机排序示例:")
	aresIterator := NewAResIterator(items)
	for _, val := range aresIterator.All() {
		fmt.Print(val, " ")
	}
	fmt.Println()

	fmt.Println("\nA-ExpJ算法随机排序示例:")
	aexpjIterator := NewAExpJIterator(items)
	for _, val := range aexpjIterator.All() {
		fmt.Print(val, " ")
	}
	fmt.Println()

	// 执行性能测试
	// 测试不同元素数量下的性能
	sizes := []int{10, 100, 1000, 10000}
	benchmarkAlgorithms(sizes, 100) // 每种算法每个大小运行100次取平均值

	// 测试大数据量的性能差异
	fmt.Println("\n大数据量性能测试 (10万元素):")
	largeItems := generateTestItems(100000)

	start := time.Now()
	NewWeightedRandomIterator(largeItems).All()
	fmt.Printf("原始算法耗时: %v\n", time.Since(start))

	start = time.Now()
	NewAResIterator(largeItems).All()
	fmt.Printf("A-Res算法耗时: %v\n", time.Since(start))

	start = time.Now()
	NewAExpJIterator(largeItems).All()
	fmt.Printf("A-ExpJ算法耗时: %v\n", time.Since(start))

	/*
		大数据量性能测试 (10万元素):
		原始算法耗时: 15.37625ms  这个性能最好
		A-Res算法耗时: 17.383792ms
		A-ExpJ算法耗时: 5.174245209s 这个虽然慢, 但是如果是100以内的取出不放回, 可以选择这个 (保证结果中元素不重复,适合"不放回"抽样)
	*/
}

迭代器版本, 进行优化

package main

import (
	"fmt"
	"math"
	"math/rand"
	"time"
)

// Item 表示一个带有权重的元素
type Item struct {
	Value  string
	Weight float64
}

// Weight 返回一个函数,该函数使用A-ExpJ算法进行权重随机选择
func Weight(cMembers []*Item) func(yield func(*Item) bool) {
	return func(yield func(*Item) bool) {
		if len(cMembers) == 0 {
			return
		}

		// 初始化随机种子
		rand.Seed(time.Now().UnixNano())

		// 复制items,用于后续处理
		remainingItems := make([]*Item, len(cMembers))
		copy(remainingItems, cMembers)

		// 使用指数跳跃算法选择元素
		for len(remainingItems) > 0 {
			// 计算剩余元素的总权重
			remainingWeight := 0.0
			for _, item := range remainingItems {
				remainingWeight += item.Weight
			}

			if remainingWeight <= 0 {
				break
			}

			// 生成[0,remainingWeight)之间的随机数
			r := rand.Float64() * remainingWeight

			// 查找累积权重首次超过r的元素
			cumWeight := 0.0
			selectedIndex := 0
			for j, item := range remainingItems {
				cumWeight += item.Weight
				if cumWeight > r {
					selectedIndex = j
					break
				}
			}

			// 使用yield函数处理选中的元素
			shouldContinue := yield(remainingItems[selectedIndex])

			// 从剩余元素中移除已选中的元素
			remainingItems = append(remainingItems[:selectedIndex], remainingItems[selectedIndex+1:]...)

			// 如果yield函数返回false,停止选择
			if !shouldContinue {
				break
			}
		}
	}
}

// 计算理论概率分布
func calculateTheoretical(items []*Item) map[string]float64 {
	result := make(map[string]float64)
	totalWeight := 0.0

	// 计算总权重
	for _, item := range items {
		totalWeight += item.Weight
	}

	// 计算每个元素按权重应该具有的概率
	for _, item := range items {
		result[item.Value] = item.Weight / totalWeight
	}

	return result
}

// 验证概率分布
func verifyProbability(items []*Item, iterations int) map[string]float64 {
	// 统计每个元素被选中的次数
	counts := make(map[string]int)

	for i := 0; i < iterations; i++ {
		weightFunc := Weight(items)
		weightFunc(func(selectedItem *Item) bool {
			counts[selectedItem.Value]++
			return false // 只统计第一个选中的元素
		})
	}

	// 计算每个元素被选中的概率
	probabilities := make(map[string]float64)
	for val, count := range counts {
		probabilities[val] = float64(count) / float64(iterations)
	}

	return probabilities
}

func main() {
	// 创建测试数据
	items := []*Item{
		{Value: "A", Weight: 0.1},
		{Value: "B", Weight: 0.3},
		{Value: "C", Weight: 0.6},
		{Value: "D", Weight: 0.2},
		{Value: "E", Weight: 0.8},
	}

	for i := range Weight(items) {
		fmt.Printf("Selected Item: %s\n", i.Value)
	}

	// 计算理论概率分布
	theoreticalProb := calculateTheoretical(items)
	fmt.Println("理论概率分布:")
	for val, prob := range theoreticalProb {
		fmt.Printf("%v: %.4f\n", val, prob)
	}

	// 验证实际概率分布
	iterations := 10000
	actualProb := verifyProbability(items, iterations)

	fmt.Println("\n实际概率分布:")
	for val, prob := range actualProb {
		fmt.Printf("%v: %.4f (理论值: %.4f, 误差: %.4f%%)\n",
			val,
			prob,
			theoreticalProb[val],
			(math.Abs(prob-theoreticalProb[val])/theoreticalProb[val])*100)
	}

}

← Previous Next →
Less
More