package com.databricks.spark.sql.perf.mllib.classification;

import com.databricks.spark.sql.perf.MLMetric;
import com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm;
import com.databricks.spark.sql.perf.mllib.MLBenchContext;
import com.databricks.spark.sql.perf.mllib.OptionImplicits$;
import com.databricks.spark.sql.perf.mllib.ScoringWithEvaluator;
import com.databricks.spark.sql.perf.mllib.TestFromTraining;
import com.databricks.spark.sql.perf.mllib.TrainingSetFromTransformer;
import com.databricks.spark.sql.perf.mllib.data.DataGenerator$;
import java.util.Random;
import org.apache.spark.ml.ModelBuilderSSP$;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import scala.Array$;
import scala.Function0;
import scala.Predef$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Map;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric$DoubleIsFractional$;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

/* compiled from: NaiveBayes.scala */
/* loaded from: input_file:com/databricks/spark/sql/perf/mllib/classification/NaiveBayes$.class */
public final class NaiveBayes$ implements BenchmarkAlgorithm, TestFromTraining, TrainingSetFromTransformer, ScoringWithEvaluator {
    public static NaiveBayes$ MODULE$;

    static {
        new NaiveBayes$();
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm, com.databricks.spark.sql.perf.mllib.ScoringWithEvaluator
    public final MLMetric score(MLBenchContext mLBenchContext, Dataset<Row> dataset, Transformer transformer) {
        MLMetric score;
        score = score(mLBenchContext, dataset, transformer);
        return score;
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm, com.databricks.spark.sql.perf.mllib.TrainingSetFromTransformer
    public final Dataset<Row> trainingDataSet(MLBenchContext mLBenchContext) {
        Dataset<Row> trainingDataSet;
        trainingDataSet = trainingDataSet(mLBenchContext);
        return trainingDataSet;
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm, com.databricks.spark.sql.perf.mllib.TestFromTraining
    public final Dataset<Row> testDataSet(MLBenchContext mLBenchContext) {
        Dataset<Row> testDataSet;
        testDataSet = testDataSet(mLBenchContext);
        return testDataSet;
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm
    public String name() {
        String name;
        name = name();
        return name;
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm
    public Map<String, Function0<?>> testAdditionalMethods(MLBenchContext mLBenchContext, Transformer transformer) {
        Map<String, Function0<?>> testAdditionalMethods;
        testAdditionalMethods = testAdditionalMethods(mLBenchContext, transformer);
        return testAdditionalMethods;
    }

    @Override // com.databricks.spark.sql.perf.mllib.TrainingSetFromTransformer
    public Dataset<Row> initialData(MLBenchContext mLBenchContext) {
        Random newGenerator = mLBenchContext.newGenerator();
        int i = 20;
        return DataGenerator$.MODULE$.generateMixedFeatures(mLBenchContext.sqlContext(), OptionImplicits$.MODULE$.oL2L(mLBenchContext.params().numExamples()), mLBenchContext.seed(), OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().numPartitions()), (int[]) ((TraversableOnce) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().numFeatures())).map(i2 -> {
            return 2 + newGenerator.nextInt(i - 2);
        }, IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Int()));
    }

    @Override // com.databricks.spark.sql.perf.mllib.TrainingSetFromTransformer
    public Transformer trueModel(MLBenchContext mLBenchContext) {
        Random newGenerator = mLBenchContext.newGenerator();
        double[] dArr = (double[]) ((TraversableOnce) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().numClasses())).map(i -> {
            return newGenerator.nextDouble() + 1.0E-5d;
        }, IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Double());
        double log = package$.MODULE$.log(BoxesRunTime.unboxToDouble(new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dArr)).sum(Numeric$DoubleIsFractional$.MODULE$)));
        double[] dArr2 = (double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dArr)).map(d -> {
            return package$.MODULE$.log(d) - log;
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
        double d2 = 0.7d;
        double[][] dArr3 = (double[][]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) Array$.MODULE$.tabulate(OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().numClasses()), obj -> {
            return $anonfun$trueModel$3(d2, mLBenchContext, BoxesRunTime.unboxToInt(obj));
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE))))).map(dArr4 -> {
            return (double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dArr4)).map(d3 -> {
                return package$.MODULE$.log(d3);
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE))));
        return ModelBuilderSSP$.MODULE$.newNaiveBayesModel(Vectors$.MODULE$.dense(dArr2), new DenseMatrix(OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().numClasses()), OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().numFeatures()), (double[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(dArr3)).flatten(dArr5 -> {
            return Predef$.MODULE$.wrapDoubleArray(dArr5);
        }, ClassTag$.MODULE$.Double()), true));
    }

    @Override // com.databricks.spark.sql.perf.mllib.BenchmarkAlgorithm
    public PipelineStage getPipelineStage(MLBenchContext mLBenchContext) {
        return new org.apache.spark.ml.classification.NaiveBayes().setSmoothing(OptionImplicits$.MODULE$.oD2D(mLBenchContext.params().smoothing()));
    }

    @Override // com.databricks.spark.sql.perf.mllib.ScoringWithEvaluator
    public Evaluator evaluator(MLBenchContext mLBenchContext) {
        return new MulticlassClassificationEvaluator();
    }

    public static final /* synthetic */ double[] $anonfun$trueModel$3(double d, MLBenchContext mLBenchContext, int i) {
        double oI2I = (1 - d) / (OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().numFeatures()) - 1);
        double[] dArr = (double[]) Array$.MODULE$.fill(OptionImplicits$.MODULE$.oI2I(mLBenchContext.params().numFeatures()), () -> {
            return oI2I;
        }, ClassTag$.MODULE$.Double());
        dArr[i] = d;
        return dArr;
    }

    private NaiveBayes$() {
        MODULE$ = this;
        BenchmarkAlgorithm.$init$(this);
        TestFromTraining.$init$(this);
        TrainingSetFromTransformer.$init$(this);
        ScoringWithEvaluator.$init$(this);
    }
}
