go实现函数式泛型Stream

定义函数式泛型流数据结构

package stream

import (
	"fmt"
	"strings"
)

//泛型类型定义
type T interface{}
type U interface{}

//流计算数据结构定义
type Stream struct {
	head     T
	tail     *Stream
	length   int
	notEmpty bool
}

var Nil = Stream{}

func New(head T, tail *Stream) Stream {
	return Stream{head, tail, tail.length + 1, true}
}

func (s Stream) Add(i T) Stream {
	return New(i, &s)
}

func (s Stream) Addall(i ...T) Stream {
	for _, v := range i {
		s = s.Add(v)
	}
	return s
}

//左折叠 用于实现 reduce 的功能
func (s Stream) FoldLeft(i U, f func(U, T) U) U {
	if s.notEmpty {
		return s.tail.FoldLeft(f(i, s.head), f)
	} else {
		return i
	}
}

//右折叠
func (s Stream) FoldRight(i U, f func(U, T) U) U {
	if s.notEmpty {
		return f(s.tail.FoldRight(i, f), s.head)
	} else {
		return i
	}
}

//合并两个 Stream
func (s Stream) Merge(t Stream) Stream {
	if t.notEmpty {
		return t.FoldRight(s, func(u U, t T) U {
			return u.(Stream).Add(t)
		}).(Stream)

	} else {
		return s
	}
}

//倒序
func (s Stream) Reverse() Stream {
	return s.FoldLeft(Nil, func(u U, t T) U {
		return u.(Stream).Add(t)
	}).(Stream)
}

//Map
func (s Stream) Map(f func(T) U) Stream {
	return s.FoldRight(Nil, func(u U, t T) U {
		return u.(Stream).Add(f(t))
	}).(Stream)
}

//过滤
func (s Stream) Filter(f func(T) bool) Stream {
	return s.FoldRight(Nil, func(u U, t T) U {
		if f(t) {
			return u.(Stream).Add(t)
		} else {
			return u
		}
	}).(Stream)
}

//打印所有
func (s Stream) ToString() string {
	return "[" + strings.Join(s.FoldRight([]string{}, func(u U, t T) U {
		return append(u.([]string), fmt.Sprintf("%v", t))
	}).([]string), ",") + "]"
}

测试上面的数据结构和方法

package main

import (
	"fmt"
	"math"
	"strings"
)
import . "./stream"

func main() {

	x := Nil.Addall(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
	fmt.Println(x.ToString()) //[1,2,3,4,5,6,7,8,9,10]

	x1 := x.Reverse()
	fmt.Println(x1.ToString()) //[10,9,8,7,6,5,4,3,2,1]

	x2 := x.Merge(Nil.Addall(11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
	fmt.Println(x2.ToString()) //[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]

	//map x -> x * x
	y := x2.Map(func(t T) U {
		p := t.(int)
		return p * p
	})
	fmt.Println(y.ToString()) //[1,4,9,16,25,36,49,64,81,100,121,144,169,196,225,256,289,324,361,400]

	//filter x -> x % 2 == 0
	z := y.Filter(func(t T) bool {
		return t.(int)%2 == 0
	})
	fmt.Println(z.ToString()) //[4,16,36,64,100,144,196,256,324,400]

	// sum(z)
	sum := z.FoldLeft(0, func(u U, t T) U {
		return u.(int) + t.(int)
	})

	fmt.Println(sum) //1540
	
	f := Nil.Addall(3.5, 4.3, 2.6, 1.1, 7.83, 4.42)
	//查找浮点数的最小值,只有匿名函数没有Lambda表达式也是很麻烦的
	fmt.Println(f.FoldLeft(math.MaxFloat64, func(u U, t T) U {
		if u.(float64) < t.(float64) {
			return u
		} else {
			return t
		}
	}))

	//列出包含a字符的字符串
	g := Nil.Addall("aaa","bbb","aba","ccc","cbb","cba")
	fmt.Println(g.Filter(func(t T) bool {
		return strings.Contains(t.(string),"a")
	}).ToString())
}

  

猜你喜欢

转载自www.cnblogs.com/scala/p/9557305.html
今日推荐