I would like to use the Dataset.map function to transform the rows of my dataset. The sample looks like this:
val result = testRepository.readTable(db, tableName)
.map(testInstance.doSomeOperation)
.count()
where testInstance is a class that extends java.io.Serializable, but testRepository does extend this. The code throws the following error:
Job aborted due to stage failure.
Caused by: NotSerializableException: TestRepository
Question
I understand why testInstance.doSomeOperation needs to be serializable, since it's inside the map and will be distributed to the Spark workers. But why does testRepository needs to be serialized? I don't see why that is necessary for the map. Changing the definition to class TestRepository extends java.io.Serializable solves the issue, but that is not desirable in the larger context of the project.
Is there a way to make this work without making TestRepository serializable, or why is it required to be serializable?
Minimal working example
Here's a full example with the code from both classes that reproduces the NotSerializableException:
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
case class MyTableSchema(id: String, key: String, value: Double)
val db = "temp_autodelete"
val tableName = "serialization_test"
class TestRepository extends java.io.Serializable {
def readTable(database: String, tableName: String): Dataset[MyTableSchema] = {
spark.table(f"$database.$tableName")
.as[MyTableSchema]
}
}
val testRepository = new TestRepository()
class TestClass() extends java.io.Serializable {
def doSomeOperation(row: MyTableSchema): MyTableSchema = {
row
}
}
val testInstance = new TestClass()
val result = testRepository.readTable(db, tableName)
.map(testInstance.doSomeOperation)
.count()
The reason why is because your
mapoperation is reading from something that already takes place on the executors.If you look at your pipeline:
The first thing you do is
testRepository.readTable(db, tableName). If we look inside of thereadTablemethod, we see that you are doing aspark.tableoperation in there. If we look at the function signature of this method from the API docs, we see the following function signature:This is not an operation that solely takes place on the driver (imagine reading in a file of >1TB while only taking place on the driver), and it creates a Dataframe (which is by itself a distributed dataset). That means that the
testRepository.readTable(db, tableName)function needs to be distributed, and so yourtestRepositoryobject needs to be distributed.Hope this helps you!