package com.databricks.spark.sql.perf.mllib.classification;

import com.databricks.spark.sql.perf.MLMetric;
import com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm;
import com.databricks.spark.sql.perf.mllib.MLBenchContext;
import com.databricks.spark.sql.perf.mllib.OptionImplicits$;
import com.databricks.spark.sql.perf.mllib.ScoringWithEvaluator;
import com.databricks.spark.sql.perf.mllib.TestFromTraining;
import com.databricks.spark.sql.perf.mllib.TrainingSetFromTransformer;
import com.databricks.spark.sql.perf.mllib.TreeOrForestClassifier;
import com.databricks.spark.sql.perf.mllib.TreeOrForestEstimator;
import com.databricks.spark.sql.perf.mllib.TreeOrForestEstimator$;
import org.apache.spark.ml.ModelBuilderSSP$;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.GBTClassifier;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import scala.Function0;
import scala.collection.immutable.Map;

/* compiled from: GBTClassification.scala */
/* loaded from: input_file:com/databricks/spark/sql/perf/mllib/classification/GBTClassification$.class */
public final class GBTClassification$ implements BenchmarkAlgorithm, TreeOrForestClassifier {
    public static GBTClassification$ MODULE$;

    static {
        new GBTClassification$();
    }

    @Override // com.databricks.spark.sql.perf.mllib.TreeOrForestClassifier, com.databricks.spark.sql.perf.mllib.ScoringWithEvaluator
    public Evaluator evaluator(MLBenchContext mLBenchContext) {
        Evaluator evaluator;
        evaluator = evaluator(mLBenchContext);
        return evaluator;
    }

    @Override // com.databricks.spark.sql.perf.mllib.TreeOrForestEstimator, com.databricks.spark.sql.perf.mllib.TrainingSetFromTransformer
    public Dataset<Row> initialData(MLBenchContext mLBenchContext) {
        Dataset<Row> initialData;
        initialData = initialData(mLBenchContext);
        return initialData;
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm, com.databricks.spark.sql.perf.mllib.ScoringWithEvaluator
    public final MLMetric score(MLBenchContext mLBenchContext, Dataset<Row> dataset, Transformer transformer) {
        MLMetric score;
        score = score(mLBenchContext, dataset, transformer);
        return score;
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm, com.databricks.spark.sql.perf.mllib.TrainingSetFromTransformer
    public final Dataset<Row> trainingDataSet(MLBenchContext mLBenchContext) {
        Dataset<Row> trainingDataSet;
        trainingDataSet = trainingDataSet(mLBenchContext);
        return trainingDataSet;
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm, com.databricks.spark.sql.perf.mllib.TestFromTraining
    public final Dataset<Row> testDataSet(MLBenchContext mLBenchContext) {
        Dataset<Row> testDataSet;
        testDataSet = testDataSet(mLBenchContext);
        return testDataSet;
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm
    public String name() {
        String name;
        name = name();
        return name;
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm
    public Map<String, Function0<?>> testAdditionalMethods(MLBenchContext mLBenchContext, Transformer transformer) {
        Map<String, Function0<?>> testAdditionalMethods;
        testAdditionalMethods = testAdditionalMethods(mLBenchContext, transformer);
        return testAdditionalMethods;
    }

    @Override // com.databricks.spark.sql.perf.mllib.TreeOrForestClassifier, com.databricks.spark.sql.perf.mllib.TrainingSetFromTransformer
    public Transformer trueModel(MLBenchContext mLBenchContext) {
        return ModelBuilderSSP$.MODULE$.newDecisionTreeClassificationModel(OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().depth()) + 1, OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().numClasses()), TreeOrForestEstimator$.MODULE$.getFeatureArity(mLBenchContext), mLBenchContext.seed());
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm
    public PipelineStage getPipelineStage(MLBenchContext mLBenchContext) {
        return new GBTClassifier().setMaxDepth(OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().depth())).setMaxIter(OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().maxIter())).setSeed(mLBenchContext.seed());
    }

    private GBTClassification$() {
        MODULE$ = this;
        BenchmarkAlgorithm.$init$(this);
        TestFromTraining.$init$(this);
        TrainingSetFromTransformer.$init$(this);
        ScoringWithEvaluator.$init$(this);
        TreeOrForestEstimator.$init$((TreeOrForestEstimator) this);
        TreeOrForestClassifier.$init$((TreeOrForestClassifier) this);
    }
}
