農場IoTデータを用いて分類モデル作成

* pysparkによるグリッドサーチ
* 交差検証
* pipeline
* ロジスティクス回帰
* 決定木
In [11]:
import findspark
findspark.init('/home/yoshi-1/spark-3.1.1-bin-hadoop2.7')

from pyspark.sql import SparkSession
from pyspark.sql.types import *

from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.pipeline import Pipeline
from pyspark.ml.tuning import ParamGridBuilder
from pyspark.ml.tuning import CrossValidator

sparksessionのインスタンス化とデータ読み込み

In [12]:
# sparksessionのインスタンス化
ss = SparkSession \
            .builder \
            .appName("Classsification") \
            .enableHiveSupport() \
            .getOrCreate()
In [13]:
# 読み込むcsvファイルのスキーマを定義
struct = StructType([
            StructField('Year', StringType(), False),
            StructField('FarmID', DoubleType(), False),
            StructField('MeanHighestTemperature', DoubleType(), False),
            StructField('MeanMinimumtemperature', DoubleType(), False),
            StructField('MeanWhc', DoubleType(), False),
            StructField('MeanDaylightHours', DoubleType(), False),
            StructField('MeanDayOfSoilAcidityRange', DoubleType(), False),
            StructField('TotalYield', DoubleType(), False),
            StructField('Area', DoubleType(), False),
            StructField('YieldPerA', DoubleType(), False),
            StructField('label', DoubleType(), False)
        ])
In [14]:
# csv読み込み
df5 = ss.read.csv('./batchAnalysticsData_train_5.csv', 
                    header=True,
                     encoding='UTF-8',
                     schema=struct)

df5.show(5, truncate=False)
+----+------+----------------------+----------------------+-------+-----------------+-------------------------+----------+----+---------+-----+
|Year|FarmID|MeanHighestTemperature|MeanMinimumtemperature|MeanWhc|MeanDaylightHours|MeanDayOfSoilAcidityRange|TotalYield|Area|YieldPerA|label|
+----+------+----------------------+----------------------+-------+-----------------+-------------------------+----------+----+---------+-----+
|2007|1.0   |6.93                  |-1.3                  |14.17  |171.12           |18.0                     |1423222.21|4.5 |3162.72  |0.0  |
|2007|2.0   |7.77                  |-0.63                 |15.83  |172.62           |18.0                     |1457585.51|5.0 |2915.17  |0.0  |
|2007|3.0   |7.77                  |-1.13                 |14.5   |169.28           |18.0                     |1150258.61|3.0 |3834.2   |1.0  |
|2007|4.0   |6.77                  |0.03                  |16.67  |170.12           |19.0                     |2327859.58|6.0 |3879.77  |1.0  |
|2007|5.0   |6.93                  |-1.47                 |17.5   |173.78           |18.0                     |1448612.55|4.0 |3621.53  |1.0  |
+----+------+----------------------+----------------------+-------+-----------------+-------------------------+----------+----+---------+-----+
only showing top 5 rows

訓練データとテストデータに分割

In [15]:
df5TrainData, df5TestData = df5.randomSplit([0.7, 0.3], 50)

決定木モデルの作成

・pipelineの利用

・ GridSearchとCross Validationの使用

In [16]:
# 特徴量のベクトル化ステージ
assemblerForDTC = VectorAssembler(inputCols=[
                    "MeanHighestTemperature",
                    "MeanMinimumtemperature",
                    "MeanWhc",
                    "MeanDaylightHours",
                    "MeanDayOfSoilAcidityRange",
                    ], outputCol="features")
In [17]:
# 決定木のステージ
classifierByDT = DecisionTreeClassifier().setLabelCol("label").setFeaturesCol("features")
In [18]:
# pipeline作成
pipelineForDTC = Pipeline(stages=[assemblerForDTC, classifierByDT])
In [19]:
# グリッドサーチ用インスタンスの生成
# サーチ対象
#     ・maxBins:連続値を不連続値にビン分けするときのビンの数
#     ・maxDepth:木の深さ
paramGridForDTC = ParamGridBuilder()\
                    .addGrid(
                        classifierByDT.maxBins,
                        [10, 20, 30, 40, 50])\
                    .addGrid(
                        classifierByDT.maxDepth,
                        [2, 3, 4]).build()
In [20]:
# Evaluator(モデル評価用インスタンス)の生成
evaluatorForDTC = BinaryClassificationEvaluator()\
                    .setLabelCol("label")\
                    .setRawPredictionCol(classifierByDT.getRawPredictionCol())\
                    .setMetricName("areaUnderROC")
In [21]:
# クロスバリデーション用インスタンスの生成
crossValForDTC = CrossValidator()\
                    .setEstimator(pipelineForDTC)\
                    .setEvaluator(evaluatorForDTC)\
                    .setEstimatorParamMaps(paramGridForDTC)\
                    .setNumFolds(10)
In [22]:
# クロスバリデーションの実施
crossValForDTCModel = crossValForDTC.fit(df5TrainData)
In [23]:
# 訓練データで予測を行い、AUCを出力
predictionByDTC = crossValForDTCModel.transform(df5TrainData)
aucByDTC = evaluatorForDTC.evaluate(predictionByDTC)
print(" AUC-TrainData(DecisionTree): ", aucByDTC)
 AUC-TrainData(DecisionTree):  0.99375

ロジスティクス回帰モデルの作成

In [24]:
from pyspark.ml.feature import StandardScaler
from pyspark.ml.classification import LogisticRegression 
In [32]:
# 特徴量選択のため、候補となる組み合わせ分のVectorAssemblerを定義する
# 候補
# 1. 畑×土壌酸度範囲内日数×平均最高気温
# 2. 畑×土壌酸度範囲内日数×平均最低気温
# 3. 畑×土壌酸度範囲内日数×平均含水分量
# 4. 畑×土壌酸度範囲内日数×日照合計時間

assemblerForLC = []
# 1
assemblerForLC.append(
    VectorAssembler(inputCols=[
        "FarmID",
        "MeanDayOfSoilAcidityRange",
        "MeanHighestTemperature",
        ],
        outputCol="features")
    )
# 2
assemblerForLC.append(
    VectorAssembler(inputCols=[
        "FarmID",
        "MeanDayOfSoilAcidityRange",
        "MeanMinimumtemperature",
        ],
        outputCol="features")
    )
# 3
assemblerForLC.append(
    VectorAssembler(inputCols=[
        "FarmID",
        "MeanDayOfSoilAcidityRange",
        "MeanWhc",
        ],
        outputCol="features")
    )
# 4
assemblerForLC.append(
    VectorAssembler(inputCols=[
        "FarmID",
        "MeanDayOfSoilAcidityRange",
        "MeanDaylightHours",
        ],
        outputCol="features")
    )
In [33]:
# Pipelineの定義

# 標準化ステージ
scalerForLC = StandardScaler(
                inputCol="features",
                outputCol="standardedFeature",
                withStd=True, withMean=True)

# ロジスティクス回帰ステージ
logisticClassification = LogisticRegression().setLabelCol("label")\
                                            .setFeaturesCol("standardedFeature")\
                                            .setStandardization(True)

# 特徴量組み合わせごとのpipeline入れるリスト
pipelineForLC = []

# 特徴量組み合わせごとにpipelineを生成し、リストに入れる
for assembler in assemblerForLC:
    pipelineForLC.append(
        Pipeline(
            stages=[
                assembler,
                scalerForLC,
                logisticClassification
                ]
        )
    )
In [34]:
# グリッドサーチ、クロスバリデーション

# グリッドサーチ生成
# 最適化するパラメータの種類と、検証対象の値をセット
paramGridForLC = ParamGridBuilder()\
                    .addGrid(
                        logisticClassification.regParam,
                        [0.001, 0.01, 0.1, 1.0, 10.0, 100.0])\
                    .addGrid(
                        logisticClassification.maxIter,
                        [10, 100, 1000])\
                    .build()

# Evaluatorの生成
evaluatorForLC = BinaryClassificationEvaluator().setLabelCol("label").setMetricName("areaUnderROC")

# クロスバリデーションの生成
crossValidatorForLC = []
for pipeline in pipelineForLC:
    crossValidatorForLC.append(
        CrossValidator().setEstimator(pipeline).setEvaluator(evaluatorForLC)\
            .setEstimatorParamMaps(paramGridForLC).setNumFolds(10))
In [35]:
%%time

# モデルを作成し、訓練データをインプットに予測を行う

# クロスバリデーションモデルの生成
modelForLC = []
for crossValidator in crossValidatorForLC:
    modelForLC.append(crossValidator.fit(df5TrainData))
    
# 訓練データで予測を行い、AUCを取得し、出力
print(" -- df5TrainData --")
df5TrainData.show()
print(" -- AUC-TrainData(Logistic Regression) --")
for i, model in enumerate(modelForLC):
    prediction = model.transform(df5TrainData)
    auc = evaluatorForLC.evaluate(prediction)
    print(i, auc)
print("")
 -- df5TrainData --
+----+------+----------------------+----------------------+-------+-----------------+-------------------------+----------+----+---------+-----+
|Year|FarmID|MeanHighestTemperature|MeanMinimumtemperature|MeanWhc|MeanDaylightHours|MeanDayOfSoilAcidityRange|TotalYield|Area|YieldPerA|label|
+----+------+----------------------+----------------------+-------+-----------------+-------------------------+----------+----+---------+-----+
|2007|   1.0|                  6.93|                  -1.3|  14.17|           171.12|                     18.0|1423222.21| 4.5|  3162.72|  0.0|
|2007|   2.0|                  7.77|                 -0.63|  15.83|           172.62|                     18.0|1457585.51| 5.0|  2915.17|  0.0|
|2007|   3.0|                  7.77|                 -1.13|   14.5|           169.28|                     18.0|1150258.61| 3.0|   3834.2|  1.0|
|2007|   4.0|                  6.77|                  0.03|  16.67|           170.12|                     19.0|2327859.58| 6.0|  3879.77|  1.0|
|2007|   5.0|                  6.93|                 -1.47|   17.5|           173.78|                     18.0|1448612.55| 4.0|  3621.53|  1.0|
|2008|   1.0|                  6.77|                 -0.47|  14.67|           176.62|                     20.0|1817120.47| 4.5|  4038.05|  1.0|
|2008|   2.0|                  6.93|                 -1.13|  15.33|           166.62|                     20.0|2111691.74| 5.0|  4223.38|  1.0|
|2008|   3.0|                   8.1|                 -1.13|   15.5|           163.12|                     21.0|1225468.03| 3.0|  4084.89|  1.0|
|2008|   4.0|                  7.93|                  -0.8|   17.0|           174.12|                     19.0|2784617.62| 6.0|  4641.03|  1.0|
|2009|   1.0|                  7.43|                 -0.13|  16.33|           169.78|                     19.0| 1515486.0| 4.5|  3367.75|  1.0|
|2009|   2.0|                   6.6|                  0.03|  18.17|           175.45|                     21.0|2295080.42| 5.0|  4590.16|  1.0|
|2009|   3.0|                   7.1|                 -0.47|  17.33|           171.62|                     19.0|1235863.09| 3.0|  4119.54|  1.0|
|2009|   4.0|                   8.1|                   0.2|  18.33|           166.28|                     21.0|2176601.63| 6.0|  3627.67|  0.0|
|2010|   1.0|                  7.27|                  -0.8|  15.67|           169.62|                     19.0|1881345.57| 4.5|  4180.77|  1.0|
|2010|   2.0|                   8.1|                  -0.8|   15.5|           167.28|                     18.0|1189336.27| 5.0|  2378.67|  0.0|
|2010|   4.0|                  7.27|                   0.2|  15.67|           169.45|                     21.0|2496696.69| 6.0|  4161.16|  1.0|
|2010|   5.0|                  6.93|                 -0.13|  17.17|           171.12|                     19.0|1385340.26| 4.0|  3463.35|  1.0|
|2011|   1.0|                  6.77|                 -0.47|  16.67|           168.95|                     19.0|1582906.76| 4.5|  3517.57|  1.0|
|2011|   2.0|                  7.43|                 -0.13|  16.83|           172.45|                     18.0|2566974.75| 5.0|  5133.95|  1.0|
|2011|   3.0|                  7.93|                 -0.13|  17.33|           167.45|                     20.0| 829522.99| 3.0|  2765.08|  0.0|
+----+------+----------------------+----------------------+-------+-----------------+-------------------------+----------+----+---------+-----+
only showing top 20 rows

 -- AUC-TrainData(Logistic Regression) --
0 0.7375
1 0.5958333333333334
2 0.5958333333333333
3 0.6583333333333333

CPU times: user 19.6 s, sys: 10.4 s, total: 30 s
Wall time: 3min 45s
In [ ]:
# 上記結果より、1の組み合わせ(畑×土壌酸度範囲内日数×平均最高気温)を選択する

モデル選択

決定木、ロジスティクス回帰にテストデータ渡してAUCで評価

In [36]:
# 決定木モデルにテストデータ渡して、AUCを取得
predictionTestDataByDTC = crossValForDTCModel.transform(df5TestData)
aucTestDataByDTC = evaluatorForDTC.evaluate(predictionTestDataByDTC)
print("-- AUC-TestData(Decision Tree) --")
print(aucTestDataByDTC, "\n")
-- AUC-TestData(Decision Tree) --
0.59375 

In [37]:
# ロジスティクス回帰モデルにテストデータ渡して、AUCを取得
predictionTestDataByLC = modelForLC[0].transform(df5TestData)
aucTestDataByLC = evaluatorForLC.evaluate(predictionTestDataByLC)
print("-- AUC-TestData(Logistic Regression) --")
print(aucTestDataByLC, "\n")
-- AUC-TestData(Logistic Regression) --
0.625 

In [ ]:
# 上記より、ロジスティクス回帰のほうが精度が高いので、ロジスティクス回帰を選択する
In [38]:
# 未知データを用いてロジスティクス回帰で予測してみる

# 未知データよりDataFrame生成
df5Predict = ss.read.csv('./batchAnalysticsData_predict_5.csv',
                        header=True, encoding="UTF-8", schema=struct)
df5Predict.show(10)
+----+------+----------------------+----------------------+-------+-----------------+-------------------------+----------+----+---------+-----+
|Year|FarmID|MeanHighestTemperature|MeanMinimumtemperature|MeanWhc|MeanDaylightHours|MeanDayOfSoilAcidityRange|TotalYield|Area|YieldPerA|label|
+----+------+----------------------+----------------------+-------+-----------------+-------------------------+----------+----+---------+-----+
|2017|   1.0|                  7.77|                 -0.63|   15.0|           173.95|                     20.0|      null|null|     null| null|
|2017|   2.0|                  6.43|                 -1.13|  16.17|           175.78|                     19.0|      null|null|     null| null|
|2017|   3.0|                   6.6|                 -0.13|   16.0|           173.78|                     21.0|      null|null|     null| null|
|2017|   4.0|                  7.77|                  -0.8|  13.67|           170.45|                     20.0|      null|null|     null| null|
|2017|   5.0|                  7.43|                  0.03|  17.83|           169.45|                     17.0|      null|null|     null| null|
+----+------+----------------------+----------------------+-------+-----------------+-------------------------+----------+----+---------+-----+

In [39]:
# 予測
print("-- AUC-FutureData(Logistic Regression) --")
predictionFutureDataByLC = modelForLC[0].transform(df5Predict)
predictionFutureDataByLC.select("FarmID", "probability", "prediction").show()
-- AUC-FutureData(Logistic Regression) --
+------+--------------------+----------+
|FarmID|         probability|prediction|
+------+--------------------+----------+
|   1.0|[0.53259603960872...|       0.0|
|   2.0|[0.06610868314425...|       1.0|
|   3.0|[0.03963171210502...|       1.0|
|   4.0|[0.23382722347383...|       1.0|
|   5.0|[0.16286436053485...|       1.0|
+------+--------------------+----------+

In [ ]:
 
In [ ]:
 
In [ ]: