Spark2.0 机器学习 ML 库:数据分析方法小结(Scala 版)

一、前言

之前使用 Python 的 sklearn 库做机器学习,现在改用 Spark 的机器学习库 ML
本文所涉及是 Spark2.1ML 库(Spark1.x 为 MLlib)的基础部分
操作的对象是 DataFrame,其类似数据库的表
与 Python 机器学习的数据分析库 Pandas、Numpy 异曲同工。

参考文献:http://spark.apache.org/docs/2.1.0/api/scala/index.html

二、数据集操作方法

代码示例

import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel

/**
  * 1.Spark ML 基础入门
  *
  * 1.1.填坑:
  * ① spark 1.x 与 2.x
  * ② IDE编程 与 Shell编程
  * ③ 定义范围
  *
  * 1.2.访问 web UI
  * http://192.168.222.1:4040(http://ip:port)
  *
  * 1.3.小结
  * 作用类似 Python 数据分析库 pandas + numpy
  * 
  * 1.4.不足
  * 命名重复,运行所有代码将报错
  */

// 注意如下的 case class 要放在 object 外面,否则报错
case class Person(name: String, age: Int, height: Int)

case class Peoples(age: Int, names: String)

case class Score(name: String, score: Int)

object mlFirst {


  def main(args: Array[String]): Unit = {


    // 0.构建 Spark 对象
    val spark = SparkSession
      .builder()
      .master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
      .appName("test")
      .enableHiveSupport()
      .getOrCreate() // 有就获取无则创建

    spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
    import spark.implicits._ // 缺少则报错:Unable to find encoder for type stored in a Dataset.  Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._  Support for serializing other types will be added in future releases

    // 一、DataSet 的创建

    // 1.产生序列 DataSet
    val numDS = spark.range(5, 100, 5) // start、end、step
    numDS.orderBy("id").show(numRows = 5) // 默认 numRos=5,显示前 5 行
    /**
      * +---+
      * | id|
      * +---+
      * |  5|
      * | 10|
      * | 15|
      * | 20|
      * | 25|
      * +---+
      * only showing top 5 rows
      */
    numDS.describe().show()
    /**
      * +-------+------------------+
      * |summary|                id|
      * +-------+------------------+
      * |  count|                19|
      * |   mean|              50.0|
      * | stddev|28.136571693556885|
      * |    min|                 5|
      * |    max|                95|
      * +-------+------------------+
      */

    // 2.集合转成 DataSet
    val seq1 = Seq(Person("linhongcun", 20, 176), Person("linhongcai", 20, 178), Person("linyiguang", 27, 177))
    val ds1 = spark.createDataset(seq1)
    ds1.show()

    // 3.集合转成 DataFrame
    val df1 = spark.createDataFrame(seq1)
      .withColumnRenamed("name", "call") // 列名重命名
      .withColumnRenamed("age", "old")
    df1.show()
    /**
      * +----------+---+------+
      * |      call|old|height|
      * +----------+---+------+
      * |linhongcun| 20|   176|
      * |linhongcai| 20|   178|
      * |linyiguang| 27|   177|
      * +----------+---+------+
      */

    // 4.RDD 转成 DataFrame
    val array1 = Array(("linhongcun", 20, 176), ("linhongcai", 20, 178), ("linyiguang", 27, 177))
    val rdd1 = spark.sparkContext.parallelize(array1, 3) // numSlices
      .map(f => Row(f._1, f._2, f._3))
    val schema = StructType(
      StructField("name", StringType, false)
        ::
        StructField("age", IntegerType, false)
        ::
        Nil
    )
    val rddToDataFrame = spark.createDataFrame(rdd1, schema)
    rddToDataFrame.show()
    /**
      * +----------+---+
      * |name      |age|
      * +----------+---+
      * |linhongcun|20 |
      * |linhongcai|20 |
      * |linyiguang|27 |
      * +----------+---+
      */

    // 5.RDD 转成 DataSet、DataFrame
    val rdd2 = spark.sparkContext.parallelize(array1, 3)
      .map(f => Person(f._1, f._2, f._3))
    val ds2 = rdd2.toDS()
    val df2 = rdd2.toDS()
    ds2.orderBy("age").show()
    /**
      * +----------+---+------+
      * |      name|age|height|
      * +----------+---+------+
      * |linhongcai| 20|   178|
      * |linhongcun| 20|   176|
      * |linyiguang| 27|   177|
      * +----------+---+------+
      */
    df2.orderBy("height").show()
    /**
      * +----------+---+------+
      * |      name|age|height|
      * +----------+---+------+
      * |linhongcun| 20|   176|
      * |linyiguang| 27|   177|
      * |linhongcai| 20|   178|
      * +----------+---+------+
      */

    // 6.RDD 转成 DataSet
    val ds3 = spark.createDataset(rdd2)
    ds3.show()
    /**
      * +----------+---+------+
      * |      name|age|height|
      * +----------+---+------+
      * |linhongcun| 20|   176|
      * |linhongcai| 20|   178|
      * |linyiguang| 27|   177|
      * +----------+---+------+
      */

    // 7.读取文件
    val df4 = spark.read.csv("C:\\Users\\linhongcun\\Desktop\\spark\\src\\main\\resources\\t_teacher.csv") //测试数据:从数据库导出一张表即可
    df4.show(5)
    /**
      * +---+---+---+---+
      * |_c0|_c1|_c2|_c3|
      * +---+---+---+---+
      * |  1|ppt|doc|加勒比|
      * |  2|222|222|李克强|
      * |  3|kkk|sad| 杨幂|
      * |  4| mc|agt|胡景涛|
      * |  5|123|123|习近平|
      * +---+---+---+---+
      * only showing top 5 rows
      */

    // 8.读取文件 + 详细参数
    val schema2 = StructType(
      StructField("id", IntegerType, false)
        ::
        StructField("name", StringType, false)
        ::
        StructField("password", StringType, false)
        ::
        StructField("true_name", StringType, false)
        ::
        Nil
    )
    val df5 = spark.read
      .options(Map(("delimiter", ","), ("header", "false")))
      .schema(schema2)
      .csv("C:\\Users\\linhongcun\\Desktop\\spark\\src\\main\\resources\\t_teacher.csv")
    df5.show(5)
    /**
      * +---+----+--------+---------+
      * | id|name|password|true_name|
      * +---+----+--------+---------+
      * |  1| ppt|     doc|      加勒比|
      * |  2| 222|     222|      李克强|
      * |  3| kkk|     sad|       杨幂|
      * |  4|  mc|     agt|      胡景涛|
      * |  5| 123|     123|      习近平|
      * +---+----+--------+---------+
      * only showing top 5 rows
      */

    // 二、DataSet 基础函数

    // 9.1.DataSet 存储类型
    val seq2 = Seq(Person("linhongcun", 20, 176), Person("linhongcai", 20, 178), Person("linyiguang", 27, 177))
    val ds4 = spark.createDataset(seq2)
    ds4.show()
    /**
      * +----------+---+------+
      * |      name|age|height|
      * +----------+---+------+
      * |linhongcun| 20|   176|
      * |linhongcai| 20|   178|
      * |linyiguang| 27|   177|
      * +----------+---+------+
      */
    ds4.checkpoint()
    ds4.cache()
    ds4.persist(StorageLevel.MEMORY_ONLY)
    ds4.count()
    ds4.show()
    /**
      * +----------+---+------+
      * |      name|age|height|
      * +----------+---+------+
      * |linhongcun| 20|   176|
      * |linhongcai| 20|   178|
      * |linyiguang| 27|   177|
      * +----------+---+------+
      */
    ds4.unpersist(true)

    // 9.2.DataSet 结构属性
    ds4.columns
    ds4.dtypes
    ds4.explain()
    /**
      * == Physical Plan ==
      * LocalTableScan [name#79, age#80, height#81]
      */

    // 9.3.DataSet RDD 数据互转
    val rdd3 = ds4.rdd
    val ds5 = rdd3.toDS()
    ds5.show()
    /**
      * +----------+---+------+
      * |      name|age|height|
      * +----------+---+------+
      * |linhongcun| 20|   176|
      * |linhongcai| 20|   178|
      * |linyiguang| 27|   177|
      * +----------+---+------+
      */
    val df6 = rdd3.toDF()
    df6.show()
    /**
      * +----------+---+------+
      * |      name|age|height|
      * +----------+---+------+
      * |linhongcun| 20|   176|
      * |linhongcai| 20|   178|
      * |linyiguang| 27|   177|
      * +----------+---+------+
      */

    // 9.4.1.DataSet 保存文件
    ds5.select("name", "age", "height")
      .write.format("csv")
      .save("C:\\Users\\linhongcun\\Desktop\\spark\\src\\main\\resources\\save.csv")
    // 9.4.2.读取保存的文件
    val schema3 = StructType(
      StructField("name", StringType, false) ::
        StructField("age", IntegerType, false) ::
        StructField("height", IntegerType, true) :: Nil
    )
    val out = spark.read.options(
      Map(
        ("delimiter", ","),
        ("header", "false")
      )
    )
      .schema(schema3)
      .csv("C:\\Users\\linhongcun\\Desktop\\spark\\src\\main\\resources\\save.csv")
    out.show()
    /**
      * +----------+---+------+
      * |      name|age|height|
      * +----------+---+------+
      * |linhongcun| 20|   176|
      * |linhongcai| 20|   178|
      * |linyiguang| 27|   177|
      * +----------+---+------+
      */

    // 三、DataSet 的 Actions 操作

    // 10.1.显示数据集
    val seq3 = Seq(Person("至尊宝", 19, 178), Person("紫霞仙子", 18, 168), Person("孙悟空", 500, 150))
    val ds6 = spark.createDataset(seq3)
    ds6.show()
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 至尊宝| 19|   178|
      * |紫霞仙子| 18|   168|
      * | 孙悟空|500|   150|
      * +----+---+------+
      */
    ds6.show(2)
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 至尊宝| 19|   178|
      * |紫霞仙子| 18|   168|
      * +----+---+------+
      * only showing top 2 rows
      */
    ds6.show(2,true)
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 至尊宝| 19|   178|
      * |紫霞仙子| 18|   168|
      * +----+---+------+
      * only showing top 2 rows
      */

    //10.2.获取数据集
    val c1 = ds6.collect() // 所有组成Array
    val c2 = ds6.collectAsList() //所有组成List
    val h1 = ds6.head() //第一个
    val h2 = ds6.head(3) //前三个Array
    val f1 = ds6.first() //第一个
    val t1 = ds6.take(2) //前两个Array
    val t2 = ds6.takeAsList(2) //前两个组成List

    //10.3.统计
    println(ds6.count()) //3
    ds6.describe().show()

    /**
      * +-------+----+-----------------+------------------+
      * |summary|name|              age|            height|
      * +-------+----+-----------------+------------------+
      * |  count|   3|                3|                 3|
      * |   mean|null|            179.0|165.33333333333334|
      * | stddev|null|277.9946042641835|14.189197769195175|
      * |    min| 孙悟空|               18|               150|
      * |    max| 至尊宝|              500|               178|
      * +-------+----+-----------------+------------------+
      */
    ds6.describe("age").show()
    /**
      * +-------+-----------------+
      * |summary|              age|
      * +-------+-----------------+
      * |  count|                3|
      * |   mean|            179.0|
      * | stddev|277.9946042641835|
      * |    min|               18|
      * |    max|              500|
      * +-------+-----------------+
      */

    ds6.describe("age", "height").show()
    /**
      * +-------+-----------------+------------------+
      * |summary|              age|            height|
      * +-------+-----------------+------------------+
      * |  count|                3|                 3|
      * |   mean|            179.0|165.33333333333334|
      * | stddev|277.9946042641835|14.189197769195175|
      * |    min|               18|               150|
      * |    max|              500|               178|
      * +-------+-----------------+------------------+
      */

    //10.4.聚集
    println(ds6.reduce((f1, f2) => Person("sum", (f1.age + f2.age), (f1.height + f2.height)))) //Person(sum,537,496)

    // 四、DataSet 类型转化

    // 1 map 操作,flatMap 操作
    val seq1 = Seq(Peoples(30, "刘备,孙权,曹操"), Peoples(10, "刘禅,孙亮,曹丕"))
    val df1 = spark.createDataset(seq1)
    val df2 = df1.rdd.map { x => //.rdd.
      (x.age + 1, x.names)
    }
    println(df2.first()) //(31,刘备,孙权,曹操)

    val df3 = df1.rdd.flatMap { x => // .rdd.
      val a = x.age
      val s = x.names.split(",").map {
        x => (a, x)
      }
      s
    }
    println(df3.first()) //(30,刘备)

    // 2 filter 操作, where 操作
    val seq2 = Seq(Person("武宣卞皇后", 19, 168), Person("步练师", 17, 165), Person("孙尚香", 16, 169))
    val ds4 = spark.createDataset(seq2)
    ds4.filter("age <=18 and height >=168").show()

    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 孙尚香| 16|   169|
      * +----+---+------+
      */
    ds4.filter($"age" <= 18 && $"height" >= 168).show()

    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 孙尚香| 16|   169|
      * +----+---+------+
      */
    ds4.filter { x => x.age <= 18 && x.height >= 168 }.show()

    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 孙尚香| 16|   169|
      * +----+---+------+
      */
    ds4.where("age <=18 and height >= 168").show()

    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 孙尚香| 16|   169|
      * +----+---+------+
      */
    ds4.where($"age" <= 18 && $"height" >= 168).show()

    // 3 去重操作
    val seq3 = Seq(
      Person("赵丽颖", 18, 168),
      Person("杨幂", 17, 168),
      Person("刘亦菲", 18, 168),
      Person("白百合", 18, 167)
    )
    val ds5 = spark.createDataset(seq3)
    ds5.distinct().show() // 去掉完全重复的行
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 赵丽颖| 18|   168|
      * | 白百合| 18|   167|
      * | 刘亦菲| 18|   168|
      * |  杨幂| 17|   168|
      * +----+---+------+
      */
    ds5.dropDuplicates("age").show() // 去掉列 age 相同的行
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * |  杨幂| 17|   168|
      * | 赵丽颖| 18|   168|
      * +----+---+------+
      */
    ds5.dropDuplicates(Array("age", "height")).show() // 去掉 age、height 都相同的行
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * |  杨幂| 17|   168|
      * | 赵丽颖| 18|   168|
      * | 白百合| 18|   167|
      * +----+---+------+
      */

    // 4 加减法
    val seq4 = Seq(
      Person("白百合", 18, 167),
      Person("陈羽凡", 19, 178)
    )
    val ds6 = spark.createDataset(seq4)
    ds5.union(ds6).show() //并集
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 赵丽颖| 18|   168|
      * |  杨幂| 17|   168|
      * | 刘亦菲| 18|   168|
      * | 白百合| 18|   167|
      * | 白百合| 18|   167|
      * | 陈羽凡| 19|   178|
      * +----+---+------+
      */
    ds5.intersect(ds6).show() //交集
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 白百合| 18|   167|
      * +----+---+------+
      */
    ds5.except(ds6).show() // 并集-交集*2
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 赵丽颖| 18|   168|
      * | 刘亦菲| 18|   168|
      * |  杨幂| 17|   168|
      * +----+---+------+
      */
    // 5.select 操作
    ds6.select("name","age").show()
    /**
      * +----+---+
      * |name|age|
      * +----+---+
      * | 白百合| 18|
      * | 陈羽凡| 19|
      * +----+---+
      */
    // 6.排序操作
    ds6.sort("age").show()
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 白百合| 18|   167|
      * | 陈羽凡| 19|   178|
      * +----+---+------+
      */
    ds6.sort($"age".desc,$"height".desc).show()
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 陈羽凡| 19|   178|
      * | 白百合| 18|   167|
      * +----+---+------+
      */
    ds6.orderBy("age").show()
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 白百合| 18|   167|
      * | 陈羽凡| 19|   178|
      * +----+---+------+
      */
    ds6.orderBy($"age".desc,$"height".desc).show()
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 陈羽凡| 19|   178|
      * | 白百合| 18|   167|
      * +----+---+------+
      */

    // 7.分割抽样操作
    val ds7 = ds6.union(ds5)
    val rands = ds7.randomSplit(Array(0.5, 0.5)) // 随机“五五开”
    rands(0).show()
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 白百合| 18|   167|
      * | 陈羽凡| 19|   178|
      * | 白百合| 18|   167|
      * +----+---+------+
      */
    rands(1).show()
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 刘亦菲| 18|   168|
      * |  杨幂| 17|   168|
      * | 赵丽颖| 18|   168|
      * +----+---+------+
      */
    val sample = ds7.sample(false, 0.5).show() // 随机 50% 抽样
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 白百合| 18|   167|
      * | 陈羽凡| 19|   178|
      * | 赵丽颖| 18|   168|
      * |  杨幂| 17|   168|
      * +----+---+------+
      */

    // 8 列操作
    val ds8 = ds7.drop("height").show() //去掉列
    /**
      * +----+---+
      * |name|age|
      * +----+---+
      * | 白百合| 18|
      * | 陈羽凡| 19|
      * | 赵丽颖| 18|
      * |  杨幂| 17|
      * | 刘亦菲| 18|
      * | 白百合| 18|
      * +----+---+
      */
    val ds9 = ds7.withColumn("true_age", $"age" + 2).show() // 新增列
    /**
      * +----+---+------+--------+
      * |name|age|height|true_age|
      * +----+---+------+--------+
      * | 白百合| 18|   167|      20|
      * | 陈羽凡| 19|   178|      21|
      * | 赵丽颖| 18|   168|      20|
      * |  杨幂| 17|   168|      19|
      * | 刘亦菲| 18|   168|      20|
      * | 白百合| 18|   167|      20|
      * +----+---+------+--------+
      */
    val ds10 = ds7.withColumnRenamed("name", "true_name").show() //列名重命名
    /**
      * +---------+---+------+
      * |true_name|age|height|
      * +---------+---+------+
      * |      白百合| 18|   167|
      * |      陈羽凡| 19|   178|
      * |      赵丽颖| 18|   168|
      * |       杨幂| 17|   168|
      * |      刘亦菲| 18|   168|
      * |      白百合| 18|   167|
      * +---------+---+------+
      */

    // 8 join 操作(类似数据库链接查询)
    val seq6 = Seq(
      Score("白百合", 59),
      Score("陈羽凡", 98)
    )
    val ds66 = spark.createDataset(seq6)
    ds5.show()
    /**
      * +----+---+------+
      * |name|age|height|
      * +----+---+------+
      * | 赵丽颖| 18|   168|
      * |  杨幂| 17|   168|
      * | 刘亦菲| 18|   168|
      * | 白百合| 18|   167|
      * +----+---+------+
      */
    ds66.show()
    /**
      * +----+-----+
      * |name|score|
      * +----+-----+
      * | 白百合|   59|
      * | 陈羽凡|   98|
      * +----+-----+
      */
    val ds11=ds66.join(ds5,Seq("name"),"inner").show() // 内连接
    /**
      * +----+-----+---+------+
      * |name|score|age|height|
      * +----+-----+---+------+
      * | 白百合|   59| 18|   167|
      * +----+-----+---+------+
      */
    val ds12=ds66.join(ds5,Seq("name"),"left").show() // 左外连接
    /**
      * +----+-----+----+------+
      * |name|score| age|height|
      * +----+-----+----+------+
      * | 白百合|   59|  18|   167|
      * | 陈羽凡|   98|null|  null|
      * +----+-----+----+------+
      */
    val ds13=ds66.join(ds5,Seq("name"),"right").show() // 右外连接
    /**
      * +----+-----+---+------+
      * |name|score|age|height|
      * +----+-----+---+------+
      * | 赵丽颖| null| 18|   168|
      * |  杨幂| null| 17|   168|
      * | 刘亦菲| null| 18|   168|
      * | 白百合|   59| 18|   167|
      * +----+-----+---+------+
      */

    // 9. 聚合操作
    val ds11 = ds5.groupBy("age").agg(avg("height").as("avg_height")).show()

    /**
      * +---+------------------+
      * |age|        avg_height|
      * +---+------------------+
      * | 17|             168.0|
      * | 18|167.66666666666666|
      * +---+------------------+
      */
  }
}

三、数学方法

import org.apache.spark.sql.SparkSession
import breeze.linalg._
import breeze.numerics.{abs, asin}

/**
  * 向量计算
  */
object math {

  def main(args: Array[String]): Unit = {

    // 0.构建 Spark 对象
    val spark = SparkSession
      .builder()
      .master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
      .appName("test")
      .enableHiveSupport()
      .getOrCreate() // 有就获取无则创建

    spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳

    // 一、Breeze 创建函数

    // 1.1.1. 填充 0
    val v1 = DenseVector.zeros[Double](3) // zeros:0 填充
    println(v1)

    /**
      * DenseVector(0.0, 0.0, 0.0)
      */

    //1.1.2. 填充 1
    val v2 = DenseVector.ones[Double](3) //ones:1 填充
    println(v2)

    /**
      * DenseVector(1.0, 1.0, 1.0)
      */

    //1.1.3. 自定义填充
    val v3 = DenseVector.fill(3) {
      5
    }
    println(v3)

    /**
      * DenseVector(5, 5, 5)
      */

    // 1.1.4. 范围填充
    val v4 = DenseVector.range(1, 10, 2) // start、end、step
    println(v4)

    /**
      * DenseVector(1, 3, 5, 7, 9)
      */

    // 1.2.1. 矩阵
    val m1 = DenseMatrix.zeros[Double](2, 3) // rows、cols
    println(m1)

    /**
      * 0.0  0.0  0.0
      * 0.0  0.0  0.0
      */

    //1.2.2. 对角线
    val m2 = DenseMatrix.eye[Double](3) // dim 对角线长度
    println(m2)

    /**
      * 1.0  0.0  0.0
      * 0.0  1.0  0.0
      * 0.0  0.0  1.0
      */


    //1.3.1 自定义填充
    val x1 = diag(DenseVector(1.1, 2.2, 3.3))
    println(v3)

    /**
      * 1.1  0.0  0.0
      * 0.0  2.2  0.0
      * 0.0  0.0  3.3
      */

    // 1.3.2 函数式向量
    val x2 = DenseVector.tabulate(3) { i => i * i }
    println(x2)

    /**
      * DenseVector(0, 1, 4)
      */

    //1.3.3 函数式矩阵
    val x3 = DenseMatrix.tabulate(3, 2) { case (i, j) => i + j }
    println(x3)

    /**
      * 0  1
      * 1  2
      * 2  3
      */

    //1.3.4 数组式向量
    val x4 = new DenseVector(Array(1, 2, 3, 4))
    println(x4)

    /**
      * DenseVector(1, 2, 3, 4)
      */

    val x5 = new DenseMatrix(2, 3, Array(1, 2, 3, 4, 5, 6))
    println(x5)

    /**
      * 1  3  5
      * 2  4  6
      */

    // 二、 Breeze 元素访问

    // 2.1 向量访问
    val v = DenseVector(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    println(v) //DenseVector(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
    println(v(0)) // 1
    println(v(1 to 4)) // (2, 3, 4, 5)
    println(1 to 6 by 2) // Range(1, 3, 5)

    // 2.2 矩阵访问
    val m = DenseMatrix((1, 2, 3), (4, 5, 6))
    println(m)

    /**
      * 1  2  3
      * 4  5  6
      */
    println(m(0, 0)) // 1

    // 三、Breeze 元素操作

    // 3.1. 变形
    val mr1 = DenseMatrix((1, 2, 3), (4, 5, 6))
    println(mr1)

    /**
      * 1  2  3
      * 4  5  6
      */
    val mr2 = mr1.reshape(3, 2) //rows、cols
    println(mr2)

    /**
      * 1  5
      * 4  3
      * 2  6
      */

    // 3.2.1. 去除右上角的元素
    val mt1 = DenseMatrix((1, 2, 3), (4, 5, 6), (7, 8, 9))
    println(mt1)

    /**
      * 1  2  3
      * 4  5  6
      * 7  8  9
      */

    val mt2 = lowerTriangular(mt1)
    println(mt2)

    /**
      * 1  0  0
      * 4  5  0
      * 7  8  9
      */

    // 3.2.2. 去除左下角的元素
    val mt3 = upperTriangular(mt1)
    println(mt3)

    /**
      * 1  2  3
      * 0  5  6
      * 0  0  9
      */

    // 3.3. 复制
    val mc1 = DenseMatrix((11, 22, 33), (44, 55, 66), (77, 88, 99))
    println(mc1)

    /**
      * 11  22  33
      * 44  55  66
      * 77  88  99
      */
    val mc2 = mc1.copy
    println(mc2)

    /**
      * 11  22  33
      * 44  55  66
      * 77  88  99
      */

    // 3.4. 获取对角线
    println(mc2)

    /**
      * 11  22  33
      * 44  55  66
      * 77  88  99
      */

    //println(diag(mc2)) //DenseVector(11, 55, 99)

    // 3.4. 修改元素值
    println(mc2)

    /**
      * 11  22  33
      * 44  55  66
      * 77  88  99
      */

    mc2(0, 0) = 10
    println(mc2)

    /**
      * 10  22  33
      * 44  55  66
      * 77  88  99
      */


    // 3.5 向量合并
    val catv1 = DenseVector(1, 2, 3)
    println(catv1)

    /**
      * DenseVector(1, 2, 3)
      */

    val catv2 = DenseVector(4, 5, 6)
    println(catv2)

    /**
      * DenseVector(4, 5, 6)
      */

    val catv3 = DenseVector.vertcat(catv1, catv2) //左右合并

    /**
      * DenseVector(1, 2, 3, 4, 5, 6)
      */

    println(catv3)

    val catv4 = DenseVector.horzcat(catv1, catv2) // 上下合并,再反转
    println(catv4)

    /**
      * 1  4
      * 2  5
      * 3  6
      */

    // 3.6. 矩阵合并
    val catm1 = DenseMatrix((1, 1, 1), (2, 2, 2))
    println(cat1)

    /**
      * 1  1  1
      * 2  2  2
      */

    val catm2 = DenseMatrix((3, 3, 3), (4, 4, 4))
    println(cat2)

    /**
      * 3  3  3
      * 4  4  4
      */

    val catm3 = DenseMatrix.vertcat(catm1, catm2) // 上下合并
    println(catm3)

    /**
      * 1  1  1
      * 2  2  2
      * 3  3  3
      * 4  4  4
      */

    val catm4 = DenseMatrix.horzcat(catm1, catm2) // 左右合并
    println(catm4)

    /**
      * 1  1  1  3  3  3
      * 2  2  2  4  4  4
      */

    // 四、Breeze 运算函数

    // 4.1. 可以 + - * / max sum 等, 以 + 为例
    val add1 = DenseMatrix((1, 2, 3), (4, 5, 6))
    println(add1)

    /**
      * 1  2  3
      * 4  5  6
      */

    val add2 = DenseMatrix((1, 1, 1), (1, 1, 1))
    println(add2)

    /**
      * 1  1  1
      * 1  1  1
      */

    val add3 = add1 + add2
    println(add3)

    /**
      * 2  3  4
      * 5  6  7
      */

    // 4.2 其他简单运算实例
    println(sum(add3))

    /**
      * 27
      */
    println(max(add3))

    /**
      * 7
      */
    println(sum(add3, Axis._0))

    /**
      * Transpose(DenseVector(7, 9, 11))
      */
    println(sum(add3, Axis._1))

    /**
      * DenseVector(9, 18)
      */

    //4.3 取正
    val unabs = DenseVector(9, 6, -2)

    /**
      * DenseVector(9, 6, -2)
      */
    println(unabs)
    /**
      * DenseVector(9, 6, 2)
      */
    println(abs(unabs))

    // 4.4 三角函数
    println(asin(0.5)) //0.5235987755982989


    // 五、线性代数

    // 5.1 转置
    val t1 = DenseMatrix((1, 2, 3), (4, 5, 6))
    println(t1)

    /**
      * 1  2  3
      * 4  5  6
      */
    println(t1.t)

    /**
      * 1  4
      * 2  5
      * 3  6
      */

    // 5.2 求特征值
    println(det(t1))

    /**
      * -3.0
      */

  }
}

四、其他

这里写图片描述

猜你喜欢

转载自blog.csdn.net/larger5/article/details/81634508