diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyDefaultHyperparams.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyDefaultHyperparams.scala new file mode 100644 index 00000000000..3c856e3986d --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyDefaultHyperparams.scala @@ -0,0 +1,72 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.automl + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.classification._ + +// scalastyle:off magic.number +class VerifyDefaultHyperparams extends TestBase { + + test("LogisticRegression default range is non-empty") { + val lr = new LogisticRegression() + val params = DefaultHyperparams.defaultRange(lr) + assert(params.nonEmpty) + val paramNames = params.map(_._1.name).toSet + assert(paramNames.contains("regParam")) + assert(paramNames.contains("elasticNetParam")) + assert(paramNames.contains("maxIter")) + } + + test("DecisionTreeClassifier default range is non-empty") { + val dt = new DecisionTreeClassifier() + val params = DefaultHyperparams.defaultRange(dt) + assert(params.nonEmpty) + val paramNames = params.map(_._1.name).toSet + assert(paramNames.contains("maxBins")) + assert(paramNames.contains("maxDepth")) + } + + test("GBTClassifier default range is non-empty") { + val gbt = new GBTClassifier() + val params = DefaultHyperparams.defaultRange(gbt) + assert(params.nonEmpty) + assert(params.length >= 5) + } + + test("RandomForestClassifier default range is non-empty") { + val rf = new RandomForestClassifier() + val params = DefaultHyperparams.defaultRange(rf) + assert(params.nonEmpty) + val paramNames = params.map(_._1.name).toSet + assert(paramNames.contains("numTrees")) + } + + test("MultilayerPerceptronClassifier default range is non-empty") { + val mlp = new MultilayerPerceptronClassifier() + val params = DefaultHyperparams.defaultRange(mlp) + assert(params.nonEmpty) + val paramNames = params.map(_._1.name).toSet + assert(paramNames.contains("blockSize")) + assert(paramNames.contains("layers")) + } + + test("NaiveBayes default range is non-empty") { + val nb = new NaiveBayes() + val params = DefaultHyperparams.defaultRange(nb) + assert(params.nonEmpty) + val paramNames = params.map(_._1.name).toSet + assert(paramNames.contains("smoothing")) + } + + test("default ranges produce valid distributions") { + val lr = new LogisticRegression() + val params = DefaultHyperparams.defaultRange(lr) + params.foreach { case (param, dist) => + val value = dist.getNext + assert(value != null) // scalastyle:ignore null + } + } +} +// scalastyle:on magic.number diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyEvaluationUtils.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyEvaluationUtils.scala new file mode 100644 index 00000000000..3010ccc605d --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyEvaluationUtils.scala @@ -0,0 +1,88 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.automl + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import com.microsoft.azure.synapse.ml.core.metrics.MetricConstants +import com.microsoft.azure.synapse.ml.core.schema.SchemaConstants + +class VerifyEvaluationUtils extends TestBase { + + test("getMetricWithOperator returns correct metric for regression MSE") { + val (name, _) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, MetricConstants.MseSparkMetric) + assert(name === MetricConstants.MseColumnName) + } + + test("getMetricWithOperator returns correct metric for regression RMSE") { + val (name, _) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, MetricConstants.RmseSparkMetric) + assert(name === MetricConstants.RmseColumnName) + } + + test("getMetricWithOperator returns correct metric for regression R2") { + val (name, _) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, MetricConstants.R2SparkMetric) + assert(name === MetricConstants.R2ColumnName) + } + + test("getMetricWithOperator returns correct metric for regression MAE") { + val (name, _) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, MetricConstants.MaeSparkMetric) + assert(name === MetricConstants.MaeColumnName) + } + + test("getMetricWithOperator returns correct metric for classification AUC") { + val (name, _) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.ClassificationKind, MetricConstants.AucSparkMetric) + assert(name === MetricConstants.AucColumnName) + } + + test("getMetricWithOperator returns correct metric for classification accuracy") { + val (name, _) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.ClassificationKind, MetricConstants.AccuracySparkMetric) + assert(name === MetricConstants.AccuracyColumnName) + } + + test("regression metrics use chooseLowest ordering (except R2)") { + val (_, mseOrd) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, MetricConstants.MseSparkMetric) + // MSE should prefer lower values + assert(mseOrd.compare(1.0, 2.0) > 0) + + val (_, r2Ord) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, MetricConstants.R2SparkMetric) + // R2 should prefer higher values + assert(r2Ord.compare(1.0, 2.0) < 0) + } + + test("classification metrics use chooseHighest ordering") { + val (_, aucOrd) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.ClassificationKind, MetricConstants.AucSparkMetric) + // AUC should prefer higher values + assert(aucOrd.compare(1.0, 2.0) < 0) + } + + test("unsupported regression metric throws") { + assertThrows[Exception] { + EvaluationUtils.getMetricWithOperator(SchemaConstants.RegressionKind, "bogus_metric") + } + } + + test("unsupported classification metric throws") { + assertThrows[Exception] { + EvaluationUtils.getMetricWithOperator(SchemaConstants.ClassificationKind, "bogus_metric") + } + } + + test("unsupported model type throws") { + assertThrows[Exception] { + EvaluationUtils.getMetricWithOperator("unsupported_type", MetricConstants.MseSparkMetric) + } + } + + test("ModelTypeUnsupportedErr constant is defined") { + assert(EvaluationUtils.ModelTypeUnsupportedErr.nonEmpty) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyHyperparamBuilder.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyHyperparamBuilder.scala new file mode 100644 index 00000000000..def599af7c2 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyHyperparamBuilder.scala @@ -0,0 +1,117 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.automl + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.{IntParam, DoubleParam, Param, Params, ParamMap} +import org.apache.spark.ml.util.Identifiable + +import scala.collection.JavaConverters._ + +// scalastyle:off magic.number +class VerifyHyperparamBuilder extends TestBase { + + private object TestParams extends Params { + override val uid: String = Identifiable.randomUID("TestParams") // scalastyle:ignore field.name + override def copy(extra: ParamMap): Params = this + val intParam = new IntParam(this, "intParam", "test int param") // scalastyle:ignore field.name + val doubleParam = new DoubleParam(this, "doubleParam", "test double param") // scalastyle:ignore field.name + } + + test("IntRangeHyperParam generates values within range") { + val hp = new IntRangeHyperParam(5, 15, seed = 42) + val values = (1 to 100).map(_ => hp.getNext()) + assert(values.forall(v => v >= 5 && v < 15)) + assert(values.toSet.size > 1) // not all the same + } + + test("DoubleRangeHyperParam generates values within range") { + val hp = new DoubleRangeHyperParam(0.0, 1.0, seed = 42) + val values = (1 to 100).map(_ => hp.getNext()) + assert(values.forall(v => v >= 0.0 && v < 1.0)) + } + + test("FloatRangeHyperParam generates values within range") { + val hp = new FloatRangeHyperParam(0.0f, 1.0f, seed = 42) + val values = (1 to 100).map(_ => hp.getNext()) + assert(values.forall(v => v >= 0.0f && v < 1.0f)) + } + + test("LongRangeHyperParam generates values") { + val hp = new LongRangeHyperParam(0L, 100L, seed = 42) + val value = hp.getNext() + assert(value.isInstanceOf[Long]) + } + + test("DiscreteHyperParam selects from provided values") { + val hp = new DiscreteHyperParam(List("a", "b", "c"), seed = 42) + val values = (1 to 100).map(_ => hp.getNext()) + assert(values.forall(Set("a", "b", "c").contains)) + assert(values.toSet.size > 1) + } + + test("DiscreteHyperParam getValues returns Java list") { + val hp = new DiscreteHyperParam(List(1, 2, 3)) + val javaList = hp.getValues + assert(javaList.asScala.toList === List(1, 2, 3)) + } + + test("HyperparamBuilder builds array of param-dist pairs") { + val hp = new HyperparamBuilder() + .addHyperparam(TestParams.intParam, new IntRangeHyperParam(1, 10)) + .addHyperparam(TestParams.doubleParam, new DoubleRangeHyperParam(0.0, 1.0)) + .build() + assert(hp.length === 2) + assert(hp.map(_._1.name).toSet === Set("intParam", "doubleParam")) + } + + test("HyperparamBuilder empty build returns empty array") { + val hp = new HyperparamBuilder().build() + assert(hp.isEmpty) + } + + test("HyperParamUtils.getRangeHyperParam matches Int type") { + val hp = HyperParamUtils.getRangeHyperParam(1, 10) + assert(hp.isInstanceOf[IntRangeHyperParam]) + } + + test("HyperParamUtils.getRangeHyperParam matches Double type") { + val hp = HyperParamUtils.getRangeHyperParam(0.0, 1.0) + assert(hp.isInstanceOf[DoubleRangeHyperParam]) + } + + test("HyperParamUtils.getRangeHyperParam matches Float type") { + val hp = HyperParamUtils.getRangeHyperParam(0.0f, 1.0f) + assert(hp.isInstanceOf[FloatRangeHyperParam]) + } + + test("HyperParamUtils.getRangeHyperParam matches Long type") { + val hp = HyperParamUtils.getRangeHyperParam(0L, 100L) + assert(hp.isInstanceOf[LongRangeHyperParam]) + } + + test("HyperParamUtils.getRangeHyperParam throws on unsupported type") { + assertThrows[Exception] { + HyperParamUtils.getRangeHyperParam("a", "z") + } + } + + test("HyperParamUtils.getDiscreteHyperParam creates from Java ArrayList") { + val javaList = new java.util.ArrayList[String]() + javaList.add("x") + javaList.add("y") + val hp = HyperParamUtils.getDiscreteHyperParam(javaList) + val values = (1 to 50).map(_ => hp.getNext().toString) + assert(values.forall(v => v == "x" || v == "y")) + } + + test("seeded RangeHyperParam produces deterministic sequences") { + val hp1 = new IntRangeHyperParam(0, 100, seed = 123) + val hp2 = new IntRangeHyperParam(0, 100, seed = 123) + val seq1 = (1 to 10).map(_ => hp1.getNext()) + val seq2 = (1 to 10).map(_ => hp2.getNext()) + assert(seq1 === seq2) + } +} +// scalastyle:on magic.number diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyParamSpace.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyParamSpace.scala new file mode 100644 index 00000000000..61c4db270b1 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyParamSpace.scala @@ -0,0 +1,64 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.automl + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.util.Identifiable + +// scalastyle:off magic.number +class VerifyParamSpace extends TestBase { + + private object TestParams extends Params { + override val uid: String = Identifiable.randomUID("TestParams") // scalastyle:ignore field.name + override def copy(extra: ParamMap): Params = this + val intParam = new IntParam(this, "intParam", "test int param") // scalastyle:ignore field.name + } + + test("GridSpace iterates over all ParamMaps") { + val pm1 = ParamMap(TestParams.intParam -> 1) + val pm2 = ParamMap(TestParams.intParam -> 2) + val pm3 = ParamMap(TestParams.intParam -> 3) + val grid = new GridSpace(Array(pm1, pm2, pm3)) + val result = grid.paramMaps.toList + assert(result.length === 3) + } + + test("GridSpace with empty array produces empty iterator") { + val grid = new GridSpace(Array.empty[ParamMap]) + assert(!grid.paramMaps.hasNext) + } + + test("RandomSpace produces infinite iterator") { + val builder = new HyperparamBuilder() + .addHyperparam(TestParams.intParam, new IntRangeHyperParam(1, 100)) + val space = new RandomSpace(builder.build()) + val values = space.paramMaps.take(50).toList + assert(values.length === 50) + values.foreach { pm => + val v = pm.get(TestParams.intParam) + assert(v.isDefined) + assert(v.get >= 1 && v.get < 100) + } + } + + test("RandomSpace iterator always hasNext") { + val builder = new HyperparamBuilder() + .addHyperparam(TestParams.intParam, new IntRangeHyperParam(0, 10)) + val space = new RandomSpace(builder.build()) + assert(space.paramMaps.hasNext) + space.paramMaps.next() + assert(space.paramMaps.hasNext) + } + + test("Dist.getParamPair creates correct ParamPair") { + val dist = new IntRangeHyperParam(5, 15, seed = 42) + val pp = dist.getParamPair(TestParams.intParam) + assert(pp.param.name === TestParams.intParam.name) + val value = pp.value.asInstanceOf[Int] + assert(value >= 5) + assert(value < 15) + } +} +// scalastyle:on magic.number diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/codegen/VerifyCodegenConfig.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/codegen/VerifyCodegenConfig.scala new file mode 100644 index 00000000000..5df471015d0 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/codegen/VerifyCodegenConfig.scala @@ -0,0 +1,62 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.codegen + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import java.io.File + +class VerifyCodegenConfig extends TestBase { + + private val config = CodegenConfig( + name = "testmod", + jarName = Some("testmod.jar"), + topDir = "/top", + targetDir = "/target", + version = "1.0.0", + pythonizedVersion = "1.0.0", + rVersion = "1.0.0", + packageName = "com.test" + ) + + test("generatedDir returns correct path") { + assert(config.generatedDir === new File("/target", "generated")) + } + + test("pySrcDir derives from srcDir") { + assert(config.pySrcDir === new File(config.srcDir, "python")) + } + + test("rSrcDir derives from rSrcRoot") { + assert(config.rSrcDir === new File(config.rSrcRoot, "synapseml/R")) + } + + test("srcDir derives from generatedDir") { + assert(config.srcDir === new File(config.generatedDir, "src")) + } + + test("testDir derives from generatedDir") { + assert(config.testDir === new File(config.generatedDir, "test")) + } + + test("copyrightLines is non-empty") { + assert(config.copyrightLines.nonEmpty) + assert(config.copyrightLines.contains("Copyright")) + } + + test("scopeDepth is 4 spaces") { + assert(config.scopeDepth === " ") + assert(config.scopeDepth.length === 4) // scalastyle:off magic.number + } + + test("internalPrefix is underscore") { + assert(config.internalPrefix === "_") + } + + test("packageHelp produces valid content") { + val help = config.packageHelp("import foo") + assert(help.contains("SynapseML")) + assert(help.contains("import foo")) + assert(help.contains(config.pythonizedVersion)) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/codegen/VerifyDefaultParamInfo.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/codegen/VerifyDefaultParamInfo.scala new file mode 100644 index 00000000000..fef8a9a02d8 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/codegen/VerifyDefaultParamInfo.scala @@ -0,0 +1,51 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.codegen + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable + +class VerifyDefaultParamInfo extends TestBase { + + private object TestParams extends Params { + override val uid: String = Identifiable.randomUID("TestParams") // scalastyle:ignore field.name + override def copy(extra: ParamMap): Params = this + } + + test("getGeneralParamInfo returns BooleanInfo for BooleanParam") { + val p = new BooleanParam(TestParams, "b", "desc") + assert(DefaultParamInfo.getGeneralParamInfo(p) === DefaultParamInfo.BooleanInfo) + } + + test("getGeneralParamInfo returns IntInfo for IntParam") { + val p = new IntParam(TestParams, "i", "desc") + assert(DefaultParamInfo.getGeneralParamInfo(p) === DefaultParamInfo.IntInfo) + } + + test("getGeneralParamInfo returns DoubleInfo for DoubleParam") { + val p = new DoubleParam(TestParams, "d", "desc") + assert(DefaultParamInfo.getGeneralParamInfo(p) === DefaultParamInfo.DoubleInfo) + } + + test("getGeneralParamInfo returns StringArrayInfo for StringArrayParam") { + val p = new StringArrayParam(TestParams, "sa", "desc") + assert(DefaultParamInfo.getGeneralParamInfo(p) === DefaultParamInfo.StringArrayInfo) + } + + test("getGeneralParamInfo returns UnknownInfo for unrecognized param") { + val p = new Param[Any](TestParams, "unknown", "desc") + assert(DefaultParamInfo.getGeneralParamInfo(p) === DefaultParamInfo.UnknownInfo) + } + + test("ParamInfo instances have correct pyType values") { + assert(DefaultParamInfo.BooleanInfo.pyType === "bool") + assert(DefaultParamInfo.IntInfo.pyType === "int") + assert(DefaultParamInfo.DoubleInfo.pyType === "float") + assert(DefaultParamInfo.StringArrayInfo.pyType === "list") + assert(DefaultParamInfo.StringStringMapInfo.pyType === "dict") + assert(DefaultParamInfo.StringInfo.pyType === "str") + assert(DefaultParamInfo.UnknownInfo.pyType === "object") + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/codegen/VerifyGenerationUtils.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/codegen/VerifyGenerationUtils.scala new file mode 100644 index 00000000000..456dc892286 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/codegen/VerifyGenerationUtils.scala @@ -0,0 +1,51 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.codegen + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyGenerationUtils extends TestBase { + + test("indent adds correct number of spaces") { + val input = "line1\nline2\nline3" + val result = GenerationUtils.indent(input, 1) + assert(result === " line1\n line2\n line3") + } + + test("indent with multiple tabs") { + val input = "hello" + val result = GenerationUtils.indent(input, 3) + assert(result === " hello") + } + + test("indent with zero tabs") { + val input = "hello\nworld" + val result = GenerationUtils.indent(input, 0) + assert(result === "hello\nworld") + } + + test("camelToSnake converts simple camelCase") { + assert(GenerationUtils.camelToSnake("maxIter") === "max_iter") + } + + test("camelToSnake converts single word") { + assert(GenerationUtils.camelToSnake("hello") === "hello") + } + + test("camelToSnake handles leading uppercase") { + assert(GenerationUtils.camelToSnake("GBTClassifier") === "gbt_classifier") + } + + test("camelToSnake handles multiple uppercase transitions") { + assert(GenerationUtils.camelToSnake("minInstancesPerNode") === "min_instances_per_node") + } + + test("camelToSnake handles all uppercase") { + assert(GenerationUtils.camelToSnake("ABC") === "abc") + } + + test("camelToSnake handles digits as word boundaries") { + assert(GenerationUtils.camelToSnake("spark3Version") === "spark_3_version") + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/VerifyJarLoadingUtils.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/VerifyJarLoadingUtils.scala new file mode 100644 index 00000000000..84ad8b3e968 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/VerifyJarLoadingUtils.scala @@ -0,0 +1,31 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.core.utils + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyJarLoadingUtils extends TestBase { + + test("className strips .class extension and converts slashes to dots") { + assert(JarLoadingUtils.className("com/example/MyClass.class") === "com.example.MyClass") + } + + test("className returns input unchanged if no .class extension") { + assert(JarLoadingUtils.className("com.example.MyClass") === "com.example.MyClass") + } + + test("className handles nested class paths") { + assert(JarLoadingUtils.className("a/b/c/D.class") === "a.b.c.D") + } + + test("className handles simple filename") { + assert(JarLoadingUtils.className("Main.class") === "Main") + } + + test("OsUtils.IsWindows returns a boolean") { + // We can't assert the value since it depends on the OS, + // but we can assert the code path executes without error + assert(OsUtils.IsWindows.isInstanceOf[Boolean]) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/VerifyModelEquality.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/VerifyModelEquality.scala new file mode 100644 index 00000000000..692de244dc8 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/VerifyModelEquality.scala @@ -0,0 +1,23 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.core.utils + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyModelEquality extends TestBase { + + test("jaccardSimilarity of identical strings is 1.0") { + assert(ModelEquality.jaccardSimilarity("hello", "hello") === 1.0) + } + + test("jaccardSimilarity of different strings is 0.0") { + assert(ModelEquality.jaccardSimilarity("hello", "world") === 0.0) + } + + test("jaccardSimilarity is symmetric") { + val s1 = "abc" + val s2 = "def" + assert(ModelEquality.jaccardSimilarity(s1, s2) === ModelEquality.jaccardSimilarity(s2, s1)) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/http/VerifyRESTHelpers.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/http/VerifyRESTHelpers.scala new file mode 100644 index 00000000000..d65829d9379 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/http/VerifyRESTHelpers.scala @@ -0,0 +1,44 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.io.http + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +// scalastyle:off magic.number +class VerifyRESTHelpers extends TestBase { + + test("retry succeeds on first try with empty backoffs") { + val result = RESTHelpers.retry(List.empty[Int], () => 42) + assert(result === 42) + } + + test("retry succeeds on first try with non-empty backoffs") { + val result = RESTHelpers.retry(List(100, 200), () => "ok") + assert(result === "ok") + } + + test("retry retries on failure and eventually succeeds") { + var attempts = 0 + val result = RESTHelpers.retry(List(1, 1, 1), () => { + attempts += 1 + if (attempts < 3) throw new RuntimeException("fail") + "success" + }) + assert(result === "success") + assert(attempts === 3) + } + + test("retry throws when all retries exhausted") { + intercept[RuntimeException] { + RESTHelpers.retry(List(1), () => throw new RuntimeException("always fails")) + } + } + + test("retry with empty backoff list throws immediately") { + intercept[RuntimeException] { + RESTHelpers.retry(List.empty[Int], () => throw new RuntimeException("immediate")) + } + } +} +// scalastyle:on magic.number diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyCacher.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyCacher.scala new file mode 100644 index 00000000000..f6aca960c76 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyCacher.scala @@ -0,0 +1,44 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.stages + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.storage.StorageLevel + +class VerifyCacher extends TestBase { + import spark.implicits._ + + test("transform with disable=false caches the dataframe") { + val df = Seq(1, 2, 3).toDF("x") + val cacher = new Cacher().setDisable(false) + val result = cacher.transform(df) + result.count() + assert(result.storageLevel !== StorageLevel.NONE) + result.unpersist() + } + + test("transform with disable=true does not cache") { + val df = Seq(4, 5, 6).toDF("y") + val cacher = new Cacher().setDisable(true) + val result = cacher.transform(df) + assert(result.storageLevel === StorageLevel.NONE) + } + + test("default disable value is false") { + val cacher = new Cacher() + assert(!cacher.getDisable) + } + + test("copy preserves params") { + val cacher = new Cacher().setDisable(true) + val copied = cacher.copy(new org.apache.spark.ml.param.ParamMap()) + assert(copied.asInstanceOf[Cacher].getDisable) + } + + test("transformSchema returns same schema") { + val df = Seq(1, 2, 3).toDF("x") + val cacher = new Cacher() + assert(cacher.transformSchema(df.schema) === df.schema) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyTextPreprocessor.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyTextPreprocessor.scala new file mode 100644 index 00000000000..ec788d314cc --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyTextPreprocessor.scala @@ -0,0 +1,51 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.stages + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.sql.types.{StringType, StructField} + +class VerifyTextPreprocessor extends TestBase { + import spark.implicits._ + + test("replaces matched substrings in DataFrame column") { + val df = Seq("hello world").toDF("text") + val tp = new TextPreprocessor() + .setInputCol("text") + .setOutputCol("out") + .setMap(Map("hello" -> "hi")) + val result = tp.transform(df).select("out").collect() + assert(result.head.getString(0) === "hi world") + } + + test("no matches returns original text") { + val df = Seq("hello world").toDF("text") + val tp = new TextPreprocessor() + .setInputCol("text") + .setOutputCol("out") + .setMap(Map("xyz" -> "abc")) + val result = tp.transform(df).select("out").collect() + assert(result.head.getString(0) === "hello world") + } + + test("multiple replacements in same text") { + val df = Seq("hello world.").toDF("text") + val tp = new TextPreprocessor() + .setInputCol("text") + .setOutputCol("out") + .setMap(Map("hello" -> "hi", "world" -> "earth")) + val result = tp.transform(df).select("out").collect() + assert(result.head.getString(0) === "hi earth.") + } + + test("transformSchema adds output column") { + val df = Seq("a").toDF("text") + val tp = new TextPreprocessor() + .setInputCol("text") + .setOutputCol("out") + val schema = tp.transformSchema(df.schema) + assert(schema.fieldNames.contains("out")) + assert(schema(schema.fieldIndex("out")) === StructField("out", StringType)) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyTrie.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyTrie.scala new file mode 100644 index 00000000000..afaa0b1dccb --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyTrie.scala @@ -0,0 +1,50 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.stages + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyTrie extends TestBase { + + test("Trie.apply creates from map") { + val t = Trie(Map("hello" -> "hi")) + assert(t.get('h').isDefined) + } + + test("put and get for single characters") { + val t = new Trie().put("a", "b") + assert(t.get('a').isDefined) + } + + test("mapText replaces matching substrings") { + val t = Trie(Map("hello" -> "hi")) + assert(t.mapText("hello there") === "hi there") + } + + test("mapText with no matches returns original text") { + val t = Trie(Map("xyz" -> "abc")) + assert(t.mapText("hello world") === "hello world") + } + + test("mapText with multiple replacements") { + val t = Trie(Map("hello" -> "hi", "world" -> "earth")) + assert(t.mapText("hello world.") === "hi earth.") + } + + test("putAll adds all entries") { + val t = new Trie().putAll(Map("a" -> "1", "b" -> "2")) + assert(t.get('a').isDefined) + assert(t.get('b').isDefined) + } + + test("get returns None for missing key") { + val t = new Trie() + assert(t.get('z').isEmpty) + } + + test("mapText with overlapping keys longer key wins") { + val t = Trie(Map("he" -> "X", "hello" -> "Y")) + assert(t.mapText("hello.") === "Y.") + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyUnicodeNormalize.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyUnicodeNormalize.scala new file mode 100644 index 00000000000..c79a44f2b0b --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/stages/VerifyUnicodeNormalize.scala @@ -0,0 +1,59 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.stages + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.sql.types.{StringType, StructField} + +class VerifyUnicodeNormalize extends TestBase { + import spark.implicits._ + + test("normalizes unicode text") { + // e + combining acute accent (decomposed) vs precomposed e-acute + val decomposed = "cafe\u0301" + val df = Seq(decomposed).toDF("text") + val normalizer = new UnicodeNormalize() + .setInputCol("text") + .setOutputCol("normalized") + .setForm("NFC") + .setLower(false) + val result = normalizer.transform(df).select("normalized").collect() + assert(result.head.getString(0) === "caf\u00e9") + } + + test("lower=true lowercases output") { + val df = Seq("HELLO").toDF("text") + val normalizer = new UnicodeNormalize() + .setInputCol("text") + .setOutputCol("out") + .setLower(true) + val result = normalizer.transform(df).select("out").collect() + assert(result.head.getString(0) === "hello") + } + + test("lower=false preserves case") { + val df = Seq("Hello").toDF("text") + val normalizer = new UnicodeNormalize() + .setInputCol("text") + .setOutputCol("out") + .setLower(false) + val result = normalizer.transform(df).select("out").collect() + assert(result.head.getString(0).contains("H")) + } + + test("default form is NFKD") { + val normalizer = new UnicodeNormalize() + assert(normalizer.getForm === "NFKD") + } + + test("transformSchema adds output column") { + val df = Seq("a").toDF("text") + val normalizer = new UnicodeNormalize() + .setInputCol("text") + .setOutputCol("out") + val schema = normalizer.transformSchema(df.schema) + assert(schema.fieldNames.contains("out")) + assert(schema(schema.fieldIndex("out")) === StructField("out", StringType)) + } +} diff --git a/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/VerifyDatasetUtils.scala b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/VerifyDatasetUtils.scala new file mode 100644 index 00000000000..4eae009f14c --- /dev/null +++ b/lightgbm/src/test/scala/com/microsoft/azure/synapse/ml/lightgbm/dataset/VerifyDatasetUtils.scala @@ -0,0 +1,72 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.lightgbm.dataset + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.sql.types._ + +// scalastyle:off magic.number +class VerifyDatasetUtils extends TestBase { + + test("countCardinality with all same values") { + val result = DatasetUtils.countCardinality(Seq(1, 1, 1)) + assert(result === Array(3)) + } + + test("countCardinality with all different values") { + val result = DatasetUtils.countCardinality(Seq(1, 2, 3)) + assert(result === Array(1, 1, 1)) + } + + test("countCardinality with grouped values") { + val result = DatasetUtils.countCardinality(Seq(1, 1, 2, 2, 2, 3)) + assert(result === Array(2, 3, 1)) + } + + test("countCardinality with empty sequence") { + val result = DatasetUtils.countCardinality(Seq.empty[Int]) + assert(result === Array(0)) + } + + test("getArrayType with sparse returns true") { + val iter = Iterator.empty + val (_, isSparse) = DatasetUtils.getArrayType(iter, "sparse", "features") + assert(isSparse) + } + + test("getArrayType with dense returns false") { + val iter = Iterator.empty + val (_, isSparse) = DatasetUtils.getArrayType(iter, "dense", "features") + assert(!isSparse) + } + + test("getArrayType with invalid type throws") { + intercept[Exception] { + DatasetUtils.getArrayType(Iterator.empty, "invalid", "features") + } + } + + test("validateGroupColumn throws for unsupported types") { + val schema = StructType(Seq(StructField("g", DoubleType))) + intercept[IllegalArgumentException] { + DatasetUtils.validateGroupColumn("g", schema) + } + } + + test("validateGroupColumn passes for IntegerType") { + val schema = StructType(Seq(StructField("g", IntegerType))) + DatasetUtils.validateGroupColumn("g", schema) + } + + test("validateGroupColumn passes for LongType") { + val schema = StructType(Seq(StructField("g", LongType))) + DatasetUtils.validateGroupColumn("g", schema) + } + + test("validateGroupColumn passes for StringType") { + val schema = StructType(Seq(StructField("g", StringType))) + DatasetUtils.validateGroupColumn("g", schema) + } +} +// scalastyle:on magic.number diff --git a/vw/src/test/scala/com/microsoft/azure/synapse/ml/vw/VerifyVectorUtils.scala b/vw/src/test/scala/com/microsoft/azure/synapse/ml/vw/VerifyVectorUtils.scala new file mode 100644 index 00000000000..a580952fbcd --- /dev/null +++ b/vw/src/test/scala/com/microsoft/azure/synapse/ml/vw/VerifyVectorUtils.scala @@ -0,0 +1,73 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.vw + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +// scalastyle:off magic.number +class VerifyVectorUtils extends TestBase { + + test("sortAndDistinct with empty arrays returns empty arrays") { + val (indices, values) = VectorUtils.sortAndDistinct(Array[Int](), Array[Double]()) + assert(indices.isEmpty) + assert(values.isEmpty) + } + + test("sortAndDistinct sorts indices and values together") { + val (indices, values) = VectorUtils.sortAndDistinct( + Array(3, 1, 2), Array(30.0, 10.0, 20.0)) + assert(indices === Array(1, 2, 3)) + assert(values === Array(10.0, 20.0, 30.0)) + } + + test("sortAndDistinct deduplicates and sums collisions by default") { + val (indices, values) = VectorUtils.sortAndDistinct( + Array(1, 2, 1), Array(10.0, 20.0, 5.0)) + assert(indices === Array(1, 2)) + assert(values === Array(15.0, 20.0)) + } + + test("sortAndDistinct deduplicates without summing when sumCollisions is false") { + val (indices, values) = VectorUtils.sortAndDistinct( + Array(1, 2, 1), Array(10.0, 20.0, 5.0), sumCollisions = false) + assert(indices === Array(1, 2)) + assert(values === Array(10.0, 20.0)) + } + + test("sortAndDistinct with single element") { + val (indices, values) = VectorUtils.sortAndDistinct( + Array(5), Array(42.0)) + assert(indices === Array(5)) + assert(values === Array(42.0)) + } + + test("sortAndDistinct with already sorted no-duplicate input") { + val (indices, values) = VectorUtils.sortAndDistinct( + Array(1, 2, 3, 4), Array(1.0, 2.0, 3.0, 4.0)) + assert(indices === Array(1, 2, 3, 4)) + assert(values === Array(1.0, 2.0, 3.0, 4.0)) + } + + test("sortAndDistinct with all duplicate indices sums to single element") { + val (indices, values) = VectorUtils.sortAndDistinct( + Array(5, 5, 5), Array(1.0, 2.0, 3.0)) + assert(indices === Array(5)) + assert(values === Array(6.0)) + } + + test("sortAndDistinct with multiple collision groups") { + val (indices, values) = VectorUtils.sortAndDistinct( + Array(3, 1, 3, 1, 2), Array(1.0, 2.0, 3.0, 4.0, 5.0)) + assert(indices === Array(1, 2, 3)) + assert(values === Array(6.0, 5.0, 4.0)) + } + + test("sortAndDistinct preserves negative values") { + val (indices, values) = VectorUtils.sortAndDistinct( + Array(2, 1), Array(-5.0, -3.0)) + assert(indices === Array(1, 2)) + assert(values === Array(-3.0, -5.0)) + } +} +// scalastyle:on magic.number