S-JIS[2017-01-15/2017-02-01] 変更履歴

Dataset

Apache SparkのDatasetクラスについて。


概要

Spark2.0以降はDatasetを使ってプログラミングする。
(Spark2では、DataFrameは「Dataset[Row]」の別名である)


最初のDatasetはSparkSessionを使って生成する。
Scalaのコレクション(Seq)から作る方法と、ファイル等から読み込んで作る方法がある。

DatasetクラスにはfilterやmapといったScalaのコレクションと同様のメソッドが用意されている。


Datasetの生成

Seqから生成する例

import org.apache.spark.sql.SparkSession
case class Person(name: String, age: Long)
    val spark = SparkSession.builder().〜.getOrCreate()
    import spark.implicits._

    val seq = Seq(Person("hoge", 20), Person("foo", 30), Person("bar", 40))
    val ds = seq.toDS()

seq.toDF()でDataFrameに変換することも出来るが、内部的にはseq.toDS().toDF()と同じ。


テキストファイルから生成する例

import org.apache.spark.sql.SparkSession
  def readText(spark: SparkSession): Unit = {
    import spark.implicits._

    // DataFrameとして読み込む
    val df = spark.read.text("D:/tmp/person.csv")
//  df.printSchema()

    // Dataset[Person]に変換
    val ds = df.map(row => {
      val s = row.getString(0)
      val a = s.split(",")
      Person(a(0), a(1).toLong)
    })
    ds.show()
  }

read.textメソッドは可変長引数なので、複数ファイルを指定することも出来る。
また、ワイルドカード「*」を使うことも出来る。


read.textでファイルを読み込むと、スキーマは、valueという名前のStringフィールドがひとつあるだけの状態になっている。

root
 |-- value: string (nullable = true)

csvファイルから生成する例

import org.apache.spark.sql.SparkSession
  def readCsv(spark: SparkSession): Unit = {
    import spark.implicits._

    // DataFrameとして読み込む
    val df = spark.read.csv("D:/tmp/person.csv")
//  df.printSchema()

    // Dataset[Person]に変換
    val ds = df.map(row => {
      val name = row.getString(0)
      val age  = row.getString(1).toLong
      Person(name, age)
    })
    ds.show()
  }

csvにはいくつかのオプションがあり、例えばフィールドセパレーターをタブにすればtsvファイルが読める。[2017-01-17]
他にもエンコーディングやヘッダー有無の指定等、いくつかのオプションがある。
どのようなオプションがあるのかは、def csv(paths: String*)メソッドのScaladocを参照。


read.csvでファイルを読み込むと、スキーマは、_c0, _c1といったカラム名でデータ型は全てStringという状態になっている。

root
 |-- _c0: string (nullable = true)
 |-- _c1: string (nullable = true)

スキーマを指定してcsvファイルから読み込む例

DataFrameのスキーマ(カラム名とデータ型)がケースクラスと一致していれば、(自分でmapを書かなくても)そのケースクラスのDatasetに変換できる。

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructType
  def readSchema(spark: SparkSession): Unit = {
    import spark.implicits._

    val schema = new StructType()
      .add("name", StringType)
      .add("age", LongType)
    val ds = spark.read.schema(schema).csv("D:/tmp/person.csv").as[Person]
    ds.show()
  }
スキーマ作成の例
説明
import org.apache.spark.sql.types._
val schema = StructType(Seq(
  StructField("name", StringType),
  StructField("age", LongType)
))
StructTypeのファクトリーを使ってスキーマを生成する方法。
import org.apache.spark.sql.types._
val schema = new StructType()
  .add("name", StringType)
  .add("age", LongType)
StructTypeをビルダー風に使ってスキーマを生成する方法。
val schema = Seq.empty[Person].toDS().schema
該当ケースクラスの空のDatasetを作って、そのスキーマを取得する方法。
(スキーマ関連のクラスをインポートする必要は無いが、実行時にちょっと無駄がある?)
val schema = spark.emptyDataset[Person].schema
import org.apache.spark.sql.Encoder
def getSchema[T](implicit enc: Encoder[T]) = enc.schema
val schema = getSchema[Person]
Encoderを使って該当ケースクラスのスキーマを生成する方法。
import org.apache.spark.sql.Encoder
val schema = implicitly[Encoder[Person]].schema
import org.apache.spark.sql.Encoders
val schema = Encoders.product[Person].schema

tsvファイルから生成する例

tsvファイルを読み込む場合は、フィールドセパレーターをタブ文字にしてcsv読み込みメソッドを使用する。[2017-01-17]

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession
  def readTsv(spark: SparkSession): Unit = {
    import spark.implicits._

    val schema = Encoders.product[Person].schema
    val ds = spark.read.schema(schema).option("sep", "\t").csv("D:/tmp/person.tsv").as[Person]
//  ds.printSchema()
    ds.show()
  }

Datasetのメソッド

Dataset[T]の主なメソッド。


設定系

メソッド ver 説明 実行結果
schema: StructType 1.6.0 スキーマを取得する。 ds.schema.foreach(f => println(s"${f.name}, ${f.dataType}")) name, StringType
age, LongType
printSchema(): Unit 1.6.0 スキーマの内容を表示する。
→データを表示するのはshow
ds.printSchema() root
 |-- name: string (nullable = true)
 |-- age: long (nullable = true)
columns: Array[String] 1.6.0 カラム名の一覧を取得する。 val columns = ds.columns  
dtypes: Array[(String, String)] 1.6.0 カラム名とデータ型(の名前)のペアの一覧を取得する。 val dtypes = ds.dtypes  
col(colName: String): Column 2.0.0 Columnを取得する。 val column = ds.col("name")  
apply(colName: String): Column 2.0.0 colと同じ。
(applyだとメソッド名を省略して記述できる
val column = ds("name")  
inputFiles: Array[String] 2.0.0 読み込んだファイルの一覧を返す。
(ワイルドカードを指定して読み込んだ場合でも、具体的な個々のファイル名を返す)
println(ds.inputFiles.toSeq)  
registerTempTable(tableName: String): Unit 1.6.0 非推奨。→createOrReplaceTempView    
createTempView(viewName: String): Unit 2.0.0      
createOrReplaceTempView(viewName: String): Unit 2.0.0      
createGlobalTempView(viewName: String): Unit 2.1.0      
explain(): Unit
explain(extended: Boolean): Unit
1.6.0 物理プランを表示する。(デバッグ用) ds.explain()  
isLocal: Boolean 1.6.0 trueの場合、collectおよびtakeがローカルで実行できる。    
isStreaming: Boolean 2.0.0 ストリーミングのデータ源を持っているかどうか。    
withWatermark(eventTime: String, delayThreshold: String): Dataset[T] 2.1.0 イベント時間を設定する。    
na: DataFrameNaFunctions 1.6.0 失われたデータを処理するオブジェクトを返す?    
stat: DataFrameStatFunctions 1.6.0 統計情報を処理するオブジェクトを返す?    

変換系

Datasetから別のDataset(DataFrame)を生成する処理。

メソッド ver 説明 実行結果
as[U]: Dataset[U] 1.6.0 DataFrameをDataset[U]に変換する。
UはEncoderが対応している必要がある。(ケースクラスならOK)
ケースクラスに変換する場合、スキーマ(のカラム名およびデータ型)と一致している必要がある。
val ds = df.as[Person]
ds.show()
+----+---+
|name|age|
+----+---+
|hoge| 20|
| foo| 30|
| bar| 40|
+----+---+
toDF(): DataFrame 1.6.0 DataFrameに変換する。 val df = ds.toDF()
df.show()
+----+---+
|name|age|
+----+---+
|hoge| 20|
| foo| 30|
| bar| 40|
+----+---+
toDF(colNames: String*): DataFrame 2.0.0 カラム名を変更したDataFrameを返す。
(引数のカラム数と変換元スキーマのカラム数は一致している必要がある)
→1カラムだけ変えるならwithColumnRenamed
val df = ds.toDF("c0", "c1")
df.show()
+----+---+
|  c0| c1|
+----+---+
|hoge| 20|
| foo| 30|
| bar| 40|
+----+---+
as(alias: String): Dataset[T]
as(alias: Symbol): Dataset[T]
1.6.0
2.0.0
テーブル名を付ける。
カラム指定時にテーブル名を使えるようになる。
val df = ds.as("p").select($"p.name", $"p.age")
df.show()
+----+---+
|name|age|
+----+---+
|hoge| 20|
| foo| 30|
| bar| 40|
+----+---+
alias(alias: String): Dataset[T]
alias(alias: Symbol): Dataset[T]
2.0.0 asと同じ。    
map[U](func: T => U): Dataset[U] 1.6.0 保持しているデータ(の型)を変換する。 case class Age(age: Long)
val dsAge = ds.map(person => Age(person.age))
dsAge.show()
+---+
|age|
+---+
| 20|
| 30|
| 40|
+---+
map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] 1.6.0 Java8の関数インターフェースラムダ式)用。    
mapPartitions[U](func: Iterator[T] => Iterator[U]): Dataset[U] 1.6.0      
mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] 1.6.0 Java8の関数インターフェースラムダ式)用。    
flatMap[U](func: T => TraversableOnce[U]): Dataset[U] 1.6.0 保持しているデータ(の型)(および件数)を変換する。 case class Age(age: Long)
val dsAge = ds.flatMap(person => Seq(Age(person.age), Age(person.age + 1)))
println(dsAge.collect().toSeq)
WrappedArray(Age(20), Age(21), Age(30), Age(31), Age(40), Age(41))
flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] 1.6.0 Java8の関数インターフェースラムダ式)用。    
explode[A](input: Column*)(f: Row => TraversableOnce[A]): DataFrame
explode[A, B](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]): DataFrame
2.0.0 非推奨。→flatMap, select    
describe(cols: String*): DataFrame 1.6.0 件数・平均・標準偏差・最小・最大値を保持するDataFrameを作成する。 val df = ds.describe("age")
df.show()
+-------+----+
|summary| age|
+-------+----+
|  count|   3|
|   mean|30.0|
| stddev|10.0|
|    min|  20|
|    max|  40|
+-------+----+
toJSON: Dataset[String] 2.0.0 カラム名がvalue, データ内容がJSON文字列であるDataset[String]に変換する。 val dsJson = ds.toJSON
dsJosn.show()
+--------------------+
|               value|
+--------------------+
|{"name":"hoge","a...|
|{"name":"foo","ag...|
|{"name":"bar","ag...|
+--------------------+
repartition(numPartitions: Int): Dataset[T] 1.6.0 パーティション数を変更する。
RDDのrepartition
   
repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T]
repartition(partitionExprs: Column*): Dataset[T]
2.0.0 パーティションを分割する。(hash partitionになる)
Hiveの「DISTRIBUTE BY」と同じようなものらしいので、同一カラムのデータは同じパーティションに入ると思われる)
val ds2 = ds.repartition($"name")  
coalesce(numPartitions: Int): Dataset[T] 1.6.0 パーティション数を変更する。
RDDのcoalesce
   
rdd: RDD[T] 1.6.0 RDDを返す。    
toJavaRDD: JavaRDD[T]
javaRDD: JavaRDD[T]
1.6.0 JavaRDDを返す。    
transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] 1.6.0      

アクション

ジョブを実行して値を返す処理。(いわゆる終端処理)

メソッド ver 説明 実行結果
show(): Unit
show(numRows: Int): Unit
show(truncate: Boolean): Unit
show(numRows: Int, truncate: Boolean): Unit
show(numRows: Int, truncate: Int): Unit
1.6.0 Datasetの内容を表示する。
printSchema
ds.show() +----+---+
|name|age|
+----+---+
|hoge| 20|
| foo| 30|
| bar| 40|
+----+---+
foreach(f: T => Unit): Unit
foreach(func: ForeachFunction[T]): Unit
1.6.0 Datasetの中身を処理する。 ds.foreach(person => println(person)) Person(hoge,20)
Person(foo,30)
Person(bar,40)
foreachPartition(f: Iterator[T] => Unit): Unit
foreachPartition(func: ForeachPartitionFunction[T]): Unit
1.6.0      
count(): Long 1.6.0 Datasetの件数を返す。 println(ds.count()) 3
collect(): Array[T] 1.6.0 Datasetの内容を配列にして返す。 println(ds.collect().toSeq) WrappedArray(Person(hoge,20), Person(foo,30), Person(bar,40))
collectAsList(): java.util.List[T] 1.6.0 Java用。    
toLocalIterator(): java.util.Iterator[T] 2.0.0 Java用。 toLocalIteratorを使ってファイル出力する例  
take(n: Int): Array[T] 1.6.0 指定された件数だけ配列にして返す。
limithead
println(ds.take(2).toSeq) WrappedArray(Person(hoge,20), Person(foo,30))
takeAsList(n: Int): java.util.List[T] 1.6.0 Java用。    
head(): T 1.6.0 先頭1件を返す。 println(ds.head()) Person(hoge,20)
head(n: Int): Array[T] 1.6.0 先頭n件を配列にして返す。
take
println(ds.head(2).toSeq) WrappedArray(Person(hoge,20), Person(foo,30))
first(): T 1.6.0 headと同じ。    
reduce(func: (T, T) => T): T 1.6.0 集約処理を行う。
agg
println(ds.select("age").as[Long].reduce(_ + _)) 90
reduce(func: ReduceFunction[T]): T 1.6.0 Java8の関数インターフェースラムダ式)用。    
write: DataFrameWriter[T] 1.6.0 ファイル出力に使用する。
ファイル出力の例
ds.write.csv("/tmp/result/person")  
writeStream: DataStreamWriter[T] 2.0.0 ストリーミングデータのファイル出力に使用する(と思われる)    

列操作系

カラムを増減・演算する処理。

メソッド ver 説明 実行結果
select(col: String, cols: String*): DataFrame 2.0.0 カラムを選択して絞り込んだDataFrameを返す。 val df = ds.select("name")
df.show()
+----+
|name|
+----+
|hoge|
| foo|
| bar|
+----+
select(cols: Column*): DataFrame 2.0.0 Columnで選択カラムを指定する。 val df = ds.select($"name", $"age" + 1 as "inc")
df.show()
+----+---+
|name|inc|
+----+---+
|hoge| 21|
| foo| 31|
| bar| 41|
+----+---+
select[U1](c1: TypedColumn[T, U1]): Dataset[U1]
select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)]

引数5個版まで
1.6.0 TypedColumnで選択カラムおよび演算を指定する。
sql.functionsでColumnを取得でき、Columnに「.as[型]」を付けるとTypedColumnになる。
import org.apache.spark.sql.functions.{ col, expr }

val df = ds.select(col("name").as[String], expr("age + 1 as inc").as[Long])
df.show()
+----+---+
|name|inc|
+----+---+
|hoge| 21|
| foo| 31|
| bar| 41|
+----+---+
selectExpr(exprs: String*): DataFrame 2.0.0 カラムの演算を文字列で指定する。 val df = ds.selectExpr("name", "age + 1 as inc")
df.show()
+----+---+
|name|inc|
+----+---+
|hoge| 21|
| foo| 31|
| bar| 41|
+----+---+
withColumn(colName: String, col: Column): DataFrame 2.0.0 カラムを追加したDataFrameを返す。 val df = ds.withColumn("inc", $"age" + 1)
df.show()
+----+---+---+
|name|age|inc|
+----+---+---+
|hoge| 20| 21|
| foo| 30| 31|
| bar| 40| 41|
+----+---+---+
withColumnRenamed(existingName: String, newName: String): DataFrame 2.0.0 カラム名を変更したDataFrameを返す。
(存在しないカラム名が指定されていた場合は、何もしない(エラーにもならない))
(新しいカラム名が他で使われている名前だった場合、エラーにならないが、その名前を使うメソッドで例外が発生する)
→全カラム名をまとめて変えるならtoDF
val df = ds.withColumnRenamed("name", "zzz")
df.show()
+----+---+
| zzz|age|
+----+---+
|hoge| 20|
| foo| 30|
| bar| 40|
+----+---+
drop(colName: String): DataFrame
drop(colNames: String*): DataFrame
2.0.0 指定されたカラムを除去したDataFrameを返す。 val df = ds.drop("name")
df.show()
+---+
|age|
+---+
| 20|
| 30|
| 40|
+---+
drop(col: Column): DataFrame 2.0.0 val df = ds.drop($"name")
df.show()
+---+
|age|
+---+
| 20|
| 30|
| 40|
+---+

行操作系

レコードを操作する処理。

メソッド ver 説明 実行結果
filter(conditionExpr: String): Dataset[T] 1.6.0 条件に合致するレコードだけ抽出する。 val df = ds.filter("age >= 30")
df.show()
+----+---+
|name|age|
+----+---+
| foo| 30|
| bar| 40|
+----+---+
filter(condition: Column): Dataset[T] 1.6.0 val df = ds.filter($"age" >= 30)
df.show()
+----+---+
|name|age|
+----+---+
| foo| 30|
| bar| 40|
+----+---+
filter(func: T => Boolean): Dataset[T] 1.6.0 val ds2 = ds.filter(person => person.age >= 30)
ds2.show()
+----+---+
|name|age|
+----+---+
| foo| 30|
| bar| 40|
+----+---+
filter(func: FilterFunction[T]): Dataset[T] 1.6.0 Java8の関数インターフェースラムダ式)用。    
where(conditionExpr: String): Dataset[T]
where(condition: Column): Dataset[T]
1.6.0 filterと同じ。    
limit(n: Int): Dataset[T] 2.0.0 指定された件数だけ抽出する。
take
val df = ds.limit(2)
df.show()
+----+---+
|name|age|
+----+---+
|hoge| 20|
| foo| 30|
+----+---+
dropDuplicates(): Dataset[T] 2.0.0 重複排除(distinct)する。 val ds = Seq(Person("hoge", 20), Person("hoge", 20)).toDS()
val ds2 = ds.dropDuplicates()
ds2.show()
+----+---+
|name|age|
+----+---+
|hoge| 20|
+----+---+
dropDuplicates(col1: String, cols: String*): Dataset[T]
dropDuplicates(colNames: Seq[String]): Dataset[T]
dropDuplicates(colNames: Array[String]): Dataset[T]
2.0.0 指定されたカラムが重複しているものを排除する。
(どちらのレコードが残るかは、並び順次第?)
キー毎に重複排除する例
val ds = Seq(Person("hoge", 20), Person("hoge", 21), Person("foo", 20)).toDS()
val ds2 = ds.dropDuplicates("name")
ds2.show()
+----+---+
|name|age|
+----+---+
| foo| 20|
|hoge| 20|
+----+---+
distinct(): Dataset[T] 2.0.0 dropDuplicatesと同じ。    
sort(sortCol: String, sortCols: String*): Dataset[T] 2.0.0 指定したカラムでソートする。

昇順・降順を指定したい場合はColumnを使う。
asc, desc, asc_nulls_first, asc_nulls_last, desc_nulls_first, desc_nulls_lastのいずれかを指定する。
val ds2 = ds.sort("name")
ds2.show()
+----+---+
|name|age|
+----+---+
| bar| 40|
| foo| 30|
|hoge| 20|
+----+---+
sort(sortExprs: Column*): Dataset[T] 2.0.0 val ds2 = ds.sort($"age".desc)
ds2.show()
+----+---+
|name|age|
+----+---+
| bar| 40|
| foo| 30|
|hoge| 20|
+----+---+
sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T]
sortWithinPartitions(sortExprs: Column*): Dataset[T]
2.0.0 パーティション毎にソートする。
HiveのSORT BYのようなもの)
→全データのソートはsort
   
orderBy(sortCol: String, sortCols: String*): Dataset[T]
orderBy(sortExprs: Column*): Dataset[T]
2.0.0 sortと同じ。    
groupBy(col1: String, cols: String*): RelationalGroupedDataset 2.0.0 指定されたカラム毎の集約処理を行う。
→全レコードの集約はagg
val ds = Seq(Person("hoge", 20), Person("hoge", 21), Person("foo", 30)).toDS()
val ds2 = ds.groupBy("name").agg("age" -> "sum")
ds2.show()
+----+--------+
|name|sum(age)|
+----+--------+
| foo|      30|
|hoge|      41|
+----+--------+
groupBy(cols: Column*): RelationalGroupedDataset 2.0.0 import org.apache.spark.sql.functions.sum

val ds = Seq(Person("hoge", 20), Person("hoge", 21), Person("foo", 30)).toDS()
val ds2 = ds.groupBy($"name").agg(sum("age") as "s")
ds2.show()
+----+---+
|name|  s|
+----+---+
| foo| 30|
|hoge| 41|
+----+---+
rollup(col1: String, cols: String*): RelationalGroupedDataset 2.0.0      
rollup(cols: Column*): RelationalGroupedDataset 2.0.0    
cube(col1: String, cols: String*): RelationalGroupedDataset 2.0.0      
cube(cols: Column*): RelationalGroupedDataset 2.0.0    
groupByKey[K](func: T => K): KeyValueGroupedDataset[K, T] 2.0.0 キーを生成する関数を渡し、そのキー毎に集約する。
集約処理自体はKeyValueGroupedDatasetのメソッドを呼ぶ。
キー毎に重複排除する例
import org.apache.spark.sql.functions.expr
case class MyKey(mykey: String, dec: Long)

val ds = Seq(Person("hoge", 20), Person("hage", 21), Person("foo", 30)).toDS()
val ds2 = ds.groupByKey(person => MyKey(person.name.substring(0, 1), person.age / 10))
  .agg(expr("sum(dec) as n").as[Long], expr("sum(age) as s").as[Long])
ds2.show()
+-----+---+---+
|  key|  n|  s|
+-----+---+---+
|[h,2]|  4| 41|
|[f,3]|  3| 30|
+-----+---+---+
groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] 2.0.0 Java8の関数インターフェースラムダ式)用。    
agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame 2.0.0 全データを集約する。
→キー毎の集約はgroupBy
reduce
val df = ds.agg("age" -> "sum")
df.show()
+--------+
|sum(age)|
+--------+
|      90|
+--------+
agg(exprs: Map[String, String]): DataFrame 2.0.0 val map = Map("age" -> "sum")
val df = ds.agg(map)
df.show()
+--------+
|sum(age)|
+--------+
|      90|
+--------+
agg(expr: Column, exprs: Column*): DataFrame 2.0.0 import org.apache.spark.sql.functions.sum

val df = ds.agg(sum("age") as "s")
df.show()
+---+
|  s|
+---+
| 90|
+---+
agg(exprs: java.util.Map[String, String]): DataFrame 2.0.0 Java用。    
sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T]
sample(withReplacement: Boolean, fraction: Double): Dataset[T]
1.6.0 サンプリングする。(適当に行を抽出する) val ds2 = ds.sample(false, 0.2)
ds2.show()
+----+---+
|name|age|
+----+---+
| bar| 40|
+----+---+
randomSplit(weights: Array[Double]): Array[Dataset[T]]
randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]]
2.0.0 適当に分割する。 val dss = ds.randomSplit(Array(0.2, 0.8))
dss.foreach(_.show())
+----+---+
|name|age|
+----+---+
| foo| 30|
|hoge| 20|
+----+---+

+----+---+
|name|age|
+----+---+
| bar| 40|
+----+---+
randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] 2.0.0 Java用。    

結合処理

他のDatasetを使用する処理。

メソッド ver 説明 実行結果
join(right: Dataset[_]): DataFrame 2.0.0 インナージョイン。
結合条件は後続のwhereで記述する。
case class Master(name: String, sub: String)
val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS()
val df = ds.join(ds2).where(ds("name") === ds2("name"))
df.show()
+----+---+----+---+
|name|age|name|sub|
+----+---+----+---+
|hoge| 20|hoge|abc|
| foo| 30| foo|def|
+----+---+----+---+
join(right: Dataset[_], usingColumn: String): DataFrame 2.0.0 インナージョイン。
(結合キーは1つ)
case class Master(name: String, sub: String)
val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS()
val df = ds.join(ds2, "name")
df.show()
+----+---+---+
|name|age|sub|
+----+---+---+
|hoge| 20|abc|
| foo| 30|def|
+----+---+---+
join(right: Dataset[_], usingColumns: Seq[String]): DataFrame 2.0.0 インナージョイン。
(結合キーは複数カラム指定可能)
case class Master(name: String, sub: String)
val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS()
val df = ds.join(ds2, Seq("name"))
df.show()
 
+----+---+---+
|name|age|sub|
+----+---+---+
|hoge| 20|abc|
| foo| 30|def|
+----+---+---+
join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame 2.0.0 結合タイプを指定する。 case class Master(name: String, sub: String)
val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS()
val df = ds.join(ds2, Seq("name"), "left")
df.show()
+----+---+----+
|name|age| sub|
+----+---+----+
|hoge| 20| abc|
| foo| 30| def|
| bar| 40|null|
+----+---+----+
join(right: Dataset[_], joinExprs: Column): DataFrame 2.0.0 インナージョイン。
結合条件を指定する。
case class Master(name: String, sub: String)
val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS()
val df = ds.join(ds2, ds("name") === ds2("name"))
df.show()
+----+---+----+---+
|name|age|name|sub|
+----+---+----+---+
|hoge| 20|hoge|abc|
| foo| 30| foo|def|
+----+---+----+---+
join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame 2.0.0 結合タイプおよび結合条件を指定する。 case class Master(name: String, sub: String)
val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS()
val df = ds.join(ds2, ds("name") === ds2("name"), "left")
df.show()
+----+---+----+----+
|name|age|name| sub|
+----+---+----+----+
|hoge| 20|hoge| abc|
| foo| 30| foo| def|
| bar| 40|null|null|
+----+---+----+----+
crossJoin(right: Dataset[_]): DataFrame 2.1.0 クロスジョイン。
結合条件は後続のwhereで記述する。
case class Master(name: String, sub: String)
val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS()
val df = ds.crossJoin(ds2).where(ds("name") === ds2("name"))
df.printSchema()
df.show()
+----+---+----+---+
|name|age|name|sub|
+----+---+----+---+
|hoge| 20|hoge|abc|
| foo| 30| foo|def|
+----+---+----+---+
joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] 1.6.0 インナージョイン。
結合結果のレコードはタプル
case class Master(name: String, sub: String)
val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS()
val df = ds.joinWith(ds2, ds("name") === ds2("name"))
df.show()
+---------+----------+
|       _1|        _2|
+---------+----------+
|[hoge,20]|[hoge,abc]|
| [foo,30]| [foo,def]|
+---------+----------+
joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] 1.6.0 結合タイプを指定する。
結合結果のレコードはタプル
case class Master(name: String, sub: String)
val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS()
val df = ds.joinWith(ds2, ds("name") === ds2("name"), "left")
df.show()
+---------+----------+
|       _1|        _2|
+---------+----------+
|[hoge,20]|[hoge,abc]|
| [foo,30]| [foo,def]|
| [bar,40]|      null|
+---------+----------+
union(other: Dataset[T]): Dataset[T] 2.0.0 データの合流。 val ds2 = Seq(Person("zzz", 99)).toDS()
val ds3 = ds.union(ds2)
ds3.show()
+----+---+
|name|age|
+----+---+
|hoge| 20|
| foo| 30|
| bar| 40|
| zzz| 99|
+----+---+
unionAll(other: Dataset[T]): Dataset[T] 2.0.0 非推奨。→union    
intersect(other: Dataset[T]): Dataset[T] 2.0.0 両方に存在するデータだけ残す。 val ds2 = Seq(Person("hoge", 20), Person("foo", 30), Person("bar", 99)).toDS()
val ds3 = ds.intersect(ds2)
ds3.show()
+----+---+
|name|age|
+----+---+
|hoge| 20|
| foo| 30|
+----+---+
except(other: Dataset[T]): Dataset[T] 2.0.0 otherに無いデータだけ残す。 val ds2 = Seq(Person("hoge", 20), Person("bar", 99)).toDS()
val ds3 = ds.except(ds2)
ds3.show()
+----+---+
|name|age|
+----+---+
| bar| 40|
| foo| 30|
+----+---+

結合の注意

join結合キーとしてカラム名を指定するメソッドの場合は、双方のDatasetに同じ名前のカラムが存在する必要があるが、出力されるデータでは1カラムのみになる。
しかし結合条件を指定する方式(ds("name")===right("name"))の場合、双方に同名のカラムがあると、両方とも残る。 そして後続処理でそのカラム名を指定すると、一意に定まらなくてエラーになる。

この場合、Datasetに名前(テーブル名相当)を付けておけば、その名前を使ってカラムを一意に識別することが出来る。

    import spark.implicits._

    val ds1 = Seq(Person("hoge", 20), Person("foo", 30), Person("bar", 40)).toDS().as("person")
    val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS().as("master")

    val df = ds1.join(ds2)
      .where($"person.name" === $"master.name")
      .select("person.name", "age", "sub")
    df.show()

↓実行結果

+----+---+---+
|name|age|sub|
+----+---+---+
|hoge| 20|abc|
| foo| 30|def|
+----+---+---+

インナージョインやレフトジョインでマスター側のデータに同一キーのレコードが複数有る場合、全て結合対象となる。
Asakusa FrameworkMasterJoinの場合、マスター側に複数レコードあると、どれか1レコードのみが結合対象となる)


ブロードキャスト

データ量の少ないDataset(DataFrame)なら、ブロードキャスト結合(hash join)をすることが出来る。[2017-01-16]

import org.apache.spark.sql.functions.broadcast
    val df = ds1.join(broadcast(ds2))

永続化(キャッシュ)処理

メソッド ver 説明 実行結果
persist(): this.type
persist(newLevel: StorageLevel): this.type
1.6.0 データを永続化する。
永続化とチェックポイントの違い
val ds2 = ds.persist()  
cache(): this.type 1.6.0 persistと同じ。    
storageLevel: StorageLevel 2.1.0 永続化のストレージの種類(DISK_ONLYとかMEMORY_ONLYとかNONEとか)を返す。    
unpersist(): this.type
unpersist(blocking: Boolean): this.type
1.6.0 永続化を解除する。 ds2.unpersist()  
checkpoint(): Dataset[T]
checkpoint(eager: Boolean): Dataset[T]
2.1.0 チェックポイントを作成する。
SparkContextでチェックポイントのディレクトリー(HDFS)を指定しておく必要がある。
永続化とチェックポイントの違い
spark.sparkContext.setCheckpointDir("/tmp/spark-checkpoint")
val ds2 = ds.checkpoint()
 

RDDの永続化処理


キー毎に重複排除する例

同一キーで複数レコードある場合に1レコードのみ残すこと(重複排除・distinct)を考えてみる。[2017-01-16]


一番単純なのは、dropDuplicatesを使う方法。

    import spark.implicits._

    val ds = Seq(Person("hoge", 30), Person("hoge", 20), Person("hoge", 40), Person("foo", 30)).toDS()
    val ds2 = ds.dropDuplicates("name")
    ds2.show()

↓実行結果

+----+---+
|name|age|
+----+---+
| foo| 30|
|hoge| 30|
+----+---+

しかしこれだと、重複しているレコードの内のどのレコードが残るか分からない。
(先頭レコードが残るような感じがするけど、分散処理しているときにどれが先頭になるかは保証されるか?)


dropDuplicatesする前に全体をソートする方法は(少量データでは望んだ結果になったけど)、データがパーティションをまたがった場合に上手くいかないと思われる。

×  val ds2 = ds.sort($"name", $"age".asc).dropDuplicates("name")
キーでパーティションを分けてからソートしたら大丈夫そうな気がする。いまいち自信が無い^^;
    val ds2 = ds.repartition($"name").sortWithinPartitions($"name", $"age".asc).dropDuplicates("name")

キー毎にグルーピングして値一覧を処理すれば、そこで自分が望むデータだけを残すことが出来る。

    val ds2 = ds.groupByKey(_.name).mapGroups((k, i) => i.toSeq.sortBy(_.age).head)

iはIterator。ageでソートしたいがIteratorにはsortByが無いので、一旦Seqに変換している。
そして、ageでソートして先頭データ(つまりageが最小のPerson)を取得している。

↓実行結果

+----+---+
|name|age|
+----+---+
| foo| 30|
|hoge| 20|
+----+---+

ただ、自前でIteratorをソートしているので、件数が多いときは嫌なことになりそうな気がする。


自前でソートするとデータ件数が多い場合に不安だし、そもそもソートはシャッフル処理に任せたい気がするので、キーでパーティションを分けてからソートしておけば大丈夫か?(groupByKeyでパーティションが変わる可能性があるような気もする…)

    val ds2 = ds.repartition($"name").sortWithinPartitions($"name", $"age".asc).groupByKey(_.name).mapGroups((k, i) => i.next)

並び順は気にせずに最小のものを探すという方法も考えられる。

    val ds2 = ds.groupByKey(_.name).mapGroups { (k, i) =>
      var min: Person = null
      i.foreach { p =>
        if ((min eq null) || p.age < min.age) {
          min = p
        }
      }
      min
    }

しかし、なんだかとてもクドい感じ(爆)


IteratorにsortByは無いがminBy(やmaxBy)はあるので、これを使うのが一番シンプルかも。[2017-01-23]

    val ds2 = ds.groupByKey(_.name).mapGroups((k, i) => i.minBy(_.age))

minByなら(toSeq.sortByと違ってソートせずに)データ全体を1回走査するだけなので、ソートを伴う操作より早いかも?


自分で最小のものを探すなら、reduceを使う方法もある。

    val ds2 = ds.groupByKey(_.name).reduceGroups((p, r) => if (p.age <= r.age) p else r).map(_._2)

(reduceGroupsはキーと値(この例ではPerson)のタプルを返すので、mapを使ってタプルの2番目(Person)だけ取得している)

reduce内でif式を使いたくないなら、以下のような方法も考えられる。

    // 新しいPersonを作って返す
    val ds2 = ds.groupByKey(_.name).reduceGroups((p, r) => Person(p.name, math.min(p.age, r.age))).map(_._2)
    // 一部だけを変更したPersonを作って返す
    val ds2 = ds.groupByKey(_.name).reduceGroups((p, r) => p.copy(age = math.min(p.age, r.age))).map(_._2)

ケースクラスが可変(case class Person(name: String, var age: Long))であれば、該当フィールドだけ書き換えることも出来る。[2017-01-17]

    val ds2 = ds.groupByKey(_.name).reduceGroups { (p, r) => p.age = math.min(p.age, r.age); p }.map(_._2)

Window関数を使うという方法もあるようだ。(参考: stackoverflowのSPARK DataFrame: select the first row of each group

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.row_number
    val w = Window.partitionBy($"name").orderBy($"age".asc)
    val ds2 = ds.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn").as[Person]

行番号を表すrow_numberカラムをwithColumnで追加し、それが1になるもの(つまり先頭行)だけをwhereで抽出している。
(row_numberカラムは最終的には不要なので、dropで削除している)
(withColumnを使うとDataFrameになってしまうので、(drop後に)as[Person]を呼び出してDataset[Person]にしている)


参考までに、Asakusa Framework(Java)のGroupSortだとグルーピングとソートが同時に指定できるので、以下のように書ける。

    @GroupSort
    public void distinctByName(@Key(group = {"name"}, order = {"age ASC"}) List<Person> personList, Result<Person> out) {
        out.add(personList.get(0));
    }

Asakusa Framewokで書いたアプリケーションをSparkに変換して動くAsakusa on Sparkがこれを実現できているんだから、何か方法はあると思うんだけど^^;


ファイル出力

Dataset#writeでファイルに出力することが出来る。[2017-01-31]


CSVファイルを出力する例

import org.apache.spark.sql.SaveMode
  def writeCsv(ds: Dataset[Person], path: String): Unit = {
    ds.write.mode(SaveMode.Overwrite).csv(path)
  }

pathは出力先ディレクトリー。この下にパーティションの個数分の複数のファイル(拡張子はcsv)が出力される。
modeをOverwrite(上書き)にしておかないと、ディレクトリーが既に存在している場合にエラーになる。

ただ、このcsvメソッドでデータを出力すると、各フィールドがtrimされてしまう。(つまり、空白1文字を出力したいと思っても、空文字列になってしまう
write(DataFrameWriter)にはoptionを設定することは出来る(def csv(path: String)メソッドのScaladocを参照)のだけれども、trimしないようにするオプションは無さげ。

常にダブルクォーテーションで囲むようにすればtrimされないかと思って試してみたけど、trimされてからダブルクォーテーションで囲まれてたorz[2017-02-02]

    ds.write.option("quoteAll", true).mode(SaveMode.Overwrite).csv(path)

ちなみにsortByというメソッドで出力順を制御できそうな感じがしたが、csvファイルは対象外だったorz[2017-02-01]


テキストファイルを出力する例

csvメソッドで出力すると各フィールドがtrimされてしまうが、自前で単一のStringに変換してテキストファイルとして出力することは出来る。[2017-02-01]

import org.apache.spark.sql.SaveMode
  def writeCsv(ds: Dataset[Person], path: String): Unit = {
    ds.map(_.productIterator.mkString(",")).write.mode(SaveMode.Overwrite).text(path)
  }

ケースクラスはProductトレイトをミックスインしているので、productIteratorで各フィールドの値が取れる。
これをmkStringで“カンマ区切りの1つのString”に変換する。

pathは出力先ディレクトリー。この下にパーティションの個数分の複数のファイル(拡張子はtxt)が出力される。


write前にrepartitionでパーティション数を1にしてやれば、出力するファイルを1つにすることは出来る。

  def writeCsv(ds: Dataset[Person], path: String): Unit = {
    ds.map(_.productIterator.mkString(",")).repartition(1).write.mode(SaveMode.Overwrite).text(path)
  }

ちなみに、write(DataFrameWriter)にはpartitionByというパーティションを変更できそうなメソッドがあるのだが、出力ファイルを1つにすることは出来なかった。


Writerを使って出力する失敗例

csvメソッドtextメソッドではファイル名が自由に出来ないので、自前でWriterを使って出力しようとしてみた。

import java.nio.charset.Charset
import java.nio.file.Files
import java.nio.file.Paths
  def writeByWriter(ds: Dataset[Person], path: String): Unit = {
    Files.createDirectories(Paths.get(path))

    val n = path.lastIndexOf('/');
    val name = path.substring(n + 1) + ".csv"

    val writer = Files.newBufferedWriter(Paths.get(path, name), Charset.forName("UTF-8"))
    try {
      ds.map(_.productIterator.mkString(",")).repartition(1).foreach { s =>
        writer.write(s)
        writer.write('\n')
      }
    } finally {
      writer.close()
    }

foreachでBufferedWirterにwriteする。

しかしいざ実行すると、BufferedWirterがシリアライザブルでない為、例外が発生する。
(foreachの中でwriterを使おうとしているので、シリアライズしてexecutorに渡そうとしているのだと思う)


Writerを使って出力する例

import java.nio.charset.Charset
import java.nio.file.Files
import java.nio.file.Paths
  def writeByIterator(ds: Dataset[Person], path: String): Unit = {
    Files.createDirectories(Paths.get(path))

    val n = path.lastIndexOf('/');
    val name = path.substring(n + 1) + ".csv"

    val writer = Files.newBufferedWriter(Paths.get(path, name), Charset.forName("UTF-8"))
    try {
      import scala.collection.JavaConverters._

      ds.map(_.productIterator.map {
        case null      => ""
        case s: String => if (s.contains(",") || s.contains("\r") || s.contains("\n")) '"' + s.replaceAll("\"", "\"\"") + '"' else s
        case x         => x.toString
      }.mkString(","))
        .toLocalIterator().asScala.foreach { s =>
          writer.write(s)
          writer.write("\r\n")
        }
    } finally {
      writer.close()
    }
  }

Datasetのforeachの中にはwriterを置くことが出来ないので、toLocalIteratorメソッドでDatasetからローカルのIteratorに変換し、そのforeachでBufferedWirterに出力する。


Spark目次へ戻る / Scalaへ戻る / 技術メモへ戻る
メールの送信先:ひしだま