package com.databricks.spark.sql.perf.mllib;

import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.attribute.NominalAttribute$;
import org.apache.spark.ml.attribute.NumericAttribute$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.Some;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: BenchmarkAlgorithm.scala */
@ScalaSignature(bytes = "\u0006\u0001I3\u0001\"\u0002\u0004\u0011\u0002\u0007\u00051c\u0013\u0005\u00065\u0001!\ta\u0007\u0005\u0006?\u00011\t\u0002\t\u0005\u0006\u007f\u00011\t\u0002\u0011\u0005\u0006\u0011\u0002!)%\u0013\u0002\u001b)J\f\u0017N\\5oON+GO\u0012:p[R\u0013\u0018M\\:g_JlWM\u001d\u0006\u0003\u000f!\tQ!\u001c7mS\nT!!\u0003\u0006\u0002\tA,'O\u001a\u0006\u0003\u00171\t1a]9m\u0015\tia\"A\u0003ta\u0006\u00148N\u0003\u0002\u0010!\u0005QA-\u0019;bEJL7m[:\u000b\u0003E\t1aY8n\u0007\u0001\u0019\"\u0001\u0001\u000b\u0011\u0005UAR\"\u0001\f\u000b\u0003]\tQa]2bY\u0006L!!\u0007\f\u0003\r\u0005s\u0017PU3g\u0003\u0019!\u0013N\\5uIQ\tA\u0004\u0005\u0002\u0016;%\u0011aD\u0006\u0002\u0005+:LG/A\u0006j]&$\u0018.\u00197ECR\fGCA\u0011:!\t\u0011cG\u0004\u0002$g9\u0011A%\r\b\u0003K=r!A\n\u0017\u000f\u0005\u001dRS\"\u0001\u0015\u000b\u0005%\u0012\u0012A\u0002\u001fs_>$h(C\u0001,\u0003\ry'oZ\u0005\u0003[9\na!\u00199bG\",'\"A\u0016\n\u00055\u0001$BA\u0017/\u0013\tY!G\u0003\u0002\u000ea%\u0011A'N\u0001\ba\u0006\u001c7.Y4f\u0015\tY!'\u0003\u00028q\tIA)\u0019;b\rJ\fW.\u001a\u0006\u0003iUBQA\u000f\u0002A\u0002m\n1a\u0019;y!\taT(D\u0001\u0007\u0013\tqdA\u0001\bN\u0019\n+gn\u00195D_:$X\r\u001f;\u0002\u0013Q\u0014X/Z'pI\u0016dGCA!H!\t\u0011U)D\u0001D\u0015\t!%'\u0001\u0002nY&\u0011ai\u0011\u0002\f)J\fgn\u001d4pe6,'\u000fC\u0003;\u0007\u0001\u00071(A\bue\u0006Lg.\u001b8h\t\u0006$\u0018mU3u)\t\t#\nC\u0003;\t\u0001\u00071HE\u0002M\u001d>3A!\u0014\u0001\u0001\u0017\naAH]3gS:,W.\u001a8u}A\u0011A\b\u0001\t\u0003yAK!!\u0015\u0004\u0003%\t+gn\u00195nCJ\\\u0017\t\\4pe&$\b.\u001c")
/* loaded from: input_file:com/databricks/spark/sql/perf/mllib/TrainingSetFromTransformer.class */
public interface TrainingSetFromTransformer {
    Dataset<Row> initialData(MLBenchContext mLBenchContext);

    Transformer trueModel(MLBenchContext mLBenchContext);

    default Dataset<Row> trainingDataSet(MLBenchContext mLBenchContext) {
        Column as;
        Dataset<Row> initialData = initialData(mLBenchContext);
        Transformer trueModel = trueModel(mLBenchContext);
        Column col = functions$.MODULE$.col("features");
        Some numClasses = mLBenchContext.params().numClasses();
        if (numClasses instanceof Some) {
            int unboxToInt = BoxesRunTime.unboxToInt(numClasses.value());
            as = functions$.MODULE$.col("prediction").as("label", (unboxToInt == 0 ? NumericAttribute$.MODULE$.defaultAttr().withName("label") : NominalAttribute$.MODULE$.defaultAttr().withName("label").withNumValues(unboxToInt)).toMetadata());
        } else {
            if (!None$.MODULE$.equals(numClasses)) {
                throw new MatchError(numClasses);
            }
            as = functions$.MODULE$.col("prediction").as("label");
        }
        return trueModel.transform(initialData).select(Predef$.MODULE$.wrapRefArray(new Column[]{col, as}));
    }

    static void $init$(TrainingSetFromTransformer trainingSetFromTransformer) {
    }
}
