S-JIS[2017-01-15/2020-10-08] 変更履歴
Apache SparkのDatasetクラスについて。
|
|
|
|
Spark2.0以降はDatasetを使ってプログラミングする。
(Spark2では、DataFrameは「Dataset[Row]
」の別名である)
最初のDatasetはSparkSessionを使って生成する。
Scalaのコレクション(Seq)から作る方法と、ファイル等から読み込んで作る方法がある。
DatasetクラスにはfilterやmapといったScalaのコレクションと同様のメソッドが用意されている。
import org.apache.spark.sql.SparkSession
case class Person(name: String, age: Long)
val spark = SparkSession.builder().〜.getOrCreate() import spark.implicits._ val seq = Seq(Person("hoge", 20), Person("foo", 30), Person("bar", 40)) val ds = seq.toDS()
seq.toDF()
でDataFrameに変換することも出来るが、内部的にはseq.toDS().
toDF()
と同じ。
import org.apache.spark.sql.SparkSession
def readText(spark: SparkSession): Unit = { import spark.implicits._ // DataFrameとして読み込む val df = spark.read.text("D:/tmp/person.csv") // df.printSchema() // Dataset[Person]に変換 val ds = df.map(row => { val s = row.getString(0) val a = s.split(",") Person(a(0), a(1).toLong) }) ds.show() }
read.textメソッドは可変長引数なので、複数ファイルを指定することも出来る。
また、ワイルドカード「*」を使うことも出来る。
read.textでファイルを読み込むと、スキーマは、valueという名前のStringフィールドがひとつあるだけの状態になっている。
root |-- value: string (nullable = true)
import org.apache.spark.sql.SparkSession
def readCsv(spark: SparkSession): Unit = { import spark.implicits._ // DataFrameとして読み込む val df = spark.read.csv("D:/tmp/person.csv") // df.printSchema() // Dataset[Person]に変換 val ds = df.map(row => { val name = row.getString(0) val age = row.getString(1).toLong Person(name, age) }) ds.show() }
csvにはいくつかのオプションがあり、例えばフィールドセパレーターをタブにすればtsvファイルが読める。[2017-01-17]
他にもエンコーディングやヘッダー有無の指定等、いくつかのオプションがある。
どのようなオプションがあるのかは、def
csv(paths: String*)メソッドのScaladocを参照。
read.csvでファイルを読み込むと、スキーマは、_c0, _c1といったカラム名でデータ型は全てStringという状態になっている。
root |-- _c0: string (nullable = true) |-- _c1: string (nullable = true)
DataFrameのスキーマ(カラム名とデータ型)がケースクラスと一致していれば、(自分でmapを書かなくても)そのケースクラスのDatasetに変換できる。
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructType
def readSchema(spark: SparkSession): Unit = { import spark.implicits._ val schema = new StructType() .add("name", StringType) .add("age", LongType) val ds = spark.read.schema(schema).csv("D:/tmp/person.csv").as[Person] ds.show() }
例 | 説明 |
---|---|
import org.apache.spark.sql.types._ val schema = StructType(Seq( StructField("name", StringType), StructField("age", LongType) )) |
StructTypeのファクトリーを使ってスキーマを生成する方法。 |
import org.apache.spark.sql.types._ val schema = new StructType() .add("name", StringType) .add("age", LongType) |
StructTypeをビルダー風に使ってスキーマを生成する方法。 |
val schema = Seq.empty[Person].toDS().schema |
該当ケースクラスの空のDatasetを作って、そのスキーマを取得する方法。 (スキーマ関連のクラスをインポートする必要は無いが、実行時にちょっと無駄がある?) |
val schema = spark.emptyDataset[Person].schema |
|
import org.apache.spark.sql.Encoder def getSchema[T](implicit enc: Encoder[T]) = enc.schema val schema = getSchema[Person] |
Encoderを使って該当ケースクラスのスキーマを生成する方法。 |
import org.apache.spark.sql.Encoder val schema = implicitly[Encoder[Person]].schema |
|
import org.apache.spark.sql.Encoders val schema = Encoders.product[Person].schema |
tsvファイルを読み込む場合は、フィールドセパレーターをタブ文字にしてcsv読み込みメソッドを使用する。[2017-01-17]
import org.apache.spark.sql.Encoders import org.apache.spark.sql.SparkSession
def readTsv(spark: SparkSession): Unit = {
import spark.implicits._
val schema = Encoders.product[Person].schema
val ds = spark.read.schema(schema).option("sep", "\t").csv("D:/tmp/person.tsv").as[Person]
// ds.printSchema()
ds.show()
}
Dataset[T]の主なメソッド。
メソッド | ver | 説明 | 例 | 実行結果 |
---|---|---|---|---|
schema : StructType |
1.6.0 | スキーマを取得する。 | ds.schema.foreach(f =>
println(s"${f.name}, ${f.dataType}")) |
name, StringType |
printSchema (): Unit |
1.6.0 | スキーマの内容を表示する。 →データを表示するのはshow |
ds.printSchema() |
root |
columns : Array[String] |
1.6.0 | カラム名の一覧を取得する。 | val columns = ds. columns |
|
dtypes : Array[(String, String)] |
1.6.0 | カラム名とデータ型(の名前)のペアの一覧を取得する。 | val dtypes = ds. dtypes |
|
col (colName: String):
Column |
2.0.0 | Columnを取得する。 | val column = ds.col("name") |
|
apply (colName: String):
Column |
2.0.0 | colと同じ。 (applyだとメソッド名を省略して記述できる) |
val column = ds("name") |
|
inputFiles : Array[String] |
2.0.0 | 読み込んだファイルの一覧を返す。 (ワイルドカードを指定して読み込んだ場合でも、具体的な個々のファイル名を返す) |
println(ds.inputFiles.toSeq) |
|
registerTempTable (tableName: String): Unit |
1.6.0 | 非推奨。→createOrReplaceTempView | ||
createTempView (viewName: String): Unit |
2.0.0 | |||
createOrReplaceTempView (viewName: String): Unit |
2.0.0 | Spark SQLで使用するビュー名を定義する。[2020-10-08] | ds.createOrReplaceTempView("zzz") |
|
createGlobalTempView (viewName: String): Unit |
2.1.0 | |||
explain (): Unit |
1.6.0 | 物理プランを表示する。(デバッグ用) | ds.explain() |
|
isLocal : Boolean |
1.6.0 | trueの場合、collectおよびtakeがローカルで実行できる。 | ||
isStreaming : Boolean |
2.0.0 | ストリーミングのデータ源を持っているかどうか。 | ||
withWatermark (eventTime: String, delayThreshold:
String): Dataset[T] |
2.1.0 | イベント時間を設定する。 | ||
na : DataFrameNaFunctions |
1.6.0 | 失われたデータを処理するオブジェクトを返す? | ||
stat : DataFrameStatFunctions |
1.6.0 | 統計情報を処理するオブジェクトを返す? |
Datasetから別のDataset(DataFrame)を生成する処理。
メソッド | ver | 説明 | 例 | 実行結果 |
---|---|---|---|---|
as [U]: Dataset[U] |
1.6.0 | DataFrameをDataset[U]に変換する。 UはEncoderが対応している必要がある。(ケースクラスならOK) ケースクラスに変換する場合、スキーマ(のカラム名およびデータ型)と一致している必要がある。 |
val ds = df.as[Person] |
+----+---+ |
toDF (): DataFrame |
1.6.0 | DataFrameに変換する。 | val df = ds.toDF() |
+----+---+ |
toDF(colNames: String*): DataFrame |
2.0.0 | カラム名を変更したDataFrameを返す。 (引数のカラム数と変換元スキーマのカラム数は一致している必要がある) →1カラムだけ変えるならwithColumnRenamed |
val df = ds.toDF("c0", "c1") |
+----+---+ |
as(alias: String): Dataset[T] |
1.6.0 2.0.0 |
テーブル名を付ける。 カラム指定時にテーブル名を使えるようになる。 |
val df = ds.as("p").select($"p.name",
$"p.age") |
+----+---+ |
alias(alias: String): Dataset[T] |
2.0.0 | asと同じ。 | ||
map[U](func: T => U): Dataset[U] |
1.6.0 | 保持しているデータ(の型)を変換する。 | case class Age(age: Long) |
+---+ |
map[U](func: MapFunction[T, U], encoder:
Encoder[U]):
Dataset[U] |
1.6.0 | Java8の関数インターフェース(ラムダ式)用。 | ||
mapPartitions[U](func: Iterator[T] =>
Iterator[U]): Dataset[U] |
1.6.0 | |||
mapPartitions[U](f: MapPartitionsFunction[T, U],
encoder: Encoder[U]): Dataset[U] |
1.6.0 | Java8の関数インターフェース(ラムダ式)用。 | ||
flatMap[U](func: T => TraversableOnce[U]):
Dataset[U] |
1.6.0 | 保持しているデータ(の型)(および件数)を変換する。 | case class Age(age: Long) |
WrappedArray(Age(20), Age(21), Age(30), Age(31),
Age(40), Age(41)) |
flatMap[U](f: FlatMapFunction[T, U], encoder:
Encoder[U]): Dataset[U] |
1.6.0 | Java8の関数インターフェース(ラムダ式)用。 | ||
explode[A](input:
Column*)(f: Row =>
TraversableOnce[A]): DataFrame |
2.0.0 | 非推奨。→flatMap, select | ||
describe (cols: String*): DataFrame |
1.6.0 | 件数・平均・標準偏差・最小・最大値を保持するDataFrameを作成する。 | val df = ds.describe("age") |
+-------+----+ |
toJSON : Dataset[String] |
2.0.0 | カラム名がvalue, データ内容がJSON文字列であるDataset[String]に変換する。 | val dsJson = ds.toJSON |
+--------------------+ |
repartition(numPartitions: Int): Dataset[T] |
1.6.0 | パーティション数を変更する。 →RDDのrepartition |
||
repartition(numPartitions: Int, partitionExprs:
Column*): Dataset[T] |
2.0.0 | パーティションを分割する。(hash partitionになる) (Hiveの「DISTRIBUTE BY」と同じようなものらしいので、同一カラムのデータは同じパーティションに入ると思われる) |
val ds2 = ds.repartition($"name") |
|
coalesce (numPartitions: Int): Dataset[T] |
1.6.0 | パーティション数を変更する。 →RDDのcoalesce |
||
rdd : RDD[T] |
1.6.0 | RDDを返す。 | ||
toJavaRDD : JavaRDD[T] |
1.6.0 | JavaRDDを返す。 | ||
transform [U](t: Dataset[T] => Dataset[U]): Dataset[U] |
1.6.0 |
ジョブを実行して値を返す処理。(いわゆる終端処理)
メソッド | ver | 説明 | 例 | 実行結果 |
---|---|---|---|---|
show (): Unit |
1.6.0 | Datasetの内容を表示する。 →printSchema |
ds.show() |
+----+---+ |
foreach (f: T => Unit): Unit |
1.6.0 | Datasetの中身を処理する。 | ds.foreach(person => println(person)) |
Person(hoge,20) |
foreachPartition(f: Iterator[T] => Unit): Unit |
1.6.0 | |||
count (): Long |
1.6.0 | Datasetの件数を返す。 | println(ds.count()) |
3 |
collect (): Array[T] |
1.6.0 | Datasetの内容を配列にして返す。 | println(ds.collect().toSeq) |
WrappedArray(Person(hoge,20), Person(foo,30),
Person(bar,40)) |
collectAsList (): java.util.List[T] |
1.6.0 | Java用。 | ||
toLocalIterator (): java.util.Iterator[T] |
2.0.0 | Java用。 | →toLocalIteratorを使ってファイル出力する例 | |
take (n: Int): Array[T] |
1.6.0 | 指定された件数だけ配列にして返す。 →limit・head |
println(ds.take(2).toSeq) |
WrappedArray(Person(hoge,20), Person(foo,30)) |
takeAsList (n: Int): java.util.List[T] |
1.6.0 | Java用。 | ||
head (): T |
1.6.0 | 先頭1件を返す。 | println(ds.head()) |
Person(hoge,20) |
head(n: Int): Array[T] |
1.6.0 | 先頭n件を配列にして返す。 →take |
println(ds.head(2).toSeq) |
WrappedArray(Person(hoge,20), Person(foo,30)) |
first(): T |
1.6.0 | headと同じ。 | ||
reduce(func: (T, T) => T): T |
1.6.0 | 集約処理を行う。 →agg |
println(ds.select("age").as[Long].reduce(_
+ _)) |
90 |
reduce(func: ReduceFunction[T]): T |
1.6.0 | Java8の関数インターフェース(ラムダ式)用。 | ||
write: DataFrameWriter[T] |
1.6.0 | ファイル出力に使用する。 →ファイル出力の例 |
ds.write.csv("/tmp/result/person") |
|
writeStream: DataStreamWriter[T] |
2.0.0 | ストリーミングデータのファイル出力に使用する(と思われる)。 |
カラムを増減・演算する処理。
メソッド | ver | 説明 | 例 | 実行結果 |
---|---|---|---|---|
select (col: String, cols: String*): DataFrame |
2.0.0 | カラムを選択して絞り込んだDataFrameを返す。 | val df = ds.select("name") |
+----+ |
select(cols: Column*): DataFrame |
2.0.0 | Columnで選択カラムを指定する。 | val df = ds.select($"name", $"age" + 1 as
"inc") |
+----+---+ |
select[U1](c1: TypedColumn[T, U1]): Dataset[U1] 引数5個版まで |
1.6.0 | TypedColumnで選択カラムおよび演算を指定する。 sql.functionsでColumnを取得でき、Columnに「 .as[型] 」を付けるとTypedColumnになる。 |
import org.apache.spark.sql.functions.{ col, expr } |
+----+---+ |
selectExpr (exprs: String*): DataFrame |
2.0.0 | カラムの演算を文字列で指定する。 | val df = ds.selectExpr("name", "age + 1 as
inc") |
+----+---+ |
withColumn (colName: String, col:
Column): DataFrame |
2.0.0 | カラムを追加したDataFrameを返す。 | val df = ds.withColumn("inc", $"age" + 1) |
+----+---+---+ |
withColumnRenamed(existingName: String, newName:
String): DataFrame |
2.0.0 | カラム名を変更したDataFrameを返す。 (存在しないカラム名が指定されていた場合は、何もしない(エラーにもならない)) (新しいカラム名が他で使われている名前だった場合、エラーにならないが、その名前を使うメソッドで例外が発生する) →全カラム名をまとめて変えるならtoDF |
val df = ds.withColumnRenamed("name",
"zzz") |
+----+---+ |
drop (colName: String): DataFrame |
2.0.0 | 指定されたカラムを除去したDataFrameを返す。 | val df = ds.drop("name") |
+---+ |
drop(col: Column): DataFrame |
2.0.0 | val df = ds.drop($"name") |
+---+ |
レコードを操作する処理。
メソッド | ver | 説明 | 例 | 実行結果 |
---|---|---|---|---|
filter (conditionExpr: String): Dataset[T] |
1.6.0 | 条件に合致するレコードだけ抽出する。 | val df = ds.filter("age >= 30") |
+----+---+ |
filter(condition: Column): Dataset[T] |
1.6.0 | val df = ds.filter($"age" >= 30) |
+----+---+ |
|
filter(func: T => Boolean): Dataset[T] |
1.6.0 | val ds2 = ds.filter(person => person.age
>= 30) |
+----+---+ |
|
filter(func: FilterFunction[T]): Dataset[T] |
1.6.0 | Java8の関数インターフェース(ラムダ式)用。 | ||
where (conditionExpr: String): Dataset[T] |
1.6.0 | filterと同じ。 | ||
limit (n: Int): Dataset[T] |
2.0.0 | 指定された件数だけ抽出する。 →take |
val df = ds.limit(2) |
+----+---+ |
dropDuplicates (): Dataset[T] |
2.0.0 | 重複排除(distinct)する。 | val ds = Seq(Person("hoge", 20), Person("hoge",
20)).toDS() |
+----+---+ |
dropDuplicates(col1: String, cols: String*): Dataset[T] |
2.0.0 | 指定されたカラムが重複しているものを排除する。 (どちらのレコードが残るかは、並び順次第?) →キー毎に重複排除する例 |
val ds = Seq(Person("hoge", 20), Person("hoge",
21), Person("foo", 20)).toDS() |
+----+---+ |
distinct(): Dataset[T] |
2.0.0 | dropDuplicatesと同じ。 | ||
sort(sortCol: String, sortCols: String*): Dataset[T] |
2.0.0 | 指定したカラムでソートする。 昇順・降順を指定したい場合はColumnを使う。 asc, desc, asc_nulls_first, asc_nulls_last, desc_nulls_first, desc_nulls_lastのいずれかを指定する。 |
val ds2 = ds.sort("name") |
+----+---+ |
sort(sortExprs: Column*): Dataset[T] |
2.0.0 | val ds2 = ds.sort($"age".desc) |
+----+---+ |
|
sortWithinPartitions(sortCol: String, sortCols:
String*): Dataset[T] |
2.0.0 | パーティション毎にソートする。 (HiveのSORT BYのようなもの) →全データのソートはsort |
||
orderBy(sortCol: String, sortCols: String*): Dataset[T] |
2.0.0 | sortと同じ。 | ||
groupBy(col1: String, cols: String*):
RelationalGroupedDataset |
2.0.0 | 指定されたカラム毎の集約処理を行う。 →全レコードの集約はagg |
val ds = Seq(Person("hoge", 20), Person("hoge",
21), Person("foo", 30)).toDS() |
+----+--------+ |
groupBy(cols: Column*): RelationalGroupedDataset |
2.0.0 | import org.apache.spark.sql.functions.sum |
+----+---+ |
|
rollup(col1: String, cols: String*):
RelationalGroupedDataset |
2.0.0 | |||
rollup(cols: Column*): RelationalGroupedDataset |
2.0.0 | |||
cube(col1: String, cols: String*):
RelationalGroupedDataset |
2.0.0 | |||
cube(cols: Column*): RelationalGroupedDataset |
2.0.0 | |||
groupByKey[K](func: T => K):
KeyValueGroupedDataset[K, T] |
2.0.0 | キーを生成する関数を渡し、そのキー毎に集約する。 集約処理自体はKeyValueGroupedDatasetのメソッドを呼ぶ。 →キー毎に重複排除する例 |
import org.apache.spark.sql.functions.expr |
+-----+---+---+ |
groupByKey[K](func: MapFunction[T, K], encoder:
Encoder[K]):
KeyValueGroupedDataset[K, T] |
2.0.0 | Java8の関数インターフェース(ラムダ式)用。 | ||
agg(aggExpr: (String, String), aggExprs: (String,
String)*): DataFrame |
2.0.0 | 全データを集約する。 →キー毎の集約はgroupBy →reduce |
val df = ds.agg("age" -> "sum") |
+--------+ |
agg(exprs: Map[String, String]): DataFrame |
2.0.0 | val map = Map("age" -> "sum") |
+--------+ |
|
agg(expr: Column, exprs:
Column*): DataFrame |
2.0.0 | import org.apache.spark.sql.functions.sum |
+---+ |
|
agg(exprs: java.util.Map[String, String]): DataFrame |
2.0.0 | Java用。 | ||
sample (withReplacement: Boolean, fraction: Double,
seed: Long): Dataset[T] |
1.6.0 | サンプリングする。(適当に行を抽出する) | val ds2 = ds.sample(false, 0.2) |
+----+---+ |
randomSplit (weights: Array[Double]): Array[Dataset[T]] |
2.0.0 | 適当に分割する。 | val dss = ds.randomSplit(Array(0.2, 0.8)) |
+----+---+ |
randomSplitAsList (weights: Array[Double], seed: Long):
java.util.List[Dataset[T]] |
2.0.0 | Java用。 |
他のDatasetを使用する処理。
メソッド | ver | 説明 | 例 | 実行結果 |
---|---|---|---|---|
join(right: Dataset[_]): DataFrame |
2.0.0 | インナージョイン。 結合条件は後続のwhereで記述する。 |
case class Master(name: String, sub: String) |
+----+---+----+---+ |
join(right: Dataset[_], usingColumn: String): DataFrame |
2.0.0 | インナージョイン。 (結合キーは1つ) |
case class Master(name: String, sub: String) |
+----+---+---+ |
join(right: Dataset[_], usingColumns: Seq[String]):
DataFrame |
2.0.0 | インナージョイン。 (結合キーは複数カラム指定可能) |
case class Master(name: String, sub: String) |
+----+---+---+ |
join(right: Dataset[_], usingColumns: Seq[String],
joinType: String): DataFrame |
2.0.0 | 結合タイプを指定する。 | case class Master(name: String, sub: String) |
+----+---+----+ |
join(right: Dataset[_], joinExprs:
Column): DataFrame |
2.0.0 | インナージョイン。 結合条件を指定する。 |
case class Master(name: String, sub: String) |
+----+---+----+---+ |
join(right: Dataset[_], joinExprs:
Column, joinType:
String): DataFrame |
2.0.0 | 結合タイプおよび結合条件を指定する。 | case class Master(name: String, sub: String) |
+----+---+----+----+ |
crossJoin(right: Dataset[_]): DataFrame |
2.1.0 | クロスジョイン。 結合条件は後続のwhereで記述する。 |
case class Master(name: String, sub: String) |
+----+---+----+---+ |
joinWith[U](other: Dataset[U], condition:
Column):
Dataset[(T, U)] |
1.6.0 | インナージョイン。 結合結果のレコードはタプル。 |
case class Master(name: String, sub: String) |
+---------+----------+ |
joinWith[U](other: Dataset[U], condition:
Column,
joinType: String): Dataset[(T, U)] |
1.6.0 | 結合タイプを指定する。 結合結果のレコードはタプル。 |
case class Master(name: String, sub: String) |
+---------+----------+ |
union (other: Dataset[T]): Dataset[T] |
2.0.0 | データの合流。 | val ds2 = Seq(Person("zzz", 99)).toDS() |
+----+---+ |
unionAll (other: Dataset[T]): Dataset[T] |
2.0.0 | 非推奨。→union | ||
intersect (other: Dataset[T]): Dataset[T] |
2.0.0 | 両方に存在するデータだけ残す。 | val ds2 = Seq(Person("hoge", 20), Person("foo",
30), Person("bar", 99)).toDS() |
+----+---+ |
except (other: Dataset[T]): Dataset[T] |
2.0.0 | otherに無いデータだけ残す。 | val ds2 = Seq(Person("hoge", 20), Person("bar",
99)).toDS() |
+----+---+ |
joinで結合キーとしてカラム名を指定するメソッドの場合は、双方のDatasetに同じ名前のカラムが存在する必要があるが、出力されるデータでは1カラムのみになる。
しかし結合条件を指定する方式(ds("name")===right("name")
)の場合、双方に同名のカラムがあると、両方とも残る。
そして後続処理でそのカラム名を指定すると、一意に定まらなくてエラーになる。
この場合、Datasetに名前(テーブル名相当)を付けておけば、その名前を使ってカラムを一意に識別することが出来る。
import spark.implicits._ val ds1 = Seq(Person("hoge", 20), Person("foo", 30), Person("bar", 40)).toDS().as("person") val ds2 = Seq(Master("hoge", "abc"), Master("foo", "def"), Master("zzz", "zzz")).toDS().as("master") val df = ds1.join(ds2) .where($"person.name" === $"master.name") .select("person.name", "age", "sub") df.show()
↓実行結果
+----+---+---+ |name|age|sub| +----+---+---+ |hoge| 20|abc| | foo| 30|def| +----+---+---+
インナージョインやレフトジョインでマスター側のデータに同一キーのレコードが複数有る場合、全て結合対象となる。
(Asakusa FrameworkのMasterJoinの場合、マスター側に複数レコードあると、どれか1レコードのみが結合対象となる)
データ量の少ないDataset(DataFrame)なら、ブロードキャスト結合(hash join)をすることが出来る。[2017-01-16]
import org.apache.spark.sql.functions.broadcast
val df = ds1.join(broadcast(ds2))
メソッド | ver | 説明 | 例 | 実行結果 |
---|---|---|---|---|
persist(): this.type |
1.6.0 | データを永続化する。 →永続化とチェックポイントの違い |
val ds2 = ds.persist() |
|
cache(): this.type |
1.6.0 | persistと同じ。 | ||
storageLevel: StorageLevel |
2.1.0 | 永続化のストレージの種類(DISK_ONLYとかMEMORY_ONLYとかNONEとか)を返す。 | ||
unpersist(): this.type |
1.6.0 | 永続化を解除する。 | ds2.unpersist() |
|
checkpoint(): Dataset[T] |
2.1.0 | チェックポイントを作成する。 SparkContextでチェックポイントのディレクトリー(HDFS)を指定しておく必要がある。 →永続化とチェックポイントの違い |
spark.sparkContext.setCheckpointDir("/tmp/spark-checkpoint") |
同一キーで複数レコードある場合に1レコードのみ残すこと(重複排除・distinct)を考えてみる。[2017-01-16]
一番単純なのは、dropDuplicatesを使う方法。
import spark.implicits._ val ds = Seq(Person("hoge", 30), Person("hoge", 20), Person("hoge", 40), Person("foo", 30)).toDS() val ds2 = ds.dropDuplicates("name") ds2.show()
↓実行結果
+----+---+ |name|age| +----+---+ | foo| 30| |hoge| 30| +----+---+
しかしこれだと、重複しているレコードの内のどのレコードが残るか分からない。
(先頭レコードが残るような感じがするけど、分散処理しているときにどれが先頭になるかは保証されるか?)
dropDuplicatesする前に全体をソートする方法は(少量データでは望んだ結果になったけど)、データがパーティションをまたがった場合に上手くいかないと思われる。
× val ds2 = ds.sort($"name", $"age".asc).dropDuplicates("name")キーでパーティションを分けてからソートしたら大丈夫そうな気がする。いまいち自信が無い^^;
val ds2 = ds.repartition($"name").sortWithinPartitions($"name", $"age".asc).dropDuplicates("name")
キー毎にグルーピングして値一覧を処理すれば、そこで自分が望むデータだけを残すことが出来る。
val ds2 = ds.groupByKey(_.name).mapGroups((k, i) => i.toSeq.sortBy(_.age).head)
iはIterator。ageでソートしたいがIteratorにはsortByが無いので、一旦Seqに変換している。
そして、ageでソートして先頭データ(つまりageが最小のPerson)を取得している。
↓実行結果
+----+---+ |name|age| +----+---+ | foo| 30| |hoge| 20| +----+---+
ただ、自前でIteratorをソートしているので、件数が多いときは嫌なことになりそうな気がする。
自前でソートするとデータ件数が多い場合に不安だし、そもそもソートはシャッフル処理に任せたい気がするので、キーでパーティションを分けてからソートしておけば大丈夫か?(groupByKeyでパーティションが変わる可能性があるような気もする…)
val ds2 = ds.repartition($"name").sortWithinPartitions($"name", $"age".asc).groupByKey(_.name).mapGroups((k, i) => i.next)
並び順は気にせずに最小のものを探すという方法も考えられる。
val ds2 = ds.groupByKey(_.name).mapGroups { (k, i) => var min: Person = null i.foreach { p => if ((min eq null) || p.age < min.age) { min = p } } min }
しかし、なんだかとてもクドい感じ(爆)
IteratorにsortByは無いがminBy(やmaxBy)はあるので、これを使うのが一番シンプルかも。[2017-01-23]
val ds2 = ds.groupByKey(_.name).mapGroups((k, i) => i.minBy(_.age))
minByなら(toSeq.sortByと違ってソートせずに)データ全体を1回走査するだけなので、ソートを伴う操作より早いかも?
自分で最小のものを探すなら、reduceを使う方法もある。
val ds2 = ds.groupByKey(_.name).reduceGroups((p, r) => if (p.age <= r.age) p else r).map(_._2)
(reduceGroupsはキーと値(この例ではPerson)のタプルを返すので、mapを使ってタプルの2番目(Person)だけ取得している)
reduce内でif式を使いたくないなら、以下のような方法も考えられる。
// 新しいPersonを作って返す val ds2 = ds.groupByKey(_.name).reduceGroups((p, r) => Person(p.name, math.min(p.age, r.age))).map(_._2)
// 一部だけを変更したPersonを作って返す val ds2 = ds.groupByKey(_.name).reduceGroups((p, r) => p.copy(age = math.min(p.age, r.age))).map(_._2)
ケースクラスが可変(case class Person(name: String, var age: Long)
)であれば、該当フィールドだけ書き換えることも出来る。[2017-01-17]
val ds2 = ds.groupByKey(_.name).reduceGroups { (p, r) => p.age = math.min(p.age, r.age); p }.map(_._2)
Window関数を使うという方法もあるようだ。(参考: stackoverflowのSPARK DataFrame: select the first row of each group)
import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.row_number
val w = Window.partitionBy($"name").orderBy($"age".asc) val ds2 = ds.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn").as[Person]
行番号を表すrow_numberカラムをwithColumnで追加し、それが1になるもの(つまり先頭行)だけをwhereで抽出している。
(row_numberカラムは最終的には不要なので、dropで削除している)
(withColumnを使うとDataFrameになってしまうので、(drop後に)as[Person]を呼び出してDataset[Person]にしている)
参考までに、Asakusa Framework(Java)のGroupSortだとグルーピングとソートが同時に指定できるので、以下のように書ける。
@GroupSort public void distinctByName(@Key(group = {"name"}, order = {"age ASC"}) List<Person> personList, Result<Person> out) { out.add(personList.get(0)); }
Asakusa Framewokで書いたアプリケーションをSparkに変換して動くAsakusa on Sparkがこれを実現できているんだから、何か方法はあると思うんだけど^^;
Dataset#writeでファイルに出力することが出来る。[2017-01-31]
import org.apache.spark.sql.SaveMode
def writeCsv(ds: Dataset[Person], path: String): Unit = { ds.write.mode(SaveMode.Overwrite).csv(path) }
pathは出力先ディレクトリー。この下にパーティションの個数分の複数のファイル(拡張子はcsv)が出力される。
modeをOverwrite(上書き)にしておかないと、ディレクトリーが既に存在している場合にエラーになる。
ただ、このcsvメソッドでデータを出力すると、各フィールドがtrimされてしまう。(つまり、空白1文字を出力したいと思っても、空文字列になってしまう)
write(DataFrameWriter)にはoptionを設定することは出来る(def
csv(path: String)メソッドのScaladocを参照)のだけれども、trimしないようにするオプションは無さげ。
常にダブルクォーテーションで囲むようにすればtrimされないかと思って試してみたけど、trimされてからダブルクォーテーションで囲まれてたorz[2017-02-02]
ds.write.option("quoteAll", true).mode(SaveMode.Overwrite).csv(path)
ちなみにsortByというメソッドで出力順を制御できそうな感じがしたが、csvファイルは対象外だったorz[2017-02-01]
csvメソッドで出力すると各フィールドがtrimされてしまうが、自前で単一のStringに変換してテキストファイルとして出力することは出来る。[2017-02-01]
import org.apache.spark.sql.SaveMode
def writeCsv(ds: Dataset[Person], path: String): Unit = { ds.map(_.productIterator.mkString(",")).write.mode(SaveMode.Overwrite).text(path) }
ケースクラスはProductトレイトをミックスインしているので、productIteratorで各フィールドの値が取れる。
これをmkStringで“カンマ区切りの1つのString”に変換する。
pathは出力先ディレクトリー。この下にパーティションの個数分の複数のファイル(拡張子はtxt)が出力される。
write前にrepartitionでパーティション数を1にしてやれば、出力するファイルを1つにすることは出来る。
def writeCsv(ds: Dataset[Person], path: String): Unit = { ds.map(_.productIterator.mkString(",")).repartition(1).write.mode(SaveMode.Overwrite).text(path) }
ちなみに、write(DataFrameWriter)にはpartitionByというパーティションを変更できそうなメソッドがあるのだが、出力ファイルを1つにすることは出来なかった。
csvメソッドやtextメソッドではファイル名が自由に出来ないので、自前でWriterを使って出力しようとしてみた。
import java.nio.charset.Charset import java.nio.file.Files import java.nio.file.Paths
def writeByWriter(ds: Dataset[Person], path: String): Unit = { Files.createDirectories(Paths.get(path)) val n = path.lastIndexOf('/'); val name = path.substring(n + 1) + ".csv" val writer = Files.newBufferedWriter(Paths.get(path, name), Charset.forName("UTF-8")) try { ds.map(_.productIterator.mkString(",")).repartition(1).foreach { s => writer.write(s) writer.write('\n') } } finally { writer.close() }
foreachでBufferedWirterにwriteする。
しかしいざ実行すると、BufferedWirterがシリアライザブルでない為、例外が発生する。
(foreachの中でwriterを使おうとしているので、シリアライズしてexecutorに渡そうとしているのだと思う)
import java.nio.charset.Charset import java.nio.file.Files import java.nio.file.Paths
def writeByIterator(ds: Dataset[Person], path: String): Unit = { Files.createDirectories(Paths.get(path)) val n = path.lastIndexOf('/'); val name = path.substring(n + 1) + ".csv" val writer = Files.newBufferedWriter(Paths.get(path, name), Charset.forName("UTF-8")) try { import scala.collection.JavaConverters._ ds.map(_.productIterator.map { case null => "" case s: String => if (s.contains(",") || s.contains("\r") || s.contains("\n")) '"' + s.replaceAll("\"", "\"\"") + '"' else s case x => x.toString }.mkString(",")) .toLocalIterator().asScala.foreach { s => writer.write(s) writer.write("\r\n") } } finally { writer.close() } }
Datasetのforeachの中にはwriterを置くことが出来ないので、toLocalIteratorメソッドでDatasetからローカルのIteratorに変換し、そのforeachでBufferedWirterに出力する。