diff --git a/.agents/skills/scala-code/SKILL.md b/.agents/skills/scala-code/SKILL.md new file mode 100644 index 00000000000..4fba83e9fab --- /dev/null +++ b/.agents/skills/scala-code/SKILL.md @@ -0,0 +1,44 @@ +# Scala Code Skill + +Use this skill for any Scala code change in this repository. + +## Objectives +- Keep Scala changes correct, minimal, and production-safe. +- Prevent CI/CD breakage by validating style and compilation before completion. +- Ensure every behavior change is covered by tests. +- Scala code should be optimized to be efficient and maintainable, following existing patterns and practices in the codebase. + +## Scala best practices +- Follow existing SynapseML patterns (`DefaultParamsReadable`, `DefaultParamsWritable`, `Wrappable`, `SynapseMLLogging`). +- Keep business logic in Scala (not generated Python wrappers). +- Do not edit generated files under `target/`. +- Preserve license headers and existing package structure. +- Prefer small, focused changes and reuse existing helpers/traits. +- Avoid introducing flaky tests or network-dependent behavior unless already required by the suite. +- To get around scalastyle issues, don't just use `// scalastyle:off` or `//scalastyle:ignore`, but instead fix the underlying issue or refactor to avoid it. Unless, the refactor results in a more complex code structure, in which case, it may be acceptable to disable the specific scalastyle rule for that line or block of code, but this should be done sparingly and with justification. + +## Required validation steps (must run) +Run these commands before finishing Scala changes: + +1. Scala style checks: + - `sbt scalastyle "Test / scalastyle"` +2. Scala compile checks: + - `sbt compile` + - `sbt test:compile` +3. Relevant tests for touched modules/files: + - Example: `sbt "core/testOnly *SuiteName*"` or `sbt "cognitive/testOnly *SuiteName*"` + +If a command fails, fix the issue and rerun until passing. + +## Testing requirement for code changes +- Any Scala code change must be accompanied by tests (new tests or updates to existing tests). +- Bug fixes must include a regression test that fails before the fix and passes after. +- New logic/branches should include coverage for success and failure/edge cases where practical. +- Keep tests deterministic and aligned with current module test conventions. + +## Completion checklist +- [ ] Code follows existing Scala/SynapseML conventions. +- [ ] Style checks pass. +- [ ] Compile checks pass. +- [ ] Relevant tests pass. +- [ ] Scala code changes include corresponding tests. diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 85bfae3c347..c8fcf9be75e 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -47,7 +47,7 @@ jobs: sudo apt-get install -yq sbt - name: Scalastyle check - run: sbt scalastyle test:scalastyle + run: sbt scalastyle "Test / scalastyle" - name: Compile - run: sbt compile test:compile + run: sbt compile "Test / compile" diff --git a/build.sbt b/build.sbt index 72d00806c69..60cd91663a5 100644 --- a/build.sbt +++ b/build.sbt @@ -268,6 +268,19 @@ uploadNotebooks := { uploadToBlob(localNotebooksFolder, blobNotebooksFolder, "docs") } +// Scoverage configuration: exclude legitimately untestable code +// These are either build-time utilities or require external environments +val coverageExclusions = Seq( + // Build-time code generation utilities (not runtime code - invoked during sbt build) + "com\\.microsoft\\.azure\\.synapse\\.ml\\.codegen\\..*", + // Generated BuildInfo classes (auto-generated by sbt-buildinfo) + "com\\.microsoft\\.azure\\.synapse\\.ml\\.build\\..*", + // Microsoft Fabric integration (requires Fabric environment files at specific paths) + "com\\.microsoft\\.azure\\.synapse\\.ml\\.fabric\\..*", + // Fabric-specific logging/telemetry (requires Fabric environment) + "com\\.microsoft\\.azure\\.synapse\\.ml\\.logging\\.fabric\\..*" +).mkString(";") + val settings = Seq( Test / scalastyleConfig := (ThisBuild / baseDirectory).value / "scalastyle-test-config.xml", Test / logBuffered := false, @@ -281,7 +294,13 @@ val settings = Seq( assembly / assemblyOption := (assembly / assemblyOption).value.copy(includeScala = false), autoAPIMappings := true, pomPostProcess := pomPostFunc, - sbtPlugin := false + sbtPlugin := false, + // Scoverage settings + coverageExcludedPackages := coverageExclusions, + coverageFailOnMinimum := false, + coverageHighlighting := true, + // Enable Cobertura XML output for Azure DevOps code coverage reporting + coverageOutputCobertura := true ) ThisBuild / publishMavenStyle := true diff --git a/codecov.yaml b/codecov.yaml index a9b25613766..b3f664c4d50 100644 --- a/codecov.yaml +++ b/codecov.yaml @@ -1,6 +1,25 @@ codecov: notify: require_ci_to_pass: no + # Wait for expected number of coverage uploads before sending notifications + # This prevents premature comments with incomplete data + after_n_builds: 40 + +comment: + layout: "reach,diff,flags,files" + behavior: new + require_changes: false + require_base: false + require_head: true + # Wait for all expected uploads before commenting + after_n_builds: 40 + +comment: + layout: "reach,diff,flags,files" + behavior: new + require_changes: false + require_base: false + require_head: true coverage: precision: 2 @@ -25,6 +44,9 @@ flags: scala: paths: - src/main/scala + # Carry forward coverage from previous builds for unchanged files + carryforward: true python: paths: - src/main/python + carryforward: true diff --git a/cognitive/src/test/java/mssparkutils/cognitiveService.java b/cognitive/src/test/java/mssparkutils/cognitiveService.java new file mode 100644 index 00000000000..fbcfde09ea8 --- /dev/null +++ b/cognitive/src/test/java/mssparkutils/cognitiveService.java @@ -0,0 +1,27 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package mssparkutils; + +public final class cognitiveService { + + private cognitiveService() { } + + public static String getEndpoint(String linkedServiceName) { + return "https://" + linkedServiceName + ".endpoint"; + } + + public static String getKey(String linkedServiceName) { + return "key-" + linkedServiceName; + } + + public static String getLocation(String linkedServiceName) { + if ("gov".equals(linkedServiceName)) { + return "usgovvirginia"; + } + if ("cn".equals(linkedServiceName)) { + return "chinanorth"; + } + return "eastus"; + } +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBaseSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBaseSuite.scala new file mode 100644 index 00000000000..46b1ef30150 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBaseSuite.scala @@ -0,0 +1,189 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import com.microsoft.azure.synapse.ml.param.ServiceParam +import org.apache.http.entity.AbstractHttpEntity +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.sql.Row +import spray.json.DefaultJsonProtocol._ + +private class ServiceParamHarness(override val uid: String = "serviceParamHarness") + extends Params with HasServiceParams { + + val requiredText: ServiceParam[String] = + new ServiceParam[String](this, "requiredText", "required text", isRequired = true) + + val optionalText: ServiceParam[String] = + new ServiceParam[String](this, "optionalText", "optional text") + + val urlVersion: ServiceParam[String] = + new ServiceParam[String](this, "urlVersion", "url version", isURLParam = true) + + def vectorParamMap: Map[String, String] = getVectorParamMap + + def requiredParamNames: Set[String] = getRequiredParams.map(_.name).toSet + + def urlParamNames: Set[String] = getUrlParams.map(_.name).toSet + + def shouldSkipRow(row: Row): Boolean = shouldSkip(row) + + def valueMap(row: Row, excludes: Set[ServiceParam[_]] = Set()): Map[String, Any] = getValueMap(row, excludes) + + def valueAnyOpt(row: Row, p: ServiceParam[_]): Option[Any] = getValueAnyOpt(row, p) + + override def copy(extra: ParamMap): Params = this +} + +private class LocationHarness(override val uid: String = "locationHarness") + extends Params with HasSetLocation { + + override def urlPath: String = "/v1/resource" + + override def copy(extra: ParamMap): Params = this +} + +private class CustomDomainHarness(override val uid: String = "customDomainHarness") + extends Params with HasCustomCogServiceDomain { + + override def urlPath: String = "/deployments/chat/completions" + + override private[ml] def internalServiceType: String = "openai" + + override def copy(extra: ParamMap): Params = this +} + +private class LinkedServiceHarness(override val uid: String = "linkedServiceHarness") + extends Params with HasSetLinkedService { + + override def urlPath: String = "/analyze" + + override def copy(extra: ParamMap): Params = this +} + +private class LinkedServiceLocationHarness(override val uid: String = "linkedServiceLocationHarness") + extends Params with HasSetLinkedServiceUsingLocation { + + override def urlPath: String = "/analyze" + + override def copy(extra: ParamMap): Params = this +} + +private class CognitiveInputHarness(override val uid: String = "cognitiveInputHarness") + extends Params with HasCognitiveServiceInput { + + val apiVersion: ServiceParam[String] = + new ServiceParam[String](this, "apiVersion", "api version", isURLParam = true) + + val requiredText: ServiceParam[String] = + new ServiceParam[String](this, "requiredText", "required text", isRequired = true) + + override protected def prepareEntity: Row => Option[AbstractHttpEntity] = _ => None + + def buildUrl(row: Row): String = prepareUrl(row) + + def headers(row: Row, addContentType: Boolean = true): Map[String, String] = getHeaders(row, addContentType) + + def shouldSkipRow(row: Row): Boolean = shouldSkip(row) + + override def copy(extra: ParamMap): Params = this +} + +class CognitiveServiceBaseSuite extends TestBase { + + import spark.implicits._ + + test("setLocation maps cloud domains deterministically") { + val service = new LocationHarness() + .setLocation("eastus") + assert(service.getUrl == "https://eastus.api.cognitive.microsoft.com/v1/resource") + + service.setLocation("usgovarizona") + assert(service.getUrl == "https://usgovarizona.api.cognitive.microsoft.us/v1/resource") + + service.setLocation("chinanorth") + assert(service.getUrl == "https://chinanorth.api.cognitive.microsoft.cn/v1/resource") + } + + test("custom domain helpers build deterministic urls") { + val service = new CustomDomainHarness() + .setCustomServiceName("contoso") + assert(service.getUrl == "https://contoso.cognitiveservices.azure.com/deployments/chat/completions") + + service.setEndpoint("https://custom.endpoint/") + assert(service.getUrl == "https://custom.endpoint/deployments/chat/completions") + + val internal = new CustomDomainHarness() + .setDefaultInternalEndpoint("https://fabric") + assert(internal.getOrDefault(internal.url) == "https://fabric/cognitive/openai/deployments/chat/completions") + } + + test("linked service setter resolves endpoint and key locally") { + val service = new LinkedServiceHarness() + .setLinkedService("demo") + + assert(service.getUrl == "https://demo.endpoint/analyze") + assert(service.getSubscriptionKey == "key-demo") + } + + test("linked service location setter resolves domain and key locally") { + val service = new LinkedServiceLocationHarness() + .setLinkedService("gov") + + assert(service.getUrl == "https://usgovvirginia.api.cognitive.microsoft.us/analyze") + assert(service.getSubscriptionKey == "key-gov") + } + + test("service param helper methods are deterministic") { + val harness = new ServiceParamHarness() + harness.setVectorParam(harness.requiredText, "requiredCol") + harness.setVectorParam("urlVersion", "versionCol") + harness.setScalarParam("optionalText", "fallback") + + val row = Seq(("hello", "2024-10-01")).toDF("requiredCol", "versionCol").head() + assert(harness.vectorParamMap == Map("requiredText" -> "requiredCol", "urlVersion" -> "versionCol")) + assert(harness.requiredParamNames == Set("requiredText")) + assert(harness.urlParamNames == Set("urlVersion")) + assert(!harness.shouldSkipRow(row)) + assert(harness.valueAnyOpt(row, harness.urlVersion).contains("2024-10-01")) + assert( + harness.valueMap(row, Set(harness.urlVersion)) == + Map("requiredText" -> "hello", "optionalText" -> "fallback") + ) + + val missingRequired = Seq((Option.empty[String], "2024-10-01")).toDF("requiredCol", "versionCol").head() + assert(harness.shouldSkipRow(missingRequired)) + } + + test("cognitive input helper methods build urls and headers locally") { + val input = new CognitiveInputHarness() + input.setUrl("https://example.test/root") + input.setVectorParam(input.requiredText, "textCol") + input.setVectorParam(input.apiVersion, "versionCol") + + val row = Seq(("hello", "2024-10-01")).toDF("textCol", "versionCol").head() + assert(input.buildUrl(row) == "https://example.test/root?apiVersion=2024-10-01") + assert(!input.shouldSkipRow(row)) + + input.setCustomUrlRoot("https://override.test/root") + assert(input.buildUrl(row) == "https://override.test/root") + + val subscriptionHeaders = new CognitiveInputHarness().setSubscriptionKey("sub-key").headers(Row.empty) + assert(subscriptionHeaders("Ocp-Apim-Subscription-Key") == "sub-key") + assert(subscriptionHeaders("Content-Type") == "application/json") + assert(!subscriptionHeaders.contains("Authorization")) + + val aadHeaders = new CognitiveInputHarness().setAADToken("aad-token").headers(Row.empty) + assert(aadHeaders("Authorization") == "Bearer aad-token") + + val customHeaders = new CognitiveInputHarness() + .setCustomAuthHeader("Shared custom-auth") + .setCustomHeaders(Map("X-Test" -> "1")) + .headers(Row.empty) + assert(customHeaders("Authorization") == "Shared custom-auth") + assert(customHeaders("X-Test") == "1") + assert(customHeaders.contains("x-ai-telemetry-properties")) + } +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/AnomalyDetectionCoreSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/AnomalyDetectionCoreSuite.scala new file mode 100644 index 00000000000..c7ba5056f0b --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/AnomalyDetectionCoreSuite.scala @@ -0,0 +1,154 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.anomaly + +import org.apache.http.entity.StringEntity +import org.apache.http.util.EntityUtils +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.types.{ + ArrayType, DataType, DoubleType, IntegerType, StringType, StructField, StructType +} +import org.scalatest.funsuite.AnyFunSuite +import spray.json._ +import spray.json.DefaultJsonProtocol._ + +private[anomaly] class TestableAnomalyDetector(uid: String) extends AnomalyDetectorBase(uid) { + def setSeries(v: Seq[TimeSeriesPoint]): this.type = setScalarParam(series, v) + + def setSeriesCol(v: String): this.type = setVectorParam(series, v) + + def buildEntity(row: Row): StringEntity = prepareEntity(row).get.asInstanceOf[StringEntity] + + override def responseDataType: DataType = ADEntireResponse.schema + + override def urlPath: String = "/anomalydetector/v1.1/timeseries/entire/detect" +} + +class AnomalyDetectionCoreSuite extends AnyFunSuite { + + import AnomalyDetectorProtocol._ + + private val timeSeriesPointSchema = StructType(Seq( + StructField("timestamp", StringType, nullable = false), + StructField("value", DoubleType, nullable = false) + )) + + private val requestSchema = StructType(Seq( + StructField("seriesInput", ArrayType(timeSeriesPointSchema), nullable = false), + StructField("granularityInput", StringType, nullable = false), + StructField("maxRatioInput", DoubleType, nullable = true), + StructField("sensitivityInput", IntegerType, nullable = true), + StructField("customIntervalInput", IntegerType, nullable = true), + StructField("periodInput", IntegerType, nullable = true), + StructField("imputeModeInput", StringType, nullable = true), + StructField("imputeFixedValueInput", DoubleType, nullable = true) + )) + + private def parseEntity(entity: StringEntity): ADRequest = + EntityUtils.toString(entity, "UTF-8").parseJson.convertTo[ADRequest] + + private def timeSeriesRow(timestamp: String, value: Double): Row = + new GenericRowWithSchema(Array[Any](timestamp, value), timeSeriesPointSchema) + + test("ADEntireResponse explode returns aligned single responses") { + val entire = ADEntireResponse( + isAnomaly = Seq(false, true), + isPositiveAnomaly = Seq(false, true), + isNegativeAnomaly = Seq(false, false), + period = 12, + expectedValues = Seq(10.0, 99.0), + upperMargins = Seq(11.0, 105.0), + lowerMargins = Seq(9.0, 95.0), + severity = Seq(0.1, 0.8) + ) + + assert(entire.explode == Seq( + ADSingleResponse(false, false, false, 12, 10.0, 11.0, 9.0, 0.1), + ADSingleResponse(true, true, false, 12, 99.0, 105.0, 95.0, 0.8) + )) + } + + test("AnomalyDetectorProtocol json format round-trips ADRequest") { + val request = ADRequest( + series = Seq(TimeSeriesPoint("2024-01-01T00:00:00Z", 1.5)), + granularity = "daily", + maxAnomalyRatio = Some(0.3), + sensitivity = None, + customInterval = Some(2), + period = Some(7), + imputeMode = Some("linear"), + imputeFixedValue = None + ) + + val json = request.toJson.asJsObject + assert(json.fields("series").convertTo[Seq[TimeSeriesPoint]] == request.series) + assert(json.fields("granularity").convertTo[String] == "daily") + assert(json.convertTo[ADRequest] == request) + } + + test("prepareEntity builds deterministic payload from scalar params") { + val detector = new TestableAnomalyDetector("scalar-ad") + .setSeries(Seq( + TimeSeriesPoint("2024-01-01T00:00:00Z", 10.0), + TimeSeriesPoint("2024-01-02T00:00:00Z", 11.5) + )) + .setGranularity("hourly") + .setMaxAnomalyRatio(0.2) + .setSensitivity(88) + .setCustomInterval(3) + .setPeriod(24) + .setImputeMode("auto") + .setImputeFixedValue(5.0) + + val request = parseEntity(detector.buildEntity(Row.empty)) + assert(request.series.map(_.timestamp) == Seq("2024-01-01T00:00:00Z", "2024-01-02T00:00:00Z")) + assert(request.granularity == "hourly") + assert(request.maxAnomalyRatio.contains(0.2)) + assert(request.sensitivity.contains(88)) + assert(request.customInterval.contains(3)) + assert(request.period.contains(24)) + assert(request.imputeMode.contains("auto")) + assert(request.imputeFixedValue.contains(5.0)) + } + + test("prepareEntity converts row-based series structs and optional params") { + val detector = new TestableAnomalyDetector("row-ad") + .setSeriesCol("seriesInput") + .setGranularityCol("granularityInput") + .setMaxAnomalyRatioCol("maxRatioInput") + .setSensitivityCol("sensitivityInput") + .setCustomIntervalCol("customIntervalInput") + .setPeriodCol("periodInput") + .setImputeModeCol("imputeModeInput") + .setImputeFixedValueCol("imputeFixedValueInput") + + val requestRow = new GenericRowWithSchema(Array[Any]( + Seq( + timeSeriesRow("2024-05-01T00:00:00Z", 20.0), + timeSeriesRow("2024-05-02T00:00:00Z", 21.0) + ), + "daily", + 0.4, + 75, + 5, + 14, + "fixed", + 2.5 + ), requestSchema) + + val request = parseEntity(detector.buildEntity(requestRow)) + assert(request.series == Seq( + TimeSeriesPoint("2024-05-01T00:00:00Z", 20.0), + TimeSeriesPoint("2024-05-02T00:00:00Z", 21.0) + )) + assert(request.granularity == "daily") + assert(request.maxAnomalyRatio.contains(0.4)) + assert(request.sensitivity.contains(75)) + assert(request.customInterval.contains(5)) + assert(request.period.contains(14)) + assert(request.imputeMode.contains("fixed")) + assert(request.imputeFixedValue.contains(2.5)) + } +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/AnomalyDetectorSchemasSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/AnomalyDetectorSchemasSuite.scala new file mode 100644 index 00000000000..b5763f2611c --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/AnomalyDetectorSchemasSuite.scala @@ -0,0 +1,210 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.anomaly + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import spray.json._ + +import scala.collection.JavaConverters._ + +class AnomalyDetectorSchemasSuite extends TestBase { + + import AnomalyDetectorProtocol._ + import MADJsonProtocol._ + + test("AnomalyDetectorProtocol serializes and deserializes ADRequest with optionals") { + val request = ADRequest( + series = Seq( + TimeSeriesPoint("2024-01-01T00:00:00Z", 1.0), + TimeSeriesPoint("2024-01-01T01:00:00Z", 2.0) + ), + granularity = "hourly", + maxAnomalyRatio = Some(0.2), + sensitivity = Some(95), + customInterval = Some(4), + period = Some(8), + imputeMode = Some("fixed"), + imputeFixedValue = Some(1.5) + ) + + assert(request.toJson.convertTo[ADRequest] == request) + assert(TimeSeriesPoint("2024-01-01T00:00:00Z", 1.0).toJson.convertTo[TimeSeriesPoint] == + TimeSeriesPoint("2024-01-01T00:00:00Z", 1.0)) + } + + test("AnomalyDetectorProtocol deserializes ADRequest optionals to None when omitted") { + val json = + """{"series":[{"timestamp":"2024-01-01T00:00:00Z","value":1.0}],"granularity":"daily"}""" + val request = json.parseJson.convertTo[ADRequest] + + assert(request.series == Seq(TimeSeriesPoint("2024-01-01T00:00:00Z", 1.0))) + assert(request.granularity == "daily") + assert(request.maxAnomalyRatio.isEmpty) + assert(request.sensitivity.isEmpty) + assert(request.customInterval.isEmpty) + assert(request.period.isEmpty) + assert(request.imputeMode.isEmpty) + assert(request.imputeFixedValue.isEmpty) + } + + test("ADEntireResponse explode returns one ADSingleResponse per index") { + val response = ADEntireResponse( + isAnomaly = Seq(true, false), + isPositiveAnomaly = Seq(true, false), + isNegativeAnomaly = Seq(false, true), + period = 12, + expectedValues = Seq(1.1, 2.2), + upperMargins = Seq(1.3, 2.4), + lowerMargins = Seq(0.9, 2.0), + severity = Seq(0.8, 0.1) + ) + + val exploded = response.explode + assert(exploded == Seq( + ADSingleResponse(true, true, false, 12, 1.1, 1.3, 0.9, 0.8), + ADSingleResponse(false, false, true, 12, 2.2, 2.4, 2.0, 0.1) + )) + } + + test("MADJsonProtocol round trips key schema case classes") { + val error = DMAError(Some("BadRequest"), Some("invalid input")) + val variableState = DMAVariableState( + Some("cpu"), + Some(0.2), + Some(30), + Some("2024-01-01T00:00:00Z"), + Some("2024-01-01T01:00:00Z") + ) + val modelState = ModelState( + Some(Seq(1, 2)), + Some(Seq(1.0, 0.8)), + Some(Seq(1.1, 0.9)), + Some(Seq(0.4, 0.5)) + ) + val diagnosticsInfo = DiagnosticsInfo(Some(modelState), Some(Seq(variableState))) + val alignPolicy = AlignPolicy(Some("Outer"), Some("Linear"), Some(0)) + val maeRequest = MAERequest( + "https://storage/path.csv", + "OneTable", + "2024-01-01T00:00:00Z", + "2024-01-02T00:00:00Z", + Some(300), + Some(alignPolicy), + Some("demo-model") + ) + val correlationChanges = CorrelationChanges(Some(Seq("memory"))) + val interpretation = Interpretation(Some("cpu"), Some(0.7), Some(correlationChanges)) + val dmaValue = DMAValue(Some(Seq(interpretation)), Some(true), Some(0.4), Some(0.9)) + val dmaResult = DMAResult("2024-01-01T00:01:00Z", Some(dmaValue), Some(Seq(error))) + val setupInfo = DMASetupInfo( + "https://storage/path.csv", + Some(5), + "2024-01-01T00:00:00Z", + "2024-01-02T00:00:00Z" + ) + val summary = DMASummary("Ready", Some(Seq(error)), Some(Seq(variableState)), setupInfo) + val maeModelInfo = MAEModelInfo( + Some(300), + Some(alignPolicy), + "https://storage/path.csv", + "OneTable", + "2024-01-01T00:00:00Z", + "2024-01-02T00:00:00Z", + Some("demo-model"), + "Ready", + Some(Seq(error)), + Some(diagnosticsInfo) + ) + val variable = Variable( + Seq("2024-01-01T00:00:00Z", "2024-01-01T00:01:00Z"), + Seq(1.0, 2.0), + "cpu" + ) + val dlmaRequest = DLMARequest(Seq(variable), 3) + val dmaRequest = DMARequest( + "https://storage/path.csv", + "2024-01-01T00:00:00Z", + "2024-01-02T00:00:00Z", + Some(5) + ) + + assert(error.toJson.convertTo[DMAError] == error) + assert(variableState.toJson.convertTo[DMAVariableState] == variableState) + assert(modelState.toJson.convertTo[ModelState] == modelState) + assert(diagnosticsInfo.toJson.convertTo[DiagnosticsInfo] == diagnosticsInfo) + assert(alignPolicy.toJson.convertTo[AlignPolicy] == alignPolicy) + assert(maeRequest.toJson.convertTo[MAERequest] == maeRequest) + assert(correlationChanges.toJson.convertTo[CorrelationChanges] == correlationChanges) + assert(interpretation.toJson.convertTo[Interpretation] == interpretation) + assert(dmaValue.toJson.convertTo[DMAValue] == dmaValue) + assert(dmaResult.toJson.convertTo[DMAResult] == dmaResult) + assert(setupInfo.toJson.convertTo[DMASetupInfo] == setupInfo) + assert(summary.toJson.convertTo[DMASummary] == summary) + assert(maeModelInfo.toJson.convertTo[MAEModelInfo] == maeModelInfo) + assert(variable.toJson.convertTo[Variable] == variable) + assert(dlmaRequest.toJson.convertTo[DLMARequest] == dlmaRequest) + assert(dmaRequest.toJson.convertTo[DMARequest] == dmaRequest) + } + + test("MADJsonProtocol deserializes optional fields to None when absent") { + val dmaRequestJson = + """{"dataSource":"source","startTime":"2024-01-01T00:00:00Z","endTime":"2024-01-02T00:00:00Z"}""" + val maeRequestJson = + """{"dataSource":"source","dataSchema":"OneTable",""" + + """"startTime":"2024-01-01T00:00:00Z",""" + + """"endTime":"2024-01-02T00:00:00Z"}""" + val dmaResultJson = """{"timestamp":"2024-01-01T00:00:00Z"}""" + val dmaValueJson = """{}""" + + assert(dmaRequestJson.parseJson.convertTo[DMARequest].topContributorCount.isEmpty) + val maeRequest = maeRequestJson.parseJson.convertTo[MAERequest] + assert(maeRequest.slidingWindow.isEmpty) + assert(maeRequest.alignPolicy.isEmpty) + assert(maeRequest.displayName.isEmpty) + val result = dmaResultJson.parseJson.convertTo[DMAResult] + assert(result.value.isEmpty) + assert(result.errors.isEmpty) + val value = dmaValueJson.parseJson.convertTo[DMAValue] + assert(value.interpretation.isEmpty) + assert(value.isAnomaly.isEmpty) + assert(value.severity.isEmpty) + assert(value.score.isEmpty) + } + + test("DMAVariableState and DiagnosticsInfo getters cover present and empty branches") { + val variableState = DMAVariableState( + Some("cpu"), + Some(0.3), + Some(11), + Some("2024-01-01T00:00:00Z"), + Some("2024-01-01T00:10:00Z") + ) + assert(variableState.getVariable == "cpu") + assert(variableState.getFilledNARatio == 0.3) + assert(variableState.getEffectiveCount == 11) + assert(variableState.getFirstTimestamp == "2024-01-01T00:00:00Z") + assert(variableState.getLastTimestamp == "2024-01-01T00:10:00Z") + + val modelState = ModelState(None, None, None, None) + assert(modelState.getEpochIds.asScala.isEmpty) + assert(modelState.getTrainLosses.asScala.isEmpty) + assert(modelState.getValidationLosses.asScala.isEmpty) + assert(modelState.getLatenciesInSeconds.asScala.isEmpty) + + val diagnosticsInfo = DiagnosticsInfo(Some(modelState), Some(Seq(variableState))) + assert(diagnosticsInfo.getModelState == modelState) + assert(diagnosticsInfo.getVariableStates.asScala.toSeq == Seq(variableState)) + + val emptyVariableState = DMAVariableState(None, None, None, None, None) + val emptyDiagnosticsInfo = DiagnosticsInfo(None, None) + intercept[NoSuchElementException](emptyVariableState.getVariable) + intercept[NoSuchElementException](emptyVariableState.getFilledNARatio) + intercept[NoSuchElementException](emptyVariableState.getEffectiveCount) + intercept[NoSuchElementException](emptyVariableState.getFirstTimestamp) + intercept[NoSuchElementException](emptyVariableState.getLastTimestamp) + intercept[NoSuchElementException](emptyDiagnosticsInfo.getModelState) + intercept[NoSuchElementException](emptyDiagnosticsInfo.getVariableStates) + } + +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/MultivariateAnamolyDetectionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/MultivariateAnamolyDetectionSuite.scala deleted file mode 100644 index 8a6148ef7eb..00000000000 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/MultivariateAnamolyDetectionSuite.scala +++ /dev/null @@ -1,340 +0,0 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. See LICENSE in project root for information. - -package com.microsoft.azure.synapse.ml.services.anomaly -// -//import com.microsoft.azure.synapse.ml.Secrets -//import com.microsoft.azure.synapse.ml.core.test.base.{Flaky, TestBase} -//import com.microsoft.azure.synapse.ml.core.test.benchmarks.DatasetUtils -//import com.microsoft.azure.synapse.ml.core.test.fuzzing.{EstimatorFuzzing, TestObject, TransformerFuzzing} -//import org.apache.hadoop.conf.Configuration -//import org.apache.spark.ml.util.MLReadable -//import org.apache.spark.sql.DataFrame -//import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} -//import spray.json.{DefaultJsonProtocol, _} -// -//import java.time.ZonedDateTime -//import java.time.format.DateTimeFormatter -//import scala.collection.mutable -// -// -//case class MADListModelsResponse(models: Seq[MADModel], -// currentCount: Int, -// maxCount: Int, -// nextLink: Option[String]) -// -//case class MADModel(modelId: String, -// createdTime: String, -// lastUpdatedTime: String, -// status: String, -// displayName: Option[String], -// variablesCount: Int) -// -//object MADListModelsProtocol extends DefaultJsonProtocol { -// -// implicit val MADModelEnc: RootJsonFormat[MADModel] = jsonFormat6(MADModel) -// implicit val MADLMRespEnc: RootJsonFormat[MADListModelsResponse] = jsonFormat4(MADListModelsResponse) -// -//} -// -//trait StorageCredentials { -// -// lazy val storageKey: String = sys.env.getOrElse("STORAGE_KEY", Secrets.MADTestStorageKey) -// lazy val storageAccount = "anomalydetectiontest" -// lazy val containerName = "madtest" -// -//} -// -//trait MADTestUtils extends TestBase with AnomalyKey with StorageCredentials { -// -// lazy val startTime: String = "2021-01-01T00:00:00Z" -// lazy val endTime: String = "2021-01-02T12:00:00Z" -// lazy val timestampColumn: String = "timestamp" -// lazy val inputColumns: Array[String] = Array("feature0", "feature1", "feature2") -// lazy val intermediateSaveDir: String = -// s"wasbs://$containerName@$storageAccount.blob.core.windows.net/intermediateData" -// lazy val fileLocation: String = DatasetUtils.madTestFile("mad_example.csv").toString -// lazy val fileSchema: StructType = StructType(Array( -// StructField(timestampColumn, StringType, nullable = true) -// ) ++ inputColumns.map(inputCol => StructField(inputCol, DoubleType, nullable = true))) -// lazy val df: DataFrame = spark.read.format("csv") -// .option("header", "true").schema(fileSchema).load(fileLocation) -// -//} -// -//class SimpleFitMultivariateAnomalySuite extends EstimatorFuzzing[SimpleFitMultivariateAnomaly] -// with MADTestUtils with Flaky { -// -// def simpleMultiAnomalyEstimator: SimpleFitMultivariateAnomaly = new SimpleFitMultivariateAnomaly() -// .setSubscriptionKey(anomalyKey) -// .setLocation(anomalyLocation) -// .setOutputCol("result") -// .setStartTime(startTime) -// .setEndTime(endTime) -// .setIntermediateSaveDir(intermediateSaveDir) -// .setTimestampCol(timestampColumn) -// .setInputCols(inputColumns) -// -// test("SimpleFitMultivariateAnomaly basic usage") { -// val smae = simpleMultiAnomalyEstimator.setSlidingWindow(50) -// val model = smae.fit(df) -// smae.cleanUpIntermediateData() -// -// // model might not be ready -// tryWithRetries(Array(100, 500, 1000)) { () => -// val result = model -// .setStartTime(startTime) -// .setEndTime(endTime) -// .setOutputCol("result") -// .setTimestampCol(timestampColumn) -// .setInputCols(inputColumns) -// .transform(df) -// .collect() -// model.cleanUpIntermediateData() -// assert(result.length == df.collect().length) -// } -// } -// -// test("Throw errors if alignMode is not set correctly") { -// val caught = intercept[IllegalArgumentException] { -// simpleMultiAnomalyEstimator.setAlignMode("alignMode").fit(df) -// } -// assert(caught.getMessage.contains("alignMode must be either `inner` or `outer`.")) -// } -// -// test("Throw errors if slidingWindow is not between 28 and 2880") { -// val caught = intercept[IllegalArgumentException] { -// simpleMultiAnomalyEstimator.setSlidingWindow(20).fit(df) -// } -// assert(caught.getMessage.contains("slidingWindow must be between 28 and 2880 (both inclusive).")) -// } -// -// test("Throw errors if authentication is not provided") { -// val caught = intercept[IllegalAccessError] { -// new SimpleFitMultivariateAnomaly() -// .setSubscriptionKey(anomalyKey) -// .setLocation(anomalyLocation) -// .setIntermediateSaveDir(s"wasbs://$containerName@notreal.blob.core.windows.net/intermediateData") -// .setOutputCol("result") -// .setInputCols(Array("feature0")) -// .fit(df) -// } -// assert(caught.getMessage.contains("Could not find the storage account credentials.")) -// } -// -// test("Throw errors if start/end time is not ISO8601 format") { -// val caught = intercept[IllegalArgumentException] { -// val smae = simpleMultiAnomalyEstimator -// .setStartTime("2021-01-01 00:00:00") -// smae.fit(df) -// } -// assert(caught.getMessage.contains("StartTime should be ISO8601 format.")) -// -// val caught2 = intercept[IllegalArgumentException] { -// val smae = simpleMultiAnomalyEstimator -// .setEndTime("2021-01-01 00:00:00") -// smae.fit(df) -// } -// assert(caught2.getMessage.contains("EndTime should be ISO8601 format.")) -// } -// -// test("Expose correct error message during fitting") { -// val caught = intercept[RuntimeException] { -// val testDf = df.limit(50) -// simpleMultiAnomalyEstimator -// .fit(testDf) -// } -// assert(caught.getMessage.contains("TrainFailed")) -// } -// -// test("Expose correct error message during inference") { -// val caught = intercept[RuntimeException] { -// val testDf = df.limit(50) -// val smae = simpleMultiAnomalyEstimator -// val model = smae.fit(df) -// smae.cleanUpIntermediateData() -// assert(model.getDiagnosticsInfo.variableStates.get.length.equals(3)) -// -// model.setStartTime(startTime) -// .setEndTime(endTime) -// .setOutputCol("result") -// .setTimestampCol(timestampColumn) -// .setInputCols(inputColumns) -// .transform(testDf) -// .collect() -// } -// assert(caught.getMessage.contains("Not enough data.")) -// } -// -// test("Expose correct error message for invalid modelId") { -// val caught = intercept[RuntimeException] { -// val detectMultivariateAnomaly = new SimpleDetectMultivariateAnomaly() -// .setModelId("FAKE_MODEL_ID") -// .setSubscriptionKey(anomalyKey) -// .setLocation(anomalyLocation) -// .setIntermediateSaveDir(intermediateSaveDir) -// detectMultivariateAnomaly -// .setStartTime(startTime) -// .setEndTime(endTime) -// .setOutputCol("result") -// .setTimestampCol(timestampColumn) -// .setInputCols(inputColumns) -// .transform(df) -// .collect() -// } -// assert(caught.getMessage.contains("Encounter error while fetching model")) -// } -// -// test("return modelId after retries and get model status before inference") { -// val caught = intercept[RuntimeException] { -// val smae = simpleMultiAnomalyEstimator -// .setMaxPollingRetries(1) -// val model = smae.fit(df) -// smae.cleanUpIntermediateData() -// -// model.setStartTime(startTime) -// .setEndTime(endTime) -// .setOutputCol("result") -// .setTimestampCol(timestampColumn) -// .setInputCols(inputColumns) -// .transform(df) -// .collect() -// model.cleanUpIntermediateData() -// } -// assert(caught.getMessage.contains("not ready yet")) -// } -// -// override def testSerialization(): Unit = { -// println("ignore the Serialization Fuzzing test because fitting process takes more than 3 minutes") -// } -// -// override def testExperiments(): Unit = { -// println("ignore the Experiment Fuzzing test because fitting process takes more than 3 minutes") -// } -// -// override def afterAll(): Unit = { -// MADUtils.cleanUpAllModels(anomalyKey, anomalyLocation) -// super.afterAll() -// } -// -// override def beforeAll(): Unit = { -// super.beforeAll() -// val hc = spark.sparkContext.hadoopConfiguration -// hc.set("fs.azure", "org.apache.hadoop.fs.azure.NativeAzureFileSystem") -// hc.set(s"fs.azure.account.keyprovider.$storageAccount.blob.core.windows.net", -// "org.apache.hadoop.fs.azure.SimpleKeyProvider") -// hc.set(s"fs.azure.account.key.$storageAccount.blob.core.windows.net", storageKey) -// cleanOldModels() -// } -// -// override def testObjects(): Seq[TestObject[SimpleFitMultivariateAnomaly]] = -// Seq(new TestObject(simpleMultiAnomalyEstimator.setSlidingWindow(200), df)) -// -// def stringToTime(dateString: String): ZonedDateTime = { -// val tsFormat = "yyyy-MM-dd'T'HH:mm:ssz" -// val formatter = DateTimeFormatter.ofPattern(tsFormat) -// ZonedDateTime.parse(dateString, formatter) -// } -// -// def cleanOldModels(): Unit = { -// val url = simpleMultiAnomalyEstimator.setLocation(anomalyLocation).getUrl + "/" -// val twoDaysAgo = ZonedDateTime.now().minusDays(2) -// val modelSet: mutable.HashSet[String] = mutable.HashSet() -// var modelDeleted: Boolean = false -// -// // madListModels doesn't necessarily return all models, so just in case, -// // if we delete any models, we loop around to see if there are more to check. -// // scalastyle:off while -// do { -// modelDeleted = false -// val models = MADUtils.madListModels(anomalyKey, anomalyLocation) -// .parseJson.asJsObject().fields("models").asInstanceOf[JsArray].elements -// .map(modelJson => modelJson.asJsObject.fields("modelId").asInstanceOf[JsString].value) -// models.foreach { modelId => -// if (!modelSet.contains(modelId)) { -// modelSet += modelId -// val lastUpdated = -// MADUtils.madGetModel(url, modelId, anomalyKey).parseJson.asJsObject.fields("lastUpdatedTime") -// val lastUpdatedTime = stringToTime(lastUpdated.toString().replaceAll("\"", "")) -// if (lastUpdatedTime.isBefore(twoDaysAgo)) { -// println(s"Deleting $modelId") -// MADUtils.madDelete(modelId, anomalyKey, anomalyLocation) -// modelDeleted = true -// } -// } -// } -// } while (modelDeleted) -// // scalastyle:on while -// } -// -// override def reader: MLReadable[_] = SimpleFitMultivariateAnomaly -// -// override def modelReader: MLReadable[_] = SimpleDetectMultivariateAnomaly -//} -// -//class DetectLastMultivariateAnomalySuite extends TransformerFuzzing[DetectLastMultivariateAnomaly] -// with MADTestUtils { -// -// lazy val sfma: SimpleFitMultivariateAnomaly = { -// val hc: Configuration = spark.sparkContext.hadoopConfiguration -// hc.set("fs.azure", "org.apache.hadoop.fs.azure.NativeAzureFileSystem") -// hc.set(s"fs.azure.account.keyprovider.$storageAccount.blob.core.windows.net", -// "org.apache.hadoop.fs.azure.SimpleKeyProvider") -// hc.set(s"fs.azure.account.key.$storageAccount.blob.core.windows.net", storageKey) -// -// new SimpleFitMultivariateAnomaly() -// .setSubscriptionKey(anomalyKey) -// .setLocation(anomalyLocation) -// .setOutputCol("result") -// .setStartTime(startTime) -// .setEndTime(endTime) -// .setIntermediateSaveDir(intermediateSaveDir) -// .setTimestampCol(timestampColumn) -// .setInputCols(inputColumns) -// .setSlidingWindow(50) -// } -// -// lazy val modelId: String = { -// val model: SimpleDetectMultivariateAnomaly = sfma.fit(df) -// MADUtils.CreatedModels += model.getModelId -// model.getModelId -// } -// -// lazy val dlma: DetectLastMultivariateAnomaly = new DetectLastMultivariateAnomaly() -// .setSubscriptionKey(anomalyKey) -// .setLocation(anomalyLocation) -// .setModelId(modelId) -// .setInputVariablesCols(inputColumns) -// .setOutputCol("result") -// .setTimestampCol(timestampColumn) -// -// test("Basic Usage") { -// val result = dlma.setBatchSize(50) -// .transform(df.limit(100)) -// .collect() -// assert(result(0).get(6) == null) -// assert(!result(50).getAs[Boolean]("isAnomaly")) -// assert(result(68).getAs[Boolean]("isAnomaly")) -// } -// -// test("Error if batch size is smaller than sliding window") { -// val result = dlma.setBatchSize(10).transform(df.limit(50)) -// result.show(50, truncate = false) -// assert(result.collect().head.getAs[StringType](dlma.getErrorCol).toString.contains("NotEnoughData")) -// } -// -// override def afterAll(): Unit = { -// MADUtils.cleanUpAllModels(anomalyKey, anomalyLocation) -// sfma.cleanUpIntermediateData() -// super.afterAll() -// } -// -// override def testSerialization(): Unit = { -// println("ignore the Serialization Fuzzing test because fitting process takes more than 3 minutes") -// } -// -// override def testObjects(): Seq[TestObject[DetectLastMultivariateAnomaly]] = -// Seq(new TestObject(dlma, df)) -// -// override def reader: MLReadable[_] = DetectLastMultivariateAnomaly -//} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/MultivariateAnomalyDetectionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/MultivariateAnomalyDetectionSuite.scala new file mode 100644 index 00000000000..f2d46a094c7 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/anomaly/MultivariateAnomalyDetectionSuite.scala @@ -0,0 +1,189 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.anomaly + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.hadoop.fs.RemoteIterator +import org.apache.http.entity.AbstractHttpEntity +import org.apache.http.util.EntityUtils +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} +import spray.json._ + +class MultivariateAnomalyDetectionSuite extends TestBase { + + import MADJsonProtocol._ + + private class ExposedSimpleFitMultivariateAnomaly(uid: String) + extends SimpleFitMultivariateAnomaly(uid) { + def requestBody(dataSource: String): String = entityToBody(prepareEntity(dataSource).get) + } + + private class ExposedSimpleDetectMultivariateAnomaly(uid: String) + extends SimpleDetectMultivariateAnomaly(uid) { + def requestBody(dataSource: String): String = entityToBody(prepareEntity(dataSource).get) + } + + private class ExposedDetectLastMultivariateAnomaly(uid: String) + extends DetectLastMultivariateAnomaly(uid) { + def requestBody(row: Row): String = entityToBody(prepareEntity(row).get) + def requestUrl(row: Row): String = prepareUrl(row) + } + + private class TestRemoteIterator(values: Seq[Int]) extends RemoteIterator[Int] { + private val underlying = values.iterator + override def hasNext: Boolean = underlying.hasNext + override def next(): Int = underlying.next() + } + + private def entityToBody(entity: AbstractHttpEntity): String = EntityUtils.toString(entity) + + test("MADUtils.madUrl creates expected endpoint") { + assert(MADUtils.madUrl("westus2") == + "https://westus2.api.cognitive.microsoft.com/anomalydetector/v1.1/multivariate/") + } + + test("start and end times accept valid ISO instant format") { + val fit = new SimpleFitMultivariateAnomaly("fit-time-test") + .setStartTime("2024-01-01T00:00:00Z") + .setEndTime("2024-01-02T00:00:00Z") + + assert(fit.getStartTime == "2024-01-01T00:00:00Z") + assert(fit.getEndTime == "2024-01-02T00:00:00Z") + } + + test("invalid start time throws deterministic validation error") { + val exception = intercept[IllegalArgumentException] { + new SimpleFitMultivariateAnomaly("fit-invalid-time").setStartTime("2024/01/01") + } + assert(exception.getMessage.contains("StartTime should be ISO8601 format")) + } + + test("intermediateSaveDir validates accepted and rejected schemes") { + val fit = new SimpleFitMultivariateAnomaly("fit-save-dir") + .setIntermediateSaveDir("wasbs://container@account.blob.core.windows.net/datasets") + assert(fit.getIntermediateSaveDir.startsWith("wasbs://")) + + val exception = intercept[IllegalArgumentException] { + fit.setIntermediateSaveDir("file:///tmp") + } + assert(exception.getMessage.contains("improper HDFS loacation")) + } + + test("fit parameter setters normalize and validate values") { + val fit = new SimpleFitMultivariateAnomaly("fit-params") + .setSlidingWindow(28) + .setAlignMode("inner") + .setFillNAMethod("fixed") + + assert(fit.getSlidingWindow == 28) + assert(fit.getAlignMode == "Inner") + assert(fit.getFillNAMethod == "Fixed") + + intercept[IllegalArgumentException] { + fit.setSlidingWindow(27) + } + intercept[IllegalArgumentException] { + fit.setAlignMode("bad-mode") + } + intercept[IllegalArgumentException] { + fit.setFillNAMethod("bad-method") + } + } + + test("SimpleFitMultivariateAnomaly prepareEntity uses deterministic defaults") { + val fit = new ExposedSimpleFitMultivariateAnomaly("fit-default-entity") + .setStartTime("2024-01-01T00:00:00Z") + .setEndTime("2024-01-02T00:00:00Z") + val request = fit.requestBody("https://storage/path.csv").parseJson.convertTo[MAERequest] + + assert(request.dataSource == "https://storage/path.csv") + assert(request.dataSchema == "OneTable") + assert(request.slidingWindow.contains(300)) + assert(request.alignPolicy.exists(_.alignMode.contains("Outer"))) + assert(request.alignPolicy.exists(_.fillNAMethod.contains("Linear"))) + assert(request.alignPolicy.exists(_.paddingValue.isEmpty)) + assert(request.displayName.isEmpty) + } + + test("SimpleFitMultivariateAnomaly prepareEntity reflects configured values") { + val fit = new ExposedSimpleFitMultivariateAnomaly("fit-custom-entity") + .setStartTime("2024-01-01T00:00:00Z") + .setEndTime("2024-01-02T00:00:00Z") + .setSlidingWindow(120) + .setAlignMode("inner") + .setFillNAMethod("fixed") + .setPaddingValue(9) + .setDisplayName("demo-model") + + val request = fit.requestBody("https://storage/path.csv").parseJson.convertTo[MAERequest] + assert(request.slidingWindow.contains(120)) + assert(request.alignPolicy.contains(AlignPolicy(Some("Inner"), Some("Fixed"), Some(9)))) + assert(request.displayName.contains("demo-model")) + } + + test("SimpleDetectMultivariateAnomaly prepareEntity respects topContributorCount") { + val detect = new ExposedSimpleDetectMultivariateAnomaly("detect-entity") + .setStartTime("2024-01-01T00:00:00Z") + .setEndTime("2024-01-02T00:00:00Z") + .setTopContributorCount(7) + + val request = detect.requestBody("https://storage/path.csv").parseJson.convertTo[DMARequest] + assert(request.dataSource == "https://storage/path.csv") + assert(request.topContributorCount.contains(7)) + } + + test("DetectLastMultivariateAnomaly prepareEntity builds expected variables payload") { + import spark.implicits._ + val row = Seq( + (Seq("2024-01-01T00:00:00Z", "2024-01-01T00:01:00Z"), Seq(1.0, 2.0), Seq(5.0, 6.0)) + ).toDF("timestamp_list", "varA_list", "varB_list").head() + + val detectLast = new ExposedDetectLastMultivariateAnomaly("detect-last-entity") + .setInputVariablesCols(Array("varA", "varB")) + .setTopContributorCount(3) + + val request = detectLast.requestBody(row).parseJson.convertTo[DLMARequest] + assert(request.topContributorCount == 3) + assert(request.variables.map(_.variable) == Seq("varA", "varB")) + assert(request.variables.forall(_.timestamps == Seq("2024-01-01T00:00:00Z", "2024-01-01T00:01:00Z"))) + assert(request.variables.find(_.variable == "varA").exists(_.values == Seq(1.0, 2.0))) + } + + test("DetectLastMultivariateAnomaly prepareUrl appends model and detect-last path") { + val detectLast = new ExposedDetectLastMultivariateAnomaly("detect-last-url") + .setLocation("westus2") + .setModelId("model-123") + + assert(detectLast.requestUrl(Row.empty) == + "https://westus2.api.cognitive.microsoft.com/anomalydetector/v1.1/multivariate/models/model-123:detect-last") + } + + test("transformSchema appends deterministic output, error and isAnomaly fields") { + val inputSchema = StructType(Seq(StructField("timestamp", StringType))) + + val fitSchema = new SimpleFitMultivariateAnomaly("fit-schema") + .setOutputCol("fitOutput") + .setErrorCol("fitError") + .transformSchema(inputSchema) + assert(fitSchema("fitOutput").dataType == DMAResponse.schema) + assert(fitSchema("fitError").dataType == DMAError.schema) + assert(fitSchema("isAnomaly").dataType == BooleanType) + + val detectLastSchema = new DetectLastMultivariateAnomaly("detect-last-schema") + .setOutputCol("detectOutput") + .setErrorCol("detectError") + .transformSchema(inputSchema) + assert(detectLastSchema("detectOutput").dataType == DLMAResponse.schema) + assert(detectLastSchema("detectError").dataType == DMAError.schema) + assert(detectLastSchema("isAnomaly").dataType == BooleanType) + } + + test("remote iterator conversion wraps Hadoop iterator deterministically") { + import Conversions._ + val iterator: Iterator[Int] = new TestRemoteIterator(Seq(1, 2, 3)) + assert(iterator.toSeq == Seq(1, 2, 3)) + } + +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/form/FormCoreOfflineSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/form/FormCoreOfflineSuite.scala new file mode 100644 index 00000000000..65f6423a8a3 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/form/FormCoreOfflineSuite.scala @@ -0,0 +1,235 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.form + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import com.microsoft.azure.synapse.ml.io.http.ErrorUtils +import org.apache.http.client.utils.URLEncodedUtils +import org.apache.http.entity.AbstractHttpEntity +import org.apache.http.util.EntityUtils +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType} +import spray.json.DefaultJsonProtocol._ +import spray.json._ + +import java.net.URI +import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ + +class FormCoreOfflineSuite extends TestBase { + + private class ExposedAnalyzeLayout(uid: String) extends AnalyzeLayout(uid) { + def requestEntity(row: Row): Option[AbstractHttpEntity] = prepareEntity(row) + def requestUrl(row: Row): String = prepareUrl(row) + } + + private class ExposedAnalyzeDocument(uid: String) extends AnalyzeDocument(uid) { + def requestEntity(row: Row): Option[AbstractHttpEntity] = prepareEntity(row) + def requestUrl(row: Row): String = prepareUrl(row) + } + + private class ExposedGetCustomModel(uid: String) extends GetCustomModel(uid) { + def requestUrl(row: Row): String = prepareUrl(row) + } + + private class ExposedAnalyzeCustomModel(uid: String) extends AnalyzeCustomModel(uid) { + def requestUrl(row: Row): String = prepareUrl(row) + } + + private def entityToBody(entity: AbstractHttpEntity): String = EntityUtils.toString(entity) + + private def queryParams(url: String): Map[String, String] = { + URLEncodedUtils + .parse(new URI(url), StandardCharsets.UTF_8) + .asScala + .map(p => p.getName -> p.getValue) + .toMap + } + + test("Form Recognizer params validate allowed values offline") { + val receipts = new AnalyzeReceipts("offline-receipts").setLocale("en-US") + assert(receipts.getLocale == "en-US") + + val document = new AnalyzeDocument("offline-document") + .setStringIndexType("utf16CodeUnit") + .setFeatures(Seq("barcodes", "languages")) + assert(document.getStringIndexType == "utf16CodeUnit") + assert(document.getFeatures == Seq("barcodes", "languages")) + + intercept[IllegalArgumentException] { + receipts.setLocale("fr-FR") + } + intercept[IllegalArgumentException] { + document.setStringIndexType("bad-value") + } + intercept[IllegalArgumentException] { + document.setFeatures(Seq("barcodes", "unsupported")) + } + } + + test("Form recognizer request entities build deterministic url and byte payloads") { + val urlInput = "https://contoso.example/layout.jpg" + val bytesInput = Array[Byte](1, 2, 3) + + val layoutFromUrl = new ExposedAnalyzeLayout("layout-url").setImageUrl(urlInput) + val urlPayload = entityToBody(layoutFromUrl.requestEntity(Row.empty).get).parseJson.asJsObject + assert(urlPayload.fields("source").convertTo[String] == urlInput) + + val layoutFromBytes = new ExposedAnalyzeLayout("layout-bytes").setImageBytes(bytesInput) + val bodyBytes = EntityUtils.toByteArray(layoutFromBytes.requestEntity(Row.empty).get) + assert(bodyBytes.sameElements(bytesInput)) + } + + test("AnalyzeDocument builds v3 request url and payload from local params") { + val imageUrl = "https://contoso.example/form.pdf" + val analyzeDocument = new ExposedAnalyzeDocument("analyze-document-url") + .setLocation("eastus") + .setPrebuiltModelId("prebuilt-layout") + .setImageUrl(imageUrl) + .setPages("1-2") + .setStringIndexType("utf16CodeUnit") + .setFeatures(Seq("barcodes", "languages")) + + val url = analyzeDocument.requestUrl(Row.empty) + val uri = new URI(url) + val query = queryParams(url) + assert(uri.getPath.endsWith("/formrecognizer/documentModels/prebuilt-layout:analyze")) + assert(query("api-version") == "2023-07-31") + assert(query("pages") == "1-2") + assert(query("stringIndexType") == "utf16CodeUnit") + assert(query("features") == "List(barcodes, languages)") + + val requestBody = entityToBody(analyzeDocument.requestEntity(Row.empty).get).parseJson.asJsObject + assert(requestBody.fields("urlSource").convertTo[String] == imageUrl) + } + + test("custom model endpoints append model id deterministically") { + val getModel = new ExposedGetCustomModel("get-custom-model-url") + .setLocation("eastus") + .setModelId("model-123") + .setIncludeKeys(true) + val getModelUrl = getModel.requestUrl(Row.empty) + assert(new URI(getModelUrl).getPath.endsWith("/formrecognizer/v2.1/custom/models/model-123")) + assert(queryParams(getModelUrl)("includeKeys") == "true") + + val analyzeCustom = new ExposedAnalyzeCustomModel("analyze-custom-model-url") + .setLocation("eastus") + .setModelId("model-123") + assert(new URI(analyzeCustom.requestUrl(Row.empty)) + .getPath.endsWith("/formrecognizer/v2.1/custom/models/model-123/analyze")) + } + + test("form recognizer schemas expose deterministic output and error columns") { + spark + val inputSchema = StructType(Seq(StructField("id", StringType, nullable = true))) + + val analyzeDocumentSchema = new AnalyzeDocument("analyze-document-schema") + .setLocation("eastus") + .setPrebuiltModelId("prebuilt-read") + .setImageUrl("https://contoso.example/doc.png") + .setOutputCol("documentResult") + .setErrorCol("documentError") + .transformSchema(inputSchema) + assert(analyzeDocumentSchema("documentResult").dataType == AnalyzeDocumentResponse.schema) + assert(analyzeDocumentSchema("documentError").dataType == ErrorUtils.ErrorSchema) + + val analyzeLayoutSchema = new AnalyzeLayout("analyze-layout-schema") + .setLocation("eastus") + .setImageUrl("https://contoso.example/doc.png") + .setOutputCol("layoutResult") + .setErrorCol("layoutError") + .transformSchema(inputSchema) + assert(analyzeLayoutSchema("layoutResult").dataType == AnalyzeResponse.schema) + assert(analyzeLayoutSchema("layoutError").dataType == ErrorUtils.ErrorSchema) + } + + test("field result helpers convert recursive values and simplify data types") { + import FormsJsonProtocol._ + + val numberValue = FieldResultRecursive( + `type` = "number", + page = None, + confidence = None, + boundingBox = None, + text = Some("7"), + valueString = None, + valuePhoneNumber = None, + valueNumber = Some(7.0), + valueDate = None, + valueTime = None, + valueObject = None, + valueArray = None + ) + val textValue = FieldResultRecursive( + `type` = "string", + page = None, + confidence = None, + boundingBox = None, + text = None, + valueString = Some("widget"), + valuePhoneNumber = None, + valueNumber = None, + valueDate = None, + valueTime = None, + valueObject = None, + valueArray = None + ) + + val mixedArray = FieldResultRecursive( + `type` = "array", + page = None, + confidence = None, + boundingBox = None, + text = None, + valueString = None, + valuePhoneNumber = None, + valueNumber = None, + valueDate = None, + valueTime = None, + valueObject = None, + valueArray = Some(Seq(textValue, numberValue)) + ) + assert(mixedArray.toSimplifiedDataType == ArrayType(StringType)) + + val objectValue = FieldResult( + `type` = "object", + page = None, + confidence = None, + boundingBox = None, + text = None, + valueString = None, + valuePhoneNumber = None, + valueNumber = None, + valueDate = None, + valueTime = None, + valueObject = Some(Map("count" -> numberValue, "label" -> textValue).toJson.compactPrint), + valueArray = None + ).toFieldResultRecursive + + val objectType = objectValue.toSimplifiedDataType.asInstanceOf[StructType] + assert(objectType.fieldNames.toSet == Set("count", "label")) + val objectRow = objectValue.viewAsDataType(objectType).asInstanceOf[Row] + val objectValues = objectRow.toSeq.toSet + assert(objectValues.contains(Some(7.0))) + assert(objectValues.contains(Some("widget"))) + assert(numberValue.viewAsDataType(StringType) == "7") + assert(numberValue.viewAsDataType(DoubleType) == 7.0) + } + + test("FormOntologyLearner.combineDataTypes merges nested schemas deterministically") { + val left = StructType(Seq( + StructField("shared", StringType, nullable = true), + StructField("leftOnly", DoubleType, nullable = true) + )) + val right = StructType(Seq( + StructField("shared", DoubleType, nullable = true), + StructField("rightOnly", StringType, nullable = true) + )) + + val merged = FormOntologyLearner.combineDataTypes(left, right).asInstanceOf[StructType] + assert(merged.fieldNames.toSet == Set("shared", "leftOnly", "rightOnly")) + assert(merged("shared").dataType == StringType) + assert(FormOntologyLearner.combineDataTypes(ArrayType(StringType), ArrayType(DoubleType)) == ArrayType(StringType)) + } +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/geospatial/GeospatialCoreSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/geospatial/GeospatialCoreSuite.scala new file mode 100644 index 00000000000..d7becb8b103 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/geospatial/GeospatialCoreSuite.scala @@ -0,0 +1,167 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.geospatial + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.http.client.methods.{HttpGet, HttpPost} +import org.apache.http.util.EntityUtils +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.ArrayType + +import java.net.{URI, URLDecoder} + +private[geospatial] class TestableAddressGeocoder extends AddressGeocoder { + def buildRequest(row: Row): Option[HttpPost] = + inputFunc(row).map(_.asInstanceOf[HttpPost]) +} + +private[geospatial] class TestableReverseAddressGeocoder extends ReverseAddressGeocoder { + def buildRequest(row: Row): Option[HttpPost] = + inputFunc(row).map(_.asInstanceOf[HttpPost]) +} + +private[geospatial] class TestableCheckPointInPolygon extends CheckPointInPolygon { + def buildRequest(row: Row): Option[HttpGet] = + inputFunc(row).map(_.asInstanceOf[HttpGet]) +} + +class GeospatialCoreSuite extends TestBase { + + import spark.implicits._ + + private def toQueryMap(uri: URI): Map[String, String] = { + Option(uri.getRawQuery).toSeq.flatMap(_.split("&")).map { kv => + val pair = kv.split("=", 2) + val key = URLDecoder.decode(pair(0), "UTF-8") + val value = if (pair.length > 1) URLDecoder.decode(pair(1), "UTF-8") else "" + key -> value + }.toMap + } + + test("address geocoder builds deterministic request payload and query params") { + val request = new TestableAddressGeocoder() + .setSubscriptionKey("fake-key") + .setAddress(Seq("One Microsoft Way, Redmond", "400 Broad St, Seattle")) + .buildRequest(Row.empty) + .get + + val query = toQueryMap(request.getURI) + assert(request.getURI.getPath.endsWith("/search/address/batch/json")) + assert(query("api-version") == "1.0") + assert(query("subscription-key") == "fake-key") + assert(request.getFirstHeader("Content-Type").getValue == "application/json") + + val payload = EntityUtils.toString(request.getEntity, "UTF-8") + assert(payload.contains("?query=One+Microsoft+Way%2C+Redmond&limit=1")) + assert(payload.contains("?query=400+Broad+St%2C+Seattle&limit=1")) + } + + test("reverse geocoder builds deterministic request payload and query params") { + val request = new TestableReverseAddressGeocoder() + .setSubscriptionKey("fake-key") + .setLatitude(Seq(48.858561, 47.639765)) + .setLongitude(Seq(2.294911, -122.127896)) + .buildRequest(Row.empty) + .get + + val query = toQueryMap(request.getURI) + assert(request.getURI.getPath.endsWith("/search/address/reverse/batch/json")) + assert(query("api-version") == "1.0") + assert(query("subscription-key") == "fake-key") + assert(request.getFirstHeader("Content-Type").getValue == "application/json") + + val payload = EntityUtils.toString(request.getEntity, "UTF-8") + assert(payload.contains("?query=48.858561,2.294911&limit=1")) + assert(payload.contains("?query=47.639765,-122.127896&limit=1")) + } + + test("address and reverse schema behavior is deterministic") { + val addressInput = Seq(Seq("One Microsoft Way, Redmond")).toDF("address") + val addressSchema = new AddressGeocoder() + .setAddressCol("address") + .setOutputCol("output") + .setErrorCol("addressError") + .transformSchema(addressInput.schema) + assert(addressSchema.fieldNames.toSet == Set("address", "output", "addressError")) + assert(addressSchema("output").dataType == ArrayType(SearchAddressBatchItem.schema)) + + val reverseInput = Seq((Seq(47.6418), Seq(-122.1275))).toDF("latitude", "longitude") + val reverseSchema = new ReverseAddressGeocoder() + .setLatitudeCol("latitude") + .setLongitudeCol("longitude") + .setOutputCol("output") + .setErrorCol("reverseError") + .transformSchema(reverseInput.schema) + assert(reverseSchema.fieldNames.toSet == Set("latitude", "longitude", "output", "reverseError")) + assert(reverseSchema("output").dataType == ArrayType(ReverseSearchAddressBatchItem.schema)) + } + + test("geospatial transformers validate missing input columns locally") { + val addressInput = Seq(Seq("One Microsoft Way, Redmond")).toDF("address") + val addressError = intercept[AssertionError] { + new AddressGeocoder().setAddressCol("missingAddress").transformSchema(addressInput.schema) + } + assert(addressError.getMessage.contains("Could not find dynamic columns")) + assert(addressError.getMessage.contains("missingAddress")) + + val reverseInput = Seq((Seq(47.6418), Seq(-122.1275))).toDF("latitude", "longitude") + val reverseError = intercept[AssertionError] { + new ReverseAddressGeocoder() + .setLatitudeCol("latitude") + .setLongitudeCol("missingLongitude") + .transformSchema(reverseInput.schema) + } + assert(reverseError.getMessage.contains("Could not find dynamic columns")) + assert(reverseError.getMessage.contains("missingLongitude")) + } + + test("checkpoint helper logic, schema behavior, and retired transform are deterministic") { + val transformer = new TestableCheckPointInPolygon() + .setSubscriptionKey("fake-key") + .setGeography("us") + .setUserDataIdentifier("udid-1") + .setLatitude(47.6418) + .setLongitude(-122.1275) + + assert(transformer.getUrl == "https://us.atlas.microsoft.com/spatial/pointInPolygon/json") + assert(transformer.getLatitude == Seq(47.6418)) + assert(transformer.getLongitude == Seq(-122.1275)) + assert(transformer.getUserDataIdentifier == "udid-1") + + val request = transformer.buildRequest(Row.empty).get + val query = toQueryMap(request.getURI) + assert(request.getURI.getPath.endsWith("/spatial/pointInPolygon/json")) + assert(query("api-version") == "1.0") + assert(query("subscription-key") == "fake-key") + assert(query("udid") == "udid-1") + assert(query("lat").contains("47.6418")) + assert(query("lon").contains("-122.1275")) + + val input = Seq((Seq(47.6418), Seq(-122.1275), "udid-1")).toDF("latitude", "longitude", "udid") + val schema = new CheckPointInPolygon() + .setLatitudeCol("latitude") + .setLongitudeCol("longitude") + .setUserDataIdentifierCol("udid") + .setOutputCol("pointInPolygon") + .setErrorCol("pointInPolygonError") + .transformSchema(input.schema) + assert(schema.fieldNames.toSet == Set("latitude", "longitude", "udid", "pointInPolygon", "pointInPolygonError")) + assert(schema("pointInPolygon").dataType == PointInPolygonProcessResult.schema) + + val missingColumnError = intercept[AssertionError] { + new CheckPointInPolygon() + .setLatitudeCol("latitude") + .setLongitudeCol("missingLongitude") + .setUserDataIdentifierCol("udid") + .transformSchema(input.schema) + } + assert(missingColumnError.getMessage.contains("Could not find dynamic columns")) + assert(missingColumnError.getMessage.contains("missingLongitude")) + + val retiredError = intercept[UnsupportedOperationException] { + transformer.transform(input) + } + assert(retiredError.getMessage.contains("retired on September 30, 2025")) + } +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextCoreOfflineSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextCoreOfflineSuite.scala new file mode 100644 index 00000000000..c54c64db460 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextCoreOfflineSuite.scala @@ -0,0 +1,159 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.language + +import com.microsoft.azure.synapse.ml.io.http.{ + EntityData, HTTPResponseData, ProtocolVersionData, StatusLineData +} +import org.apache.http.entity.StringEntity +import org.apache.http.util.EntityUtils +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.DataType +import org.scalatest.funsuite.AnyFunSuite +import spray.json._ +import spray.json.DefaultJsonProtocol._ + +import java.net.URI + +private[language] class TestableAnalyzeText extends AnalyzeText { + def buildEntity(row: Row): StringEntity = prepareEntity(row).get.asInstanceOf[StringEntity] + + def outputSchema: DataType = responseDataType +} + +private[language] class TestableAnalyzeTextLRO extends AnalyzeTextLongRunningOperations { + def buildEntity(row: Row): StringEntity = prepareEntity(row).get.asInstanceOf[StringEntity] + + def outputSchema: DataType = responseDataType + + def buildPollingURI(uri: URI): URI = modifyPollingURI(uri) +} + +class AnalyzeTextCoreOfflineSuite extends AnyFunSuite { + + private def parseEntity(entity: StringEntity): JsObject = { + EntityUtils.toString(entity, "UTF-8").parseJson.asJsObject + } + + private def responseWithEntity(json: String): HTTPResponseData = { + HTTPResponseData( + headers = Array.empty, + entity = Some(EntityData( + json.getBytes("UTF-8"), + None, + None, + None, + isChunked = false, + isRepeatable = true, + isStreaming = false)), + statusLine = StatusLineData(ProtocolVersionData("HTTP", 1, 1), 200, "OK"), + locale = "en_US") + } + + test("local parameter validation rejects invalid values") { + intercept[IllegalArgumentException] { + new AnalyzeText().setKind("UnknownTask") + } + intercept[IllegalArgumentException] { + new AnalyzeText().setStringIndexType("invalid-index") + } + intercept[IllegalArgumentException] { + new AnalyzeTextLongRunningOperations().setSentenceCount(0) + } + intercept[IllegalArgumentException] { + new AnalyzeTextLongRunningOperations().setSortBy("Score") + } + intercept[IllegalArgumentException] { + new AnalyzeTextLongRunningOperations().setSummaryLength("tiny") + } + } + + test("analyze text request-building is deterministic for language detection") { + val transformer = new TestableAnalyzeText() + .setKind("LanguageDetection") + .setText(Seq("Hello", "")) + .setCountryHint("US") + .setModelVersion("2024-10-01") + .setLoggingOptOut(true) + + val payload = parseEntity(transformer.buildEntity(Row.empty)) + assert(payload.fields("kind").convertTo[String] == "LanguageDetection") + + val analysisInput = payload.fields("analysisInput").asJsObject + val JsArray(documents) = analysisInput.fields("documents") + assert(documents.length == 2) + assert(documents.head.asJsObject.fields("countryHint").convertTo[String] == "US") + assert(documents(1).asJsObject.fields("countryHint").convertTo[String] == "US") + assert(documents(1).asJsObject.fields("text").convertTo[String] == "") + + val params = payload.fields("parameters").asJsObject + assert(params.fields("loggingOptOut").convertTo[Boolean]) + assert(params.fields("modelVersion").convertTo[String] == "2024-10-01") + } + + test("schema behavior follows selected task kind") { + val analyze = new TestableAnalyzeText().setKind("EntityLinking") + assert(analyze.outputSchema == EntityLinkingResponse.schema) + analyze.setKind("SentimentAnalysis") + assert(analyze.outputSchema == SentimentResponse.schema) + + val lro = new TestableAnalyzeTextLRO() + .setKind(AnalysisTaskKind.CustomMultiLabelClassification) + assert(lro.outputSchema == CustomLabelJobState.schema) + lro.setKind(AnalysisTaskKind.Healthcare) + assert(lro.outputSchema == HealthcareJobState.schema) + } + + test("lro request-building captures helper options deterministically") { + val transformer = new TestableAnalyzeTextLRO() + .setKind(AnalysisTaskKind.EntityRecognition) + .setText(Seq("John Doe")) + .setLanguage("en") + .setModelVersion("2024-06-01") + .setStringIndexType("UnicodeCodePoint") + .setLoggingOptOut(true) + .setInclusionList(Seq("Person")) + .setOverlapPolicy("allowOverlap") + .setExcludeNormalizedValues(true) + + val payload = parseEntity(transformer.buildEntity(Row.empty)) + val analysisInput = payload.fields("analysisInput").asJsObject + val JsArray(documents) = analysisInput.fields("documents") + assert(documents.head.asJsObject.fields("language").convertTo[String] == "en") + + val JsArray(tasks) = payload.fields("tasks") + val parameters = tasks.head.asJsObject.fields("parameters").asJsObject + assert(parameters.fields("modelVersion").convertTo[String] == "2024-06-01") + assert(parameters.fields("stringIndexType").convertTo[String] == "UnicodeCodePoint") + assert(parameters.fields("inclusionList").convertTo[Seq[String]] == Seq("Person")) + assert( + parameters.fields("overlapPolicy").asJsObject.fields("policyKind").convertTo[String] == "allowOverlap") + assert( + parameters.fields("inferenceOptions").asJsObject.fields("excludeNormalizedValues").convertTo[Boolean]) + } + + test("helper logic is deterministic for polling uri, kind mapping, and response rewrite") { + assert(AnalysisTaskKind.getKindFromString("Healthcare") == AnalysisTaskKind.Healthcare) + val ex = intercept[IllegalArgumentException] { + AnalysisTaskKind.getKindFromString("Nope") + } + assert(ex.getMessage.contains("Invalid kind")) + + val uri = new URI("https://example.test/jobs/1?api-version=2023-04-01") + val noStats = new TestableAnalyzeTextLRO() + assert(noStats.buildPollingURI(uri) == uri) + noStats.setShowStats(true) + assert(noStats.buildPollingURI(uri).toString.endsWith("&showStats=true")) + + val raw = responseWithEntity("""{"class":"Top","nested":{"class":"Secondary"}}""") + val rewritten = new TestableAnalyzeTextLRO() + .setKind(AnalysisTaskKind.CustomSingleLabelClassification) + .modifyResponse(Some(raw)) + .get + val rewrittenBody = new String(rewritten.entity.get.content, "UTF-8") + assert(rewrittenBody.contains("\"classifications\":\"Top\"")) + assert(rewrittenBody.contains("\"classifications\":\"Secondary\"")) + assert(!rewrittenBody.contains("\"class\":")) + } +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICoreOfflineSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICoreOfflineSuite.scala new file mode 100644 index 00000000000..d5af8139f20 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICoreOfflineSuite.scala @@ -0,0 +1,186 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.openai + +import org.apache.http.entity.StringEntity +import org.apache.http.util.EntityUtils +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, StringType, StructField, StructType} +import org.scalatest.funsuite.AnyFunSuite +import spray.json._ +import spray.json.DefaultJsonProtocol._ + +class OpenAICoreOfflineSuite extends AnyFunSuite { + + private val stringMessageSchema = StructType(Seq( + StructField("role", StringType, nullable = false), + StructField("content", StringType, nullable = true), + StructField("name", StringType, nullable = true) + )) + + private val compositeMessageSchema = StructType(Seq( + StructField("role", StringType, nullable = false), + StructField( + "content", + ArrayType( + MapType(StringType, StringType, valueContainsNull = true), + containsNull = false + ), + nullable = true + ), + StructField("name", StringType, nullable = true) + )) + + private val unsupportedContentSchema = StructType(Seq( + StructField("role", StringType, nullable = false), + StructField("content", IntegerType, nullable = false) + )) + + private def messageRow(role: String, content: String): Row = + new GenericRowWithSchema(Array[Any](role, content, ""), stringMessageSchema) + + private def compositeMessageRow(role: String, parts: Seq[Map[String, String]]): Row = + new GenericRowWithSchema(Array[Any](role, parts, ""), compositeMessageSchema) + + private def unsupportedMessageRow(role: String, content: Int): Row = + new GenericRowWithSchema(Array[Any](role, content), unsupportedContentSchema) + + private def parseEntity(entity: StringEntity): JsObject = + EntityUtils.toString(entity).parseJson.asJsObject + + test("encodeMessagesToMap supports text and composite message shapes") { + val chat = new OpenAIChatCompletion() + val compositeParts = Seq( + Map("type" -> "text", "text" -> "first"), + Map("type" -> "input_file", "filename" -> "example.txt") + ) + + val mapped = chat.encodeMessagesToMap(Seq( + messageRow("user", "hello"), + compositeMessageRow("assistant", compositeParts) + )) + + assert(mapped.head("role") == "user") + assert(mapped.head("content") == "hello") + val secondContent = mapped(1)("content").asInstanceOf[Seq[Map[String, Any]]] + assert(secondContent.head("type") == "text") + assert(secondContent(1)("type") == "input_file") + } + + test("encodeMessagesToMap rejects unsupported content types") { + val chat = new OpenAIChatCompletion() + val ex = intercept[IllegalArgumentException] { + chat.encodeMessagesToMap(Seq(unsupportedMessageRow("user", 123))) + } + assert(ex.getMessage.contains("Unsupported content type")) + } + + test("OpenAIChatCompletion getStringEntity collapses content parts into text") { + val chat = new OpenAIChatCompletion() + val messageParts = Seq( + Map("type" -> "text", "text" -> "Line one"), + Map("type" -> "input_file", "filename" -> "example.txt"), + Map("type" -> "text", "text" -> "Line two") + ) + + val entity = chat.getStringEntity( + Seq(compositeMessageRow("user", messageParts)), + Map("temperature" -> 0.0) + ) + + val payload = parseEntity(entity) + val JsArray(messages) = payload.fields("messages") + val content = messages.head.asJsObject.fields("content").convertTo[String] + + assert(content == "Line one\nLine two") + } + + test("OpenAIChatCompletion response_format wraps bare schemas and exposes type") { + val chat = new OpenAIChatCompletion() + chat.setResponseFormat(Map( + "name" -> "answer_schema", + "strict" -> true, + "schema" -> Map( + "type" -> "object", + "properties" -> Map("answer" -> Map("type" -> "string")) + ) + )) + + val responseFormat = chat.getResponseFormat + assert(chat.getResponseFormatType == "json_schema") + assert(responseFormat("type") == "json_schema") + val jsonSchema = responseFormat("json_schema").asInstanceOf[Map[String, Any]] + assert(jsonSchema("name") == "answer_schema") + assert(jsonSchema.contains("schema")) + } + + test("OpenAIResponses optional params merge text/reasoning and drop gpt-5 sampling") { + val responses = new OpenAIResponses() + .setDeploymentName("gpt-5-mini") + .setTemperature(0.3) + .setTopP(0.7) + .setResponseFormat("json_object") + .setVerbosity("high") + .setReasoningEffort("medium") + + val params = responses.getOptionalParams(messageRow("user", "hello")) + + assert(params("model") == "gpt-5-mini") + assert(!params.contains("temperature")) + assert(!params.contains("top_p")) + assert(!params.contains("reasoning_effort")) + + val text = params("text").asInstanceOf[Map[String, Any]] + val format = text("format").asInstanceOf[Map[String, Any]] + assert(format("type") == "json_object") + assert(text("verbosity") == "high") + + val reasoning = params("reasoning").asInstanceOf[Map[String, Any]] + assert(reasoning("effort") == "medium") + } + + test("OpenAIResponses keeps sampling params for non-gpt5 deployments") { + val responses = new OpenAIResponses() + .setDeploymentName("gpt-4.1-mini") + .setTemperature(0.2) + .setTopP(0.6) + + val params = responses.getOptionalParams(messageRow("user", "hello")) + + assert(params("model") == "gpt-4.1-mini") + assert(params("temperature") == 0.2) + assert(params("top_p") == 0.6) + } + + test("OpenAIResponses getStringEntity wraps plain text and preserves composite parts") { + val responses = new OpenAIResponses() + val compositeParts = Seq( + Map("type" -> "input_file", "filename" -> "example.txt", "file_data" -> "AAA") + ) + + val entity = responses.getStringEntity( + Seq( + messageRow("user", "plain text"), + compositeMessageRow("user", compositeParts) + ), + Map("model" -> "gpt-4.1-mini") + ) + + val payload = parseEntity(entity) + val JsArray(inputs) = payload.fields("input") + + val JsArray(firstContent) = inputs.head.asJsObject.fields("content") + assert(firstContent.head.asJsObject.fields("type").convertTo[String] == "input_text") + assert(firstContent.head.asJsObject.fields("text").convertTo[String] == "plain text") + + val JsArray(secondContent) = inputs(1).asJsObject.fields("content") + assert(secondContent.head.asJsObject.fields("type").convertTo[String] == "input_file") + } + + test("OpenAI chat and responses stages expose expected response schemas") { + assert(new OpenAIChatCompletion().responseDataType == ChatModelResponse.schema) + assert(new OpenAIResponses().responseDataType == ResponsesModelResponse.schema) + } +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptParserSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptParserSuite.scala new file mode 100644 index 00000000000..5ce4fb8eb2c --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptParserSuite.scala @@ -0,0 +1,126 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} + +import java.nio.charset.StandardCharsets +import java.nio.file.Files + +class OpenAIPromptParserSuite extends TestBase { + + import spark.implicits._ + + test("PassThroughParser returns input text and string schema") { + val parser = new PassThroughParser() + val parsed = Seq(" keep spacing ").toDF("response") + .select(parser.parse(col("response")).alias("parsed")) + .head() + .getString(0) + + assert(parsed == " keep spacing ") + assert(parser.outputSchema == StringType) + } + + test("DelimiterParser trims outer whitespace and splits values") { + val parser = new DelimiterParser(",") + val parsed = Seq(" apple, banana ,carrot ").toDF("response") + .select(parser.parse(col("response")).alias("parsed")) + .head() + .getSeq[String](0) + + assert(parsed == Seq("apple", " banana ", "carrot")) + assert(parser.outputSchema == ArrayType(StringType)) + } + + test("JsonParser removes code fences and parses JSON by schema") { + val schema = "name STRING, value INT" + val parser = new JsonParser(schema, Map.empty) + val parsed = Seq( + """```json + |{"name":"alpha","value":7} + |```""".stripMargin + ).toDF("response") + .select(parser.parse(col("response")).alias("parsed")) + .head() + .getAs[Row]("parsed") + + assert(parsed.getAs[String]("name") == "alpha") + assert(parsed.getAs[Int]("value") == 7) + assert(parser.outputSchema == DataType.fromDDL(schema)) + } + + test("RegexParser extracts configured group and uses string schema") { + val parser = new RegexParser("score=(\\d+)", 1) + val parsed = Seq("score=42 done").toDF("response") + .select(parser.parse(col("response")).alias("parsed")) + .head() + .getString(0) + + assert(parsed == "42") + assert(parser.outputSchema == StringType) + } + + test("stringMessageWrapper changes text type for responses API") { + val prompt = new OpenAIPrompt() + assert(prompt.stringMessageWrapper("hello") == Map("type" -> "text", "text" -> "hello")) + + prompt.setApiType("responses") + assert(prompt.stringMessageWrapper("hello") == Map("type" -> "input_text", "text" -> "hello")) + } + + test("createMessagesForRow returns null when path attachments are empty") { + val prompt = new OpenAIPrompt() + val messages = prompt.createMessagesForRow("Summarize", Map("filePath" -> " "), Seq("filePath")) + assert(messages == null) + } + + test("createMessagesForRow includes local text file contents for chat completions") { + val prompt = new OpenAIPrompt() + val tempFile = Files.createTempFile("synapseml-openai-local", ".txt") + + try { + Files.write(tempFile, "example content".getBytes(StandardCharsets.UTF_8)) + + val messages = prompt.createMessagesForRow( + "Summarize", + Map("filePath" -> tempFile.toString), + Seq("filePath") + ) + val userParts = messages.find(_.role == "user").get.content + + assert(userParts.head == Map("type" -> "text", "text" -> "Summarize")) + assert(userParts(1).get("type").contains("text")) + assert(userParts(1).get("text").exists(_.contains("Content: example content"))) + } finally { + Files.deleteIfExists(tempFile) + } + } + + test("createMessagesForRow formats text file content for responses API") { + val prompt = new OpenAIPrompt().setApiType("responses") + val tempFile = Files.createTempFile("synapseml-openai-local", ".txt") + + try { + Files.write(tempFile, "response content".getBytes(StandardCharsets.UTF_8)) + + val messages = prompt.createMessagesForRow( + "Summarize", + Map("filePath" -> tempFile.toString), + Seq("filePath") + ) + val systemParts = messages.find(_.role == "system").get.content + val userParts = messages.find(_.role == "user").get.content + + assert(systemParts.head.get("type").contains("input_text")) + assert(userParts.head == Map("type" -> "input_text", "text" -> "Summarize")) + assert(userParts(1) == Map("type" -> "input_text", "text" -> "response content")) + } finally { + Files.deleteIfExists(tempFile) + } + } +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/UsageUtilsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/UsageUtilsSuite.scala new file mode 100644 index 00000000000..0d2c47f1a0f --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/UsageUtilsSuite.scala @@ -0,0 +1,116 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, struct} + +class UsageUtilsSuite extends TestBase { + + import spark.implicits._ + + private def normalizeUsage(df: DataFrame, mapping: UsageUtils.UsageFieldMapping): Row = { + df.select(UsageUtils.normalize(col("usage"), mapping).alias("usage")).head().getAs[Row]("usage") + } + + test("normalize maps chat completion usage fields and nested details") { + val usageDf = Seq((1L, 2L, 3L, 10L, 11L, 21L, 22L, 23L, 24L)) + .toDF( + "prompt_tokens", + "completion_tokens", + "total_tokens", + "audio_tokens", + "cached_tokens", + "accepted_prediction_tokens", + "completion_audio_tokens", + "reasoning_tokens", + "rejected_prediction_tokens" + ) + .withColumn("usage", struct( + col("prompt_tokens"), + col("completion_tokens"), + col("total_tokens"), + struct(col("audio_tokens"), col("cached_tokens")).alias("prompt_tokens_details"), + struct( + col("accepted_prediction_tokens"), + col("completion_audio_tokens").alias("audio_tokens"), + col("reasoning_tokens"), + col("rejected_prediction_tokens") + ).alias("completion_tokens_details") + )) + .select("usage") + + val normalized = normalizeUsage(usageDf, UsageUtils.UsageMappings.ChatCompletions) + assert(normalized.getAs[Long]("input_tokens") == 1L) + assert(normalized.getAs[Long]("output_tokens") == 2L) + assert(normalized.getAs[Long]("total_tokens") == 3L) + assert( + normalized.getMap[String, Long](normalized.fieldIndex("input_token_details")) == + Map("audio_tokens" -> 10L, "cached_tokens" -> 11L) + ) + assert( + normalized.getMap[String, Long](normalized.fieldIndex("output_token_details")) == + Map( + "accepted_prediction_tokens" -> 21L, + "audio_tokens" -> 22L, + "reasoning_tokens" -> 23L, + "rejected_prediction_tokens" -> 24L + ) + ) + } + + test("normalize maps responses usage fields and details") { + val usageDf = Seq((4L, 5L, 9L, 2L, 3L)) + .toDF("input_tokens", "output_tokens", "total_tokens", "cached_tokens", "reasoning_tokens") + .withColumn("usage", struct( + col("input_tokens"), + col("output_tokens"), + col("total_tokens"), + struct(col("cached_tokens")).alias("input_tokens_details"), + struct(col("reasoning_tokens")).alias("output_tokens_details") + )) + .select("usage") + + val normalized = normalizeUsage(usageDf, UsageUtils.UsageMappings.Responses) + assert(normalized.getAs[Long]("input_tokens") == 4L) + assert(normalized.getAs[Long]("output_tokens") == 5L) + assert(normalized.getAs[Long]("total_tokens") == 9L) + assert(normalized.getMap[String, Long](normalized.fieldIndex("input_token_details")) == Map("cached_tokens" -> 2L)) + assert(normalized.getMap[String, Long](normalized.fieldIndex("output_token_details")) == + Map("reasoning_tokens" -> 3L)) + } + + test("normalize handles missing or empty detail mappings") { + val embeddingDf = Seq((7L, 7L)) + .toDF("prompt_tokens", "total_tokens") + .withColumn("usage", struct(col("prompt_tokens"), col("total_tokens"))) + .select("usage") + val embeddingUsage = normalizeUsage(embeddingDf, UsageUtils.UsageMappings.Embeddings) + assert(embeddingUsage.getAs[Long]("input_tokens") == 7L) + assert(embeddingUsage.isNullAt(embeddingUsage.fieldIndex("output_tokens"))) + assert(embeddingUsage.getMap[String, Long](embeddingUsage.fieldIndex("input_token_details")).isEmpty) + assert(embeddingUsage.getMap[String, Long](embeddingUsage.fieldIndex("output_token_details")).isEmpty) + + val customMapping = UsageUtils.UsageFieldMapping( + inputTokens = Some("input_tokens"), + outputTokens = None, + totalTokens = Some("total_tokens"), + inputDetails = Some("input_tokens_details" -> Seq.empty), + outputDetails = None + ) + val customDf = Seq((8L, 8L, 4L)) + .toDF("input_tokens", "total_tokens", "cached_tokens") + .withColumn("usage", struct( + col("input_tokens"), + col("total_tokens"), + struct(col("cached_tokens")).alias("input_tokens_details") + )) + .select("usage") + + val customUsage = normalizeUsage(customDf, customMapping) + assert(customUsage.getAs[Long]("input_tokens") == 8L) + assert(customUsage.getMap[String, Long](customUsage.fieldIndex("input_token_details")).isEmpty) + } +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/translate/TextTranslatorCoreSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/translate/TextTranslatorCoreSuite.scala new file mode 100644 index 00000000000..25abe5f1091 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/translate/TextTranslatorCoreSuite.scala @@ -0,0 +1,340 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.translate + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.http.client.methods.HttpPost +import org.apache.http.util.EntityUtils +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.ArrayType +import org.apache.spark.sql.types.StructType + +import java.net.URLDecoder + +private[translate] class TestableTranslate extends Translate { + def buildRequest(schema: StructType, row: Row): Option[HttpPost] = + inputFunc(schema)(row).map(_.asInstanceOf[HttpPost]) +} + +private[translate] class TestableTransliterate extends Transliterate { + def buildRequest(schema: StructType, row: Row): Option[HttpPost] = + inputFunc(schema)(row).map(_.asInstanceOf[HttpPost]) +} + +private[translate] class TestableDetect extends Detect { + def buildRequest(schema: StructType, row: Row): Option[HttpPost] = + inputFunc(schema)(row).map(_.asInstanceOf[HttpPost]) +} + +private[translate] class TestableBreakSentence extends BreakSentence { + def buildRequest(schema: StructType, row: Row): Option[HttpPost] = + inputFunc(schema)(row).map(_.asInstanceOf[HttpPost]) +} + +private[translate] class TestableDictionaryLookup extends DictionaryLookup { + def buildRequest(schema: StructType, row: Row): Option[HttpPost] = + inputFunc(schema)(row).map(_.asInstanceOf[HttpPost]) +} + +private[translate] class TestableDictionaryExamples extends DictionaryExamples { + def buildRequest(schema: StructType, row: Row): Option[HttpPost] = + inputFunc(schema)(row).map(_.asInstanceOf[HttpPost]) +} + +class TextTranslatorCoreSuite extends TestBase { + + import spark.implicits._ + + private def toQueryMap(post: HttpPost): Map[String, String] = { + Option(post.getURI.getRawQuery).toSeq.flatMap(_.split("&")).map { kv => + val pair = kv.split("=", 2) + val key = URLDecoder.decode(pair(0), "UTF-8") + val value = if (pair.length > 1) URLDecoder.decode(pair(1), "UTF-8") else "" + key -> value + }.toMap + } + + test("setLocation sets translator endpoint and subscription region") { + val global = new Translate().setLocation("eastus") + assert(global.getSubscriptionRegion == "eastus") + assert(global.getUrl == "https://api.cognitive.microsofttranslator.com/translate") + + val usGov = new Translate().setLocation("usgovarizona") + assert(usGov.getUrl == "https://api.cognitive.microsofttranslator.us/translate") + + val china = new Translate().setLocation("chinanorth") + assert(china.getUrl == "https://api.cognitive.microsofttranslator.cn/translate") + } + + test("translate defaults are deterministic") { + val t = new Translate() + assert(t.getOrDefault(t.textType) == Left("plain")) + assert(t.getOrDefault(t.category) == Left("general")) + assert(t.getOrDefault(t.profanityAction) == Left("NoAction")) + assert(t.getOrDefault(t.profanityMarker) == Left("Asterisk")) + assertResult(Left(false))(t.getOrDefault(t.includeAlignment)) + assertResult(Left(false))(t.getOrDefault(t.includeSentenceLength)) + assertResult(Left(true))(t.getOrDefault(t.allowFallback)) + } + + test("translate rejects invalid enum parameters") { + intercept[IllegalArgumentException] { + new Translate().setTextType("markdown") + } + intercept[IllegalArgumentException] { + new Translate().setProfanityAction("Mask") + } + intercept[IllegalArgumentException] { + new Translate().setProfanityMarker("Bracket") + } + } + + test("translate request building maps query params and body deterministically") { + val df = Seq((Seq("hello", "world"), Seq("de", "fr"), "en")) + .toDF("text", "toLanguage", "fromLanguage") + + val t = new TestableTranslate() + .setSubscriptionKey("fake-key") + .setLocation("eastus") + .setTextCol("text") + .setToLanguageCol("toLanguage") + .setFromLanguageCol("fromLanguage") + + val request = t.buildRequest(df.schema, df.head()).get + val query = toQueryMap(request) + assert(query("api-version") == "3.0") + assert(query("from") == "en") + assert(query("to") == "de,fr") + assert(query("textType") == "plain") + assert(query("category") == "general") + assert(query("profanityAction") == "NoAction") + assert(query("profanityMarker") == "Asterisk") + assert(query("includeAlignment") == "false") + assert(query("includeSentenceLength") == "false") + assert(query("allowFallback") == "true") + assert(request.getFirstHeader("Ocp-Apim-Subscription-Key").getValue == "fake-key") + assert(request.getFirstHeader("Ocp-Apim-Subscription-Region").getValue == "eastus") + assert(request.getFirstHeader("Content-Type").getValue == "application/json; charset=UTF-8") + assert(EntityUtils.toString(request.getEntity, "UTF-8") == """[{"Text":"hello"},{"Text":"world"}]""") + } + + test("translate request building skips empty or missing text and targets") { + val t = new TestableTranslate() + .setLocation("eastus") + .setTextCol("text") + .setToLanguageCol("toLanguage") + + val emptyTextDf = Seq((Seq.empty[String], Seq("de"))).toDF("text", "toLanguage") + assert(t.buildRequest(emptyTextDf.schema, emptyTextDf.head()).isEmpty) + + val emptyToDf = Seq((Seq("hello"), Seq.empty[String])).toDF("text", "toLanguage") + assert(t.buildRequest(emptyToDf.schema, emptyToDf.head()).isEmpty) + + val nullToDf = Seq((Seq("hello"), Option.empty[Seq[String]])).toDF("text", "toLanguage") + assert(t.buildRequest(nullToDf.schema, nullToDf.head()).isEmpty) + } + + test("translate transformSchema adds output and error columns without temp columns") { + val input = Seq(("hello", "de")).toDF("text", "toLanguage") + val t = new Translate() + .setTextCol("text") + .setToLanguageCol("toLanguage") + .setOutputCol("translation") + .setErrorCol("translationError") + + val schema = t.transformSchema(input.schema) + assert(schema.fieldNames.toSet == Set("text", "toLanguage", "translation", "translationError")) + assert(schema("translation").dataType == ArrayType(TranslateResponse.schema)) + } + + test("translate validates required parameters during schema creation") { + val textOnly = Seq("hello").toDF("text") + val err = intercept[AssertionError] { + new Translate().setTextCol("text").transformSchema(textOnly.schema) + } + assert(err.getMessage.contains("Missing required params")) + assert(err.getMessage.contains("toLanguage")) + } + + test("transliterate request building maps required params and body deterministically") { + val df = Seq((Seq("こんにちは"), "ja", "Jpan", "Latn")).toDF("text", "language", "fromScript", "toScript") + + val t = new TestableTransliterate() + .setSubscriptionKey("fake-key") + .setLocation("eastus") + .setTextCol("text") + .setLanguageCol("language") + .setFromScriptCol("fromScript") + .setToScriptCol("toScript") + + val request = t.buildRequest(df.schema, df.head()).get + val query = toQueryMap(request) + assert(request.getURI.getPath.endsWith("/transliterate")) + assert(query("api-version") == "3.0") + assert(query("language") == "ja") + assert(query("fromScript") == "Jpan") + assert(query("toScript") == "Latn") + assert(request.getFirstHeader("Ocp-Apim-Subscription-Key").getValue == "fake-key") + assert(request.getFirstHeader("Ocp-Apim-Subscription-Region").getValue == "eastus") + assert(EntityUtils.toString(request.getEntity, "UTF-8") == """[{"Text":"こんにちは"}]""") + } + + test("detect and breaksentence request building is deterministic offline") { + val detectDf = Seq(Seq("hello", "world")).toDF("text") + val detectRequest = new TestableDetect() + .setLocation("eastus") + .setTextCol("text") + .buildRequest(detectDf.schema, detectDf.head()) + .get + assert(detectRequest.getURI.getPath.endsWith("/detect")) + assert(toQueryMap(detectRequest) == Map("api-version" -> "3.0")) + assert(EntityUtils.toString(detectRequest.getEntity, "UTF-8") == """[{"Text":"hello"},{"Text":"world"}]""") + + val breakDf = Seq((Seq("hello"), "en", "Latn")).toDF("text", "language", "script") + val breakRequest = new TestableBreakSentence() + .setLocation("eastus") + .setTextCol("text") + .setLanguageCol("language") + .setScriptCol("script") + .buildRequest(breakDf.schema, breakDf.head()) + .get + val breakQuery = toQueryMap(breakRequest) + assert(breakRequest.getURI.getPath.endsWith("/breaksentence")) + assert(breakQuery("api-version") == "3.0") + assert(breakQuery("language") == "en") + assert(breakQuery("script") == "Latn") + assert(EntityUtils.toString(breakRequest.getEntity, "UTF-8") == """[{"Text":"hello"}]""") + } + + test("dictionary lookup and examples request building maps query params and body") { + val lookupDf = Seq((Seq("fly"), "en", "es")).toDF("text", "fromLanguage", "toLanguage") + val lookupRequest = new TestableDictionaryLookup() + .setSubscriptionKey("fake-key") + .setLocation("eastus") + .setTextCol("text") + .setFromLanguageCol("fromLanguage") + .setToLanguageCol("toLanguage") + .buildRequest(lookupDf.schema, lookupDf.head()) + .get + val lookupQuery = toQueryMap(lookupRequest) + assert(lookupRequest.getURI.getPath.endsWith("/dictionary/lookup")) + assert(lookupQuery("api-version") == "3.0") + assert(lookupQuery("from") == "en") + assert(lookupQuery("to") == "es") + assert(EntityUtils.toString(lookupRequest.getEntity, "UTF-8") == """[{"Text":"fly"}]""") + + val examplesDf = Seq((Seq(TextAndTranslation("fly", "volar")), "en", "es")) + .toDF("textAndTranslation", "fromLanguage", "toLanguage") + val examplesRequest = new TestableDictionaryExamples() + .setLocation("eastus") + .setTextAndTranslationCol("textAndTranslation") + .setFromLanguageCol("fromLanguage") + .setToLanguageCol("toLanguage") + .buildRequest(examplesDf.schema, examplesDf.head()) + .get + val examplesQuery = toQueryMap(examplesRequest) + assert(examplesRequest.getURI.getPath.endsWith("/dictionary/examples")) + assert(examplesQuery("api-version") == "3.0") + assert(examplesQuery("from") == "en") + assert(examplesQuery("to") == "es") + assert(EntityUtils.toString(examplesRequest.getEntity, "UTF-8") == """[{"Text":"fly","Translation":"volar"}]""") + } + + test("dictionary examples request building supports scalar text and translation input") { + val request = new TestableDictionaryExamples() + .setLocation("eastus") + .setFromLanguage("en") + .setToLanguage("es") + .setTextAndTranslation(TextAndTranslation("fly", "volar")) + .buildRequest(StructType(Seq.empty), Row.empty) + .get + val query = toQueryMap(request) + assert(request.getURI.getPath.endsWith("/dictionary/examples")) + assert(query("api-version") == "3.0") + assert(query("from") == "en") + assert(query("to") == "es") + assert(EntityUtils.toString(request.getEntity, "UTF-8") == """[{"Text":"fly","Translation":"volar"}]""") + } + + test("non-translate transformSchema adds deterministic output and error columns") { + val textOnly = Seq("hello").toDF("text") + val textAndTranslationOnly = Seq(Seq(TextAndTranslation("fly", "volar"))).toDF("textAndTranslation") + + val transliterateSchema = new Transliterate() + .setTextCol("text") + .setLanguage("ja") + .setFromScript("Jpan") + .setToScript("Latn") + .setOutputCol("transliteration") + .setErrorCol("transliterationError") + .transformSchema(textOnly.schema) + assert(transliterateSchema.fieldNames.toSet == Set("text", "transliteration", "transliterationError")) + assert(transliterateSchema("transliteration").dataType == ArrayType(TransliterateResponse.schema)) + + val detectSchema = new Detect() + .setTextCol("text") + .setOutputCol("detection") + .setErrorCol("detectionError") + .transformSchema(textOnly.schema) + assert(detectSchema.fieldNames.toSet == Set("text", "detection", "detectionError")) + assert(detectSchema("detection").dataType == ArrayType(DetectResponse.schema)) + + val breakSentenceSchema = new BreakSentence() + .setTextCol("text") + .setOutputCol("sentenceBreaks") + .setErrorCol("breakError") + .transformSchema(textOnly.schema) + assert(breakSentenceSchema.fieldNames.toSet == Set("text", "sentenceBreaks", "breakError")) + assert(breakSentenceSchema("sentenceBreaks").dataType == ArrayType(BreakSentenceResponse.schema)) + + val lookupSchema = new DictionaryLookup() + .setTextCol("text") + .setFromLanguage("en") + .setToLanguage("es") + .setOutputCol("lookup") + .setErrorCol("lookupError") + .transformSchema(textOnly.schema) + assert(lookupSchema.fieldNames.toSet == Set("text", "lookup", "lookupError")) + assert(lookupSchema("lookup").dataType == ArrayType(DictionaryLookupResponse.schema)) + + val examplesSchema = new DictionaryExamples() + .setTextAndTranslationCol("textAndTranslation") + .setFromLanguage("en") + .setToLanguage("es") + .setOutputCol("examples") + .setErrorCol("examplesError") + .transformSchema(textAndTranslationOnly.schema) + assert(examplesSchema.fieldNames.toSet == Set("textAndTranslation", "examples", "examplesError")) + assert(examplesSchema("examples").dataType == ArrayType(DictionaryExamplesResponse.schema)) + } + + test("non-translate classes validate required parameters during schema creation") { + val textOnly = Seq("hello").toDF("text") + val textAndTranslationOnly = Seq(Seq(TextAndTranslation("fly", "volar"))).toDF("textAndTranslation") + + val transliterateError = intercept[AssertionError] { + new Transliterate().setTextCol("text").transformSchema(textOnly.schema) + } + assert(transliterateError.getMessage.contains("Missing required params")) + assert(transliterateError.getMessage.contains("language")) + assert(transliterateError.getMessage.contains("fromScript")) + assert(transliterateError.getMessage.contains("toScript")) + + val dictionaryLookupError = intercept[AssertionError] { + new DictionaryLookup().setTextCol("text").transformSchema(textOnly.schema) + } + assert(dictionaryLookupError.getMessage.contains("Missing required params")) + assert(dictionaryLookupError.getMessage.contains("fromLanguage")) + assert(dictionaryLookupError.getMessage.contains("toLanguage")) + + val dictionaryExamplesError = intercept[AssertionError] { + new DictionaryExamples() + .setTextAndTranslationCol("textAndTranslation") + .transformSchema(textAndTranslationOnly.schema) + } + assert(dictionaryExamplesError.getMessage.contains("Missing required params")) + assert(dictionaryExamplesError.getMessage.contains("fromLanguage")) + assert(dictionaryExamplesError.getMessage.contains("toLanguage")) + } +} diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala index 425d7314f6f..3f1ae537c46 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala @@ -99,6 +99,7 @@ object PyCodegen { | long_description="SynapseML contains Microsoft's open source " | + "contributions to the Apache Spark ecosystem", | license="MIT", + | license_expression="MIT", | packages=find_namespace_packages(include=['synapse.ml.*']) ${extraPackage}, | url="https://github.com/Microsoft/SynapseML", | author="Microsoft", @@ -108,8 +109,6 @@ object PyCodegen { | "Intended Audience :: Developers", | "Intended Audience :: Science/Research", | "Topic :: Software Development :: Libraries", - | "License :: OSI Approved :: MIT License", - | "Programming Language :: Python :: 2", | "Programming Language :: Python :: 3", | ], | zip_safe=True, diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/logging/SynapseMLLogging.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/logging/SynapseMLLogging.scala index 2269592be27..b2358f487ff 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/logging/SynapseMLLogging.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/logging/SynapseMLLogging.scala @@ -41,7 +41,7 @@ case class RequiredErrorFields(errorType: String, def toMap: Map[String, String] = { Map( "errorType" -> errorType, - "errorMessage" -> errorType + "errorMessage" -> errorMessage ) } } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala index 9ff42cd8b28..cffaedba28d 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala @@ -28,8 +28,10 @@ object GlobalParams { def getParam[T](p: Param[T]): Option[T] = { ParamToKeyMap.get(p).flatMap { key => - key match { - case k: GlobalKey[T] => + // Using @unchecked because GlobalKey[T] type parameter is erased at runtime, + // but we know the types are correct due to how registerParam stores them + (key: @unchecked) match { + case k: GlobalKey[T @unchecked] => getGlobalParam(k) case _ => None } diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyDefaultHyperparams.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyDefaultHyperparams.scala new file mode 100644 index 00000000000..c335c1f234d --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyDefaultHyperparams.scala @@ -0,0 +1,124 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.automl + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.classification._ + +class VerifyDefaultHyperparams extends TestBase { + + test("defaultRange for LogisticRegression returns non-empty array") { + val lr = new LogisticRegression() + val ranges = DefaultHyperparams.defaultRange(lr) + assert(ranges.nonEmpty) + assert(ranges.length === 3) // regParam, elasticNetParam, maxIter + } + + test("defaultRange for LogisticRegression includes expected params") { + val lr = new LogisticRegression() + val ranges = DefaultHyperparams.defaultRange(lr) + val paramNames = ranges.map(_._1.name).toSet + assert(paramNames.contains("regParam")) + assert(paramNames.contains("elasticNetParam")) + assert(paramNames.contains("maxIter")) + } + + test("defaultRange for DecisionTreeClassifier returns non-empty array") { + val dt = new DecisionTreeClassifier() + val ranges = DefaultHyperparams.defaultRange(dt) + assert(ranges.nonEmpty) + assert(ranges.length === 4) // maxBins, maxDepth, minInfoGain, minInstancesPerNode + } + + test("defaultRange for DecisionTreeClassifier includes expected params") { + val dt = new DecisionTreeClassifier() + val ranges = DefaultHyperparams.defaultRange(dt) + val paramNames = ranges.map(_._1.name).toSet + assert(paramNames.contains("maxBins")) + assert(paramNames.contains("maxDepth")) + assert(paramNames.contains("minInfoGain")) + assert(paramNames.contains("minInstancesPerNode")) + } + + test("defaultRange for GBTClassifier returns non-empty array") { + val gbt = new GBTClassifier() + val ranges = DefaultHyperparams.defaultRange(gbt) + assert(ranges.nonEmpty) + assert(ranges.length === 7) + } + + test("defaultRange for GBTClassifier includes expected params") { + val gbt = new GBTClassifier() + val ranges = DefaultHyperparams.defaultRange(gbt) + val paramNames = ranges.map(_._1.name).toSet + assert(paramNames.contains("maxBins")) + assert(paramNames.contains("maxDepth")) + assert(paramNames.contains("minInfoGain")) + assert(paramNames.contains("minInstancesPerNode")) + assert(paramNames.contains("maxIter")) + assert(paramNames.contains("stepSize")) + assert(paramNames.contains("subsamplingRate")) + } + + test("defaultRange for RandomForestClassifier returns non-empty array") { + val rf = new RandomForestClassifier() + val ranges = DefaultHyperparams.defaultRange(rf) + assert(ranges.nonEmpty) + assert(ranges.length === 6) + } + + test("defaultRange for RandomForestClassifier includes expected params") { + val rf = new RandomForestClassifier() + val ranges = DefaultHyperparams.defaultRange(rf) + val paramNames = ranges.map(_._1.name).toSet + assert(paramNames.contains("maxBins")) + assert(paramNames.contains("maxDepth")) + assert(paramNames.contains("minInfoGain")) + assert(paramNames.contains("minInstancesPerNode")) + assert(paramNames.contains("numTrees")) + assert(paramNames.contains("subsamplingRate")) + } + + test("defaultRange for MultilayerPerceptronClassifier returns non-empty array") { + val mlp = new MultilayerPerceptronClassifier() + val ranges = DefaultHyperparams.defaultRange(mlp) + assert(ranges.nonEmpty) + assert(ranges.length === 4) // blockSize, maxIter, tol, layers + } + + test("defaultRange for MultilayerPerceptronClassifier includes expected params") { + val mlp = new MultilayerPerceptronClassifier() + val ranges = DefaultHyperparams.defaultRange(mlp) + val paramNames = ranges.map(_._1.name).toSet + assert(paramNames.contains("blockSize")) + assert(paramNames.contains("maxIter")) + assert(paramNames.contains("tol")) + assert(paramNames.contains("layers")) + } + + test("defaultRange for NaiveBayes returns non-empty array") { + val nb = new NaiveBayes() + val ranges = DefaultHyperparams.defaultRange(nb) + assert(ranges.nonEmpty) + assert(ranges.length === 1) // smoothing + } + + test("defaultRange for NaiveBayes includes smoothing param") { + val nb = new NaiveBayes() + val ranges = DefaultHyperparams.defaultRange(nb) + val paramNames = ranges.map(_._1.name).toSet + assert(paramNames.contains("smoothing")) + } + + test("all defaultRange methods return valid Dist instances") { + val lr = new LogisticRegression() + val ranges = DefaultHyperparams.defaultRange(lr) + ranges.foreach { case (param, dist) => + assert(param != null) + assert(dist != null) + // Verify the distribution is a known type + assert(dist.isInstanceOf[Dist[_]]) + } + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyEvaluationUtils.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyEvaluationUtils.scala new file mode 100644 index 00000000000..2097fdb11c9 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyEvaluationUtils.scala @@ -0,0 +1,122 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.automl + +import com.microsoft.azure.synapse.ml.core.metrics.MetricConstants +import com.microsoft.azure.synapse.ml.core.schema.SchemaConstants +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyEvaluationUtils extends TestBase { + + test("ModelTypeUnsupportedErr constant has expected value") { + assert(EvaluationUtils.ModelTypeUnsupportedErr === "Model type not supported for evaluation") + } + + test("getMetricWithOperator returns correct metric for regression MSE") { + val (metricName, ordering) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, + MetricConstants.MseSparkMetric + ) + assert(metricName === MetricConstants.MseColumnName) + // MSE should use lowest (reverse ordering) + assert(ordering.compare(1.0, 2.0) > 0) // 1.0 is "better" than 2.0 for MSE + } + + test("getMetricWithOperator returns correct metric for regression RMSE") { + val (metricName, ordering) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, + MetricConstants.RmseSparkMetric + ) + assert(metricName === MetricConstants.RmseColumnName) + // RMSE should use lowest + assert(ordering.compare(1.0, 2.0) > 0) + } + + test("getMetricWithOperator returns correct metric for regression R2") { + val (metricName, ordering) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, + MetricConstants.R2SparkMetric + ) + assert(metricName === MetricConstants.R2ColumnName) + // R2 should use highest + assert(ordering.compare(2.0, 1.0) > 0) // 2.0 is "better" than 1.0 for R2 + } + + test("getMetricWithOperator returns correct metric for regression MAE") { + val (metricName, ordering) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, + MetricConstants.MaeSparkMetric + ) + assert(metricName === MetricConstants.MaeColumnName) + // MAE should use lowest + assert(ordering.compare(1.0, 2.0) > 0) + } + + test("getMetricWithOperator returns correct metric for classification AUC") { + val (metricName, ordering) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.ClassificationKind, + MetricConstants.AucSparkMetric + ) + assert(metricName === MetricConstants.AucColumnName) + // AUC should use highest + assert(ordering.compare(2.0, 1.0) > 0) + } + + test("getMetricWithOperator returns correct metric for classification Precision") { + val (metricName, ordering) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.ClassificationKind, + MetricConstants.PrecisionSparkMetric + ) + assert(metricName === MetricConstants.PrecisionColumnName) + // Precision should use highest + assert(ordering.compare(2.0, 1.0) > 0) + } + + test("getMetricWithOperator returns correct metric for classification Recall") { + val (metricName, ordering) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.ClassificationKind, + MetricConstants.RecallSparkMetric + ) + assert(metricName === MetricConstants.RecallColumnName) + // Recall should use highest + assert(ordering.compare(2.0, 1.0) > 0) + } + + test("getMetricWithOperator returns correct metric for classification Accuracy") { + val (metricName, ordering) = EvaluationUtils.getMetricWithOperator( + SchemaConstants.ClassificationKind, + MetricConstants.AccuracySparkMetric + ) + assert(metricName === MetricConstants.AccuracyColumnName) + // Accuracy should use highest + assert(ordering.compare(2.0, 1.0) > 0) + } + + test("getMetricWithOperator throws for unsupported regression metric") { + assertThrows[Exception] { + EvaluationUtils.getMetricWithOperator( + SchemaConstants.RegressionKind, + "unsupported_metric" + ) + } + } + + test("getMetricWithOperator throws for unsupported classification metric") { + assertThrows[Exception] { + EvaluationUtils.getMetricWithOperator( + SchemaConstants.ClassificationKind, + "unsupported_metric" + ) + } + } + + test("getMetricWithOperator throws for unsupported model type") { + assertThrows[Exception] { + EvaluationUtils.getMetricWithOperator( + "unsupported_model_type", + MetricConstants.MseSparkMetric + ) + } + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyHyperparamBuilder.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyHyperparamBuilder.scala new file mode 100644 index 00000000000..ed3d697aa2c --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/automl/VerifyHyperparamBuilder.scala @@ -0,0 +1,169 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.automl + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.{IntParam, Params, ParamMap, DoubleParam, LongParam, FloatParam} + +class VerifyHyperparamBuilder extends TestBase { + + // Helper class for creating test params + private class TestParams extends Params { + override val uid: String = "test" + val intParam = new IntParam(this, "intParam", "test int param") + val doubleParam = new DoubleParam(this, "doubleParam", "test double param") + val longParam = new LongParam(this, "longParam", "test long param") + val floatParam = new FloatParam(this, "floatParam", "test float param") + override def copy(extra: ParamMap): Params = this + } + + private val testParamsInstance = new TestParams + + test("IntRangeHyperParam generates values within range") { + val param = new IntRangeHyperParam(10, 20, seed = 42) + for (_ <- 1 to 100) { + val value = param.getNext() + assert(value >= 10 && value < 20) + } + } + + test("IntRangeHyperParam respects seed for reproducibility") { + val param1 = new IntRangeHyperParam(0, 100, seed = 42) + val param2 = new IntRangeHyperParam(0, 100, seed = 42) + val values1 = (1 to 10).map(_ => param1.getNext()) + val values2 = (1 to 10).map(_ => param2.getNext()) + assert(values1 === values2) + } + + test("DoubleRangeHyperParam generates values within range") { + val param = new DoubleRangeHyperParam(0.0, 1.0, seed = 42) + for (_ <- 1 to 100) { + val value = param.getNext() + assert(value >= 0.0 && value < 1.0) + } + } + + test("DoubleRangeHyperParam respects seed for reproducibility") { + val param1 = new DoubleRangeHyperParam(0.0, 10.0, seed = 42) + val param2 = new DoubleRangeHyperParam(0.0, 10.0, seed = 42) + val values1 = (1 to 10).map(_ => param1.getNext()) + val values2 = (1 to 10).map(_ => param2.getNext()) + assert(values1 === values2) + } + + test("LongRangeHyperParam generates values") { + val param = new LongRangeHyperParam(0L, 100L, seed = 42) + val value = param.getNext() + // Just verify it returns a Long + assert(value.isInstanceOf[Long]) + } + + test("FloatRangeHyperParam generates values within range") { + val param = new FloatRangeHyperParam(0.0f, 1.0f, seed = 42) + for (_ <- 1 to 100) { + val value = param.getNext() + assert(value >= 0.0f && value < 1.0f) + } + } + + test("DiscreteHyperParam selects from provided values") { + val values = List("a", "b", "c") + val param = new DiscreteHyperParam(values, seed = 42) + for (_ <- 1 to 100) { + val value = param.getNext() + assert(values.contains(value)) + } + } + + test("DiscreteHyperParam.getValues returns Java list") { + val values = List(1, 2, 3) + val param = new DiscreteHyperParam(values) + val javaList = param.getValues + assert(javaList.size() === 3) + assert(javaList.get(0) === 1) + assert(javaList.get(1) === 2) + assert(javaList.get(2) === 3) + } + + test("HyperparamBuilder builds empty array when no params added") { + val builder = new HyperparamBuilder() + val result = builder.build() + assert(result.isEmpty) + } + + test("HyperparamBuilder adds single hyperparam") { + val builder = new HyperparamBuilder() + builder.addHyperparam(testParamsInstance.intParam, new IntRangeHyperParam(1, 10)) + val result = builder.build() + assert(result.length === 1) + assert(result.head._1 === testParamsInstance.intParam) + } + + test("HyperparamBuilder adds multiple hyperparams") { + val builder = new HyperparamBuilder() + .addHyperparam(testParamsInstance.intParam, new IntRangeHyperParam(1, 10)) + .addHyperparam(testParamsInstance.doubleParam, new DoubleRangeHyperParam(0.0, 1.0)) + val result = builder.build() + assert(result.length === 2) + } + + test("HyperparamBuilder supports method chaining") { + val builder = new HyperparamBuilder() + val result = builder + .addHyperparam(testParamsInstance.intParam, new IntRangeHyperParam(1, 10)) + .addHyperparam(testParamsInstance.doubleParam, new DoubleRangeHyperParam(0.0, 1.0)) + .build() + assert(result.length === 2) + } + + test("HyperParamUtils.getRangeHyperParam returns IntRangeHyperParam for Int") { + val result = HyperParamUtils.getRangeHyperParam(1, 10) + assert(result.isInstanceOf[IntRangeHyperParam]) + val intResult = result.asInstanceOf[IntRangeHyperParam] + assert(intResult.min === 1) + assert(intResult.max === 10) + } + + test("HyperParamUtils.getRangeHyperParam returns DoubleRangeHyperParam for Double") { + val result = HyperParamUtils.getRangeHyperParam(0.0, 1.0) + assert(result.isInstanceOf[DoubleRangeHyperParam]) + val doubleResult = result.asInstanceOf[DoubleRangeHyperParam] + assert(doubleResult.min === 0.0) + assert(doubleResult.max === 1.0) + } + + test("HyperParamUtils.getRangeHyperParam returns LongRangeHyperParam for Long") { + val result = HyperParamUtils.getRangeHyperParam(0L, 100L) + assert(result.isInstanceOf[LongRangeHyperParam]) + } + + test("HyperParamUtils.getRangeHyperParam returns FloatRangeHyperParam for Float") { + val result = HyperParamUtils.getRangeHyperParam(0.0f, 1.0f) + assert(result.isInstanceOf[FloatRangeHyperParam]) + } + + test("HyperParamUtils.getRangeHyperParam throws for unsupported types") { + assertThrows[Exception] { + HyperParamUtils.getRangeHyperParam("a", "b") + } + } + + test("HyperParamUtils.getDiscreteHyperParam creates DiscreteHyperParam from Java ArrayList") { + val javaList = new java.util.ArrayList[Int]() + javaList.add(1) + javaList.add(2) + javaList.add(3) + val result = HyperParamUtils.getDiscreteHyperParam(javaList) + assert(result.isInstanceOf[DiscreteHyperParam[_]]) + val value = result.getNext() + assert(Seq(1, 2, 3).contains(value)) + } + + test("RangeHyperParam stores min, max, and seed") { + val param = new IntRangeHyperParam(5, 15, seed = 123) + assert(param.min === 5) + assert(param.max === 15) + assert(param.seed === 123) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyCacheOps.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyCacheOps.scala new file mode 100644 index 00000000000..cb132c26ab5 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyCacheOps.scala @@ -0,0 +1,61 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.causal + +import breeze.linalg.{DenseVector => BDV} +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyCacheOps extends TestBase { + + test("CacheOps trait has default implementations that return input unchanged") { + val ops = new CacheOps[String] {} + val data = "test" + assert(ops.cache(data) === data) + assert(ops.checkpoint(data) === data) + } + + test("BDVCacheOps.cache returns the same vector") { + val vector = BDV(1.0, 2.0, 3.0) + val result = BDVCacheOps.cache(vector) + assert(result eq vector) + } + + test("BDVCacheOps.checkpoint returns the same vector") { + val vector = BDV(1.0, 2.0, 3.0) + val result = BDVCacheOps.checkpoint(vector) + assert(result eq vector) + } + + test("BDVCacheOps is a no-op for dense vectors") { + val vector = BDV(1.0, 2.0, 3.0, 4.0, 5.0) + + // Both operations should return the exact same instance + val cached = BDVCacheOps.cache(vector) + val checkpointed = BDVCacheOps.checkpoint(vector) + + assert(cached eq vector) + assert(checkpointed eq vector) + assert(cached.toArray === Array(1.0, 2.0, 3.0, 4.0, 5.0)) + } + + test("BDVCacheOps preserves vector data") { + val vector = BDV(10.0, 20.0, 30.0) + val cached = BDVCacheOps.cache(vector) + + assert(cached(0) === 10.0) + assert(cached(1) === 20.0) + assert(cached(2) === 30.0) + assert(cached.length === 3) + } + + test("CacheOps works with generic types") { + case class TestData(value: Int) + + val ops = new CacheOps[TestData] {} + val data = TestData(42) + + assert(ops.cache(data) === data) + assert(ops.checkpoint(data) === data) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifySharedParams.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifySharedParams.scala new file mode 100644 index 00000000000..f100be4c4d8 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifySharedParams.scala @@ -0,0 +1,104 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.causal + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.Params +import org.apache.spark.ml.util.Identifiable + +class VerifySharedParams extends TestBase { + + // Test implementation that mixes in all the traits + private class TestParamsImpl(override val uid: String) + extends Params + with HasTreatmentCol + with HasOutcomeCol + with HasPostTreatmentCol + with HasUnitCol + with HasTimeCol { + override def copy(extra: org.apache.spark.ml.param.ParamMap): Params = this + } + + test("HasTreatmentCol sets and gets treatment column") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + params.setTreatmentCol("treatment") + assert(params.getTreatmentCol === "treatment") + } + + test("HasTreatmentCol param has correct name and doc") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + assert(params.treatmentCol.name === "treatmentCol") + assert(params.treatmentCol.doc === "treatment column") + } + + test("HasOutcomeCol sets and gets outcome column") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + params.setOutcomeCol("outcome") + assert(params.getOutcomeCol === "outcome") + } + + test("HasOutcomeCol param has correct name and doc") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + assert(params.outcomeCol.name === "outcomeCol") + assert(params.outcomeCol.doc === "outcome column") + } + + test("HasPostTreatmentCol sets and gets post treatment column") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + params.setPostTreatmentCol("postTreatment") + assert(params.getPostTreatmentCol === "postTreatment") + } + + test("HasPostTreatmentCol param has correct name and doc") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + assert(params.postTreatmentCol.name === "postTreatmentCol") + assert(params.postTreatmentCol.doc === "post treatment indicator column") + } + + test("HasUnitCol sets and gets unit column") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + params.setUnitCol("userId") + assert(params.getUnitCol === "userId") + } + + test("HasUnitCol param has correct name") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + assert(params.unitCol.name === "unitCol") + assert(params.unitCol.doc.contains("identifier for each observed unit")) + } + + test("HasTimeCol sets and gets time column") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + params.setTimeCol("date") + assert(params.getTimeCol === "date") + } + + test("HasTimeCol param has correct name") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + assert(params.timeCol.name === "timeCol") + assert(params.timeCol.doc.contains("time when outcome is measured")) + } + + test("All params can be set together") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + params + .setTreatmentCol("treatment") + .setOutcomeCol("outcome") + .setPostTreatmentCol("post") + .setUnitCol("user") + .setTimeCol("time") + + assert(params.getTreatmentCol === "treatment") + assert(params.getOutcomeCol === "outcome") + assert(params.getPostTreatmentCol === "post") + assert(params.getUnitCol === "user") + assert(params.getTimeCol === "time") + } + + test("Setters return this.type for chaining") { + val params = new TestParamsImpl(Identifiable.randomUID("test")) + val result = params.setTreatmentCol("treatment") + assert(result eq params) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/contracts/VerifyMetrics.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/contracts/VerifyMetrics.scala new file mode 100644 index 00000000000..740e1dd8924 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/contracts/VerifyMetrics.scala @@ -0,0 +1,105 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.core.contracts + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyMetrics extends TestBase { + + test("TypedMetric stores name and value correctly") { + val metric = TypedMetric("accuracy", 0.95) + assert(metric.name === "accuracy") + assert(metric.value === 0.95) + } + + test("TypedMetric works with different types") { + val doubleMetric = TypedMetric[Double]("score", 1.5) + val stringMetric = TypedMetric[String]("label", "positive") + val intMetric = TypedMetric[Int]("count", 42) + + assert(doubleMetric.value === 1.5) + assert(stringMetric.value === "positive") + assert(intMetric.value === 42) + } + + test("DoubleMetric stores name and double value") { + val metric = DoubleMetric("precision", 0.85) + assert(metric.name === "precision") + assert(metric.value === 0.85) + } + + test("StringMetric stores name and string value") { + val metric = StringMetric("category", "classification") + assert(metric.name === "category") + assert(metric.value === "classification") + } + + test("IntegralMetric stores name and long value") { + val metric = IntegralMetric("count", 1000L) + assert(metric.name === "count") + assert(metric.value === 1000L) + } + + test("TypenameMetricGroup stores name and values map") { + val metrics = Map( + "group1" -> Seq(DoubleMetric("m1", 1.0), DoubleMetric("m2", 2.0)), + "group2" -> Seq(StringMetric("s1", "test")) + ) + val group = TypenameMetricGroup("myGroup", metrics) + assert(group.name === "myGroup") + assert(group.values.size === 2) + assert(group.values("group1").length === 2) + } + + test("MetricData stores data, metricType, and modelName") { + val data = Map("accuracy" -> Seq(0.9, 0.91, 0.92)) + val metricData = MetricData(data, "classification", "logisticRegression") + + assert(metricData.data === data) + assert(metricData.metricType === "classification") + assert(metricData.modelName === "logisticRegression") + } + + test("MetricData.create converts single values to sequences") { + val singleValues = Map("accuracy" -> 0.95, "precision" -> 0.90) + val metricData = MetricData.create(singleValues, "classification", "svm") + + assert(metricData.data("accuracy") === List(0.95)) + assert(metricData.data("precision") === List(0.90)) + assert(metricData.metricType === "classification") + assert(metricData.modelName === "svm") + } + + test("MetricData.createTable preserves sequences") { + val tableData = Map( + "mse" -> Seq(0.1, 0.2, 0.3), + "rmse" -> Seq(0.316, 0.447, 0.548) + ) + val metricData = MetricData.createTable(tableData, "regression", "linearRegression") + + assert(metricData.data("mse") === Seq(0.1, 0.2, 0.3)) + assert(metricData.data("rmse").length === 3) + assert(metricData.metricType === "regression") + assert(metricData.modelName === "linearRegression") + } + + test("MetricData.create handles empty map") { + val metricData = MetricData.create(Map.empty[String, Double], "test", "model") + assert(metricData.data.isEmpty) + } + + test("MetricData.createTable handles empty map") { + val metricData = MetricData.createTable(Map.empty[String, Seq[Double]], "test", "model") + assert(metricData.data.isEmpty) + } + + test("ConvenienceTypes type aliases work correctly") { + import ConvenienceTypes._ + val name: UniqueName = "testMetric" + val table: MetricTable = Map(name -> Seq(TypedMetric("m1", 1.0))) + + assert(name === "testMetric") + assert(table.contains("testMetric")) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/contracts/VerifyParams.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/contracts/VerifyParams.scala new file mode 100644 index 00000000000..8088b3cebae --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/contracts/VerifyParams.scala @@ -0,0 +1,189 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.core.contracts + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.Identifiable + +// Test implementations of the param traits +class TestHasInputCol(override val uid: String) + extends HasInputCol { + def this() = this(Identifiable.randomUID("TestHasInputCol")) + override def copy(extra: ParamMap): TestHasInputCol = defaultCopy(extra) +} + +class TestHasOutputCol(override val uid: String) + extends HasOutputCol { + def this() = this(Identifiable.randomUID("TestHasOutputCol")) + override def copy(extra: ParamMap): TestHasOutputCol = defaultCopy(extra) +} + +class TestHasInputCols(override val uid: String) + extends HasInputCols { + def this() = this(Identifiable.randomUID("TestHasInputCols")) + override def copy(extra: ParamMap): TestHasInputCols = defaultCopy(extra) +} + +class TestHasOutputCols(override val uid: String) + extends HasOutputCols { + def this() = this(Identifiable.randomUID("TestHasOutputCols")) + override def copy(extra: ParamMap): TestHasOutputCols = defaultCopy(extra) +} + +class TestHasLabelCol(override val uid: String) + extends HasLabelCol { + def this() = this(Identifiable.randomUID("TestHasLabelCol")) + override def copy(extra: ParamMap): TestHasLabelCol = defaultCopy(extra) +} + +class TestHasFeaturesCol(override val uid: String) + extends HasFeaturesCol { + def this() = this(Identifiable.randomUID("TestHasFeaturesCol")) + override def copy(extra: ParamMap): TestHasFeaturesCol = defaultCopy(extra) +} + +class TestHasWeightCol(override val uid: String) + extends HasWeightCol { + def this() = this(Identifiable.randomUID("TestHasWeightCol")) + override def copy(extra: ParamMap): TestHasWeightCol = defaultCopy(extra) +} + +class TestHasScoredLabelsCol(override val uid: String) + extends HasScoredLabelsCol { + def this() = this(Identifiable.randomUID("TestHasScoredLabelsCol")) + override def copy(extra: ParamMap): TestHasScoredLabelsCol = defaultCopy(extra) +} + +class TestHasScoresCol(override val uid: String) + extends HasScoresCol { + def this() = this(Identifiable.randomUID("TestHasScoresCol")) + override def copy(extra: ParamMap): TestHasScoresCol = defaultCopy(extra) +} + +class TestHasScoredProbabilitiesCol(override val uid: String) + extends HasScoredProbabilitiesCol { + def this() = this(Identifiable.randomUID("TestHasScoredProbabilitiesCol")) + override def copy(extra: ParamMap): TestHasScoredProbabilitiesCol = defaultCopy(extra) +} + +class TestHasEvaluationMetric(override val uid: String) + extends HasEvaluationMetric { + def this() = this(Identifiable.randomUID("TestHasEvaluationMetric")) + override def copy(extra: ParamMap): TestHasEvaluationMetric = defaultCopy(extra) +} + +class TestHasValidationIndicatorCol(override val uid: String) + extends HasValidationIndicatorCol { + def this() = this(Identifiable.randomUID("TestHasValidationIndicatorCol")) + override def copy(extra: ParamMap): TestHasValidationIndicatorCol = defaultCopy(extra) +} + +class TestHasInitScoreCol(override val uid: String) + extends HasInitScoreCol { + def this() = this(Identifiable.randomUID("TestHasInitScoreCol")) + override def copy(extra: ParamMap): TestHasInitScoreCol = defaultCopy(extra) +} + +class TestHasGroupCol(override val uid: String) + extends HasGroupCol { + def this() = this(Identifiable.randomUID("TestHasGroupCol")) + override def copy(extra: ParamMap): TestHasGroupCol = defaultCopy(extra) +} + +class VerifyParams extends TestBase { + + test("HasInputCol set and get work correctly") { + val obj = new TestHasInputCol() + obj.setInputCol("myInput") + assert(obj.getInputCol === "myInput") + } + + test("HasOutputCol set and get work correctly") { + val obj = new TestHasOutputCol() + obj.setOutputCol("myOutput") + assert(obj.getOutputCol === "myOutput") + } + + test("HasInputCols set and get work correctly") { + val obj = new TestHasInputCols() + val cols = Array("col1", "col2", "col3") + obj.setInputCols(cols) + assert(obj.getInputCols.sameElements(cols)) + } + + test("HasOutputCols set and get work correctly") { + val obj = new TestHasOutputCols() + val cols = Array("out1", "out2") + obj.setOutputCols(cols) + assert(obj.getOutputCols.sameElements(cols)) + } + + test("HasLabelCol set and get work correctly") { + val obj = new TestHasLabelCol() + obj.setLabelCol("target") + assert(obj.getLabelCol === "target") + } + + test("HasFeaturesCol set and get work correctly") { + val obj = new TestHasFeaturesCol() + obj.setFeaturesCol("features") + assert(obj.getFeaturesCol === "features") + } + + test("HasWeightCol set and get work correctly") { + val obj = new TestHasWeightCol() + obj.setWeightCol("weight") + assert(obj.getWeightCol === "weight") + } + + test("HasScoredLabelsCol set and get work correctly") { + val obj = new TestHasScoredLabelsCol() + obj.setScoredLabelsCol("scoredLabels") + assert(obj.getScoredLabelsCol === "scoredLabels") + } + + test("HasScoresCol set and get work correctly") { + val obj = new TestHasScoresCol() + obj.setScoresCol("scores") + assert(obj.getScoresCol === "scores") + } + + test("HasScoredProbabilitiesCol set and get work correctly") { + val obj = new TestHasScoredProbabilitiesCol() + obj.setScoredProbabilitiesCol("probs") + assert(obj.getScoredProbabilitiesCol === "probs") + } + + test("HasEvaluationMetric set and get work correctly") { + val obj = new TestHasEvaluationMetric() + obj.setEvaluationMetric("accuracy") + assert(obj.getEvaluationMetric === "accuracy") + } + + test("HasValidationIndicatorCol set and get work correctly") { + val obj = new TestHasValidationIndicatorCol() + obj.setValidationIndicatorCol("isValidation") + assert(obj.getValidationIndicatorCol === "isValidation") + } + + test("HasInitScoreCol set and get work correctly") { + val obj = new TestHasInitScoreCol() + obj.setInitScoreCol("initScore") + assert(obj.getInitScoreCol === "initScore") + } + + test("HasGroupCol set and get work correctly") { + val obj = new TestHasGroupCol() + obj.setGroupCol("group") + assert(obj.getGroupCol === "group") + } + + // Test chaining + test("param setters return this for chaining") { + val obj = new TestHasInputCol() + val result = obj.setInputCol("test") + assert(result eq obj) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/env/VerifyPackageUtils.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/env/VerifyPackageUtils.scala new file mode 100644 index 00000000000..86af2aa3665 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/env/VerifyPackageUtils.scala @@ -0,0 +1,58 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.core.env + +import com.microsoft.azure.synapse.ml.build.BuildInfo +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyPackageUtils extends TestBase { + + test("ScalaVersionSuffix is extracted from BuildInfo.scalaVersion") { + // Scala version is typically like "2.12.15" or "2.13.x" + // ScalaVersionSuffix should be "2.12" or "2.13" + val suffix = PackageUtils.ScalaVersionSuffix + assert(suffix.split("\\.").length === 2) + assert(suffix.startsWith("2.")) + } + + test("PackageGroup has expected value") { + assert(PackageUtils.PackageGroup === "com.microsoft.azure") + } + + test("PackageName contains scala version suffix") { + assert(PackageUtils.PackageName.startsWith("synapseml_")) + assert(PackageUtils.PackageName.contains(PackageUtils.ScalaVersionSuffix)) + } + + test("PackageMavenCoordinate has correct format") { + val coord = PackageUtils.PackageMavenCoordinate + // Format should be: group:artifact:version + val parts = coord.split(":") + assert(parts.length === 3) + assert(parts(0) === PackageUtils.PackageGroup) + assert(parts(1) === PackageUtils.PackageName) + assert(parts(2) === BuildInfo.version) + } + + test("PackageRepository is a valid URL") { + val repo = PackageUtils.PackageRepository + assert(repo.startsWith("https://")) + } + + test("SparkMavenPackageList contains package coordinate") { + val packages = PackageUtils.SparkMavenPackageList + assert(packages.contains(PackageUtils.PackageMavenCoordinate)) + } + + test("SparkMavenPackageList contains spark-avro") { + val packages = PackageUtils.SparkMavenPackageList + assert(packages.contains("spark-avro")) + } + + test("SparkMavenRepositoryList is set") { + val repos = PackageUtils.SparkMavenRepositoryList + assert(repos.nonEmpty) + assert(repos === PackageUtils.PackageRepository) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/metrics/VerifyMetricConstants.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/metrics/VerifyMetricConstants.scala new file mode 100644 index 00000000000..bc484328bcb --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/metrics/VerifyMetricConstants.scala @@ -0,0 +1,155 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.core.metrics + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyMetricConstants extends TestBase { + + // Regression metrics tests + test("regression metric constants have expected values") { + assert(MetricConstants.MseSparkMetric === "mse") + assert(MetricConstants.RmseSparkMetric === "rmse") + assert(MetricConstants.R2SparkMetric === "r2") + assert(MetricConstants.MaeSparkMetric === "mae") + assert(MetricConstants.RegressionMetricsName === "regression") + } + + test("RegressionMetrics set contains all regression metrics") { + assert(MetricConstants.RegressionMetrics.contains(MetricConstants.MseSparkMetric)) + assert(MetricConstants.RegressionMetrics.contains(MetricConstants.RmseSparkMetric)) + assert(MetricConstants.RegressionMetrics.contains(MetricConstants.R2SparkMetric)) + assert(MetricConstants.RegressionMetrics.contains(MetricConstants.MaeSparkMetric)) + assert(MetricConstants.RegressionMetrics.contains(MetricConstants.RegressionMetricsName)) + assert(MetricConstants.RegressionMetrics.size === 5) + } + + // Classification metrics tests + test("classification metric constants have expected values") { + assert(MetricConstants.AreaUnderROCMetric === "areaUnderROC") + assert(MetricConstants.AucSparkMetric === "AUC") + assert(MetricConstants.AccuracySparkMetric === "accuracy") + assert(MetricConstants.PrecisionSparkMetric === "precision") + assert(MetricConstants.RecallSparkMetric === "recall") + assert(MetricConstants.ClassificationMetricsName === "classification") + } + + test("ClassificationMetrics set contains all classification metrics") { + assert(MetricConstants.ClassificationMetrics.contains(MetricConstants.AreaUnderROCMetric)) + assert(MetricConstants.ClassificationMetrics.contains(MetricConstants.AucSparkMetric)) + assert(MetricConstants.ClassificationMetrics.contains(MetricConstants.AccuracySparkMetric)) + assert(MetricConstants.ClassificationMetrics.contains(MetricConstants.PrecisionSparkMetric)) + assert(MetricConstants.ClassificationMetrics.contains(MetricConstants.RecallSparkMetric)) + assert(MetricConstants.ClassificationMetrics.contains(MetricConstants.ClassificationMetricsName)) + assert(MetricConstants.ClassificationMetrics.size === 6) + } + + test("AllSparkMetrics constant") { + assert(MetricConstants.AllSparkMetrics === "all") + } + + // Column name tests + test("regression column names have expected values") { + assert(MetricConstants.MseColumnName === "mean_squared_error") + assert(MetricConstants.RmseColumnName === "root_mean_squared_error") + assert(MetricConstants.R2ColumnName === "R^2") + assert(MetricConstants.MaeColumnName === "mean_absolute_error") + } + + test("classification column names have expected values") { + assert(MetricConstants.AucColumnName === "AUC") + assert(MetricConstants.PrecisionColumnName === "precision") + assert(MetricConstants.RecallColumnName === "recall") + assert(MetricConstants.AccuracyColumnName === "accuracy") + } + + test("multiclass column names have expected values") { + assert(MetricConstants.AverageAccuracy === "average_accuracy") + assert(MetricConstants.MacroAveragedRecall === "macro_averaged_recall") + assert(MetricConstants.MacroAveragedPrecision === "macro_averaged_precision") + assert(MetricConstants.ConfusionMatrix === "confusion_matrix") + } + + // MetricToColumnName mapping tests + test("MetricToColumnName contains correct mappings") { + assert(MetricConstants.MetricToColumnName(MetricConstants.AccuracySparkMetric) === + MetricConstants.AccuracyColumnName) + assert(MetricConstants.MetricToColumnName(MetricConstants.PrecisionSparkMetric) === + MetricConstants.PrecisionColumnName) + assert(MetricConstants.MetricToColumnName(MetricConstants.RecallSparkMetric) === + MetricConstants.RecallColumnName) + assert(MetricConstants.MetricToColumnName(MetricConstants.MseSparkMetric) === + MetricConstants.MseColumnName) + assert(MetricConstants.MetricToColumnName(MetricConstants.RmseSparkMetric) === + MetricConstants.RmseColumnName) + assert(MetricConstants.MetricToColumnName(MetricConstants.R2SparkMetric) === + MetricConstants.R2ColumnName) + assert(MetricConstants.MetricToColumnName(MetricConstants.MaeSparkMetric) === + MetricConstants.MaeColumnName) + } + + // Column lists tests + test("ClassificationColumns contains expected columns") { + assert(MetricConstants.ClassificationColumns === List( + MetricConstants.AccuracyColumnName, + MetricConstants.PrecisionColumnName, + MetricConstants.RecallColumnName)) + } + + test("RegressionColumns contains expected columns") { + assert(MetricConstants.RegressionColumns === List( + MetricConstants.MseColumnName, + MetricConstants.RmseColumnName, + MetricConstants.R2ColumnName, + MetricConstants.MaeColumnName)) + } + + // Evaluation type tests + test("evaluation type constants have expected values") { + assert(MetricConstants.ClassificationEvaluationType === "Classification") + assert(MetricConstants.EvaluationType === "evaluation_type") + } + + // ROC column names tests + test("ROC column names have expected values") { + assert(MetricConstants.FpRateROCColumnName === "false_positive_rate") + assert(MetricConstants.TpRateROCColumnName === "true_positive_rate") + assert(MetricConstants.FpRateROCLog === "fpr") + assert(MetricConstants.TpRateROCLog === "tpr") + } + + test("BinningThreshold has expected value") { + assert(MetricConstants.BinningThreshold === 1000) + } + + // Per instance metrics tests + test("per instance metric constants have expected values") { + assert(MetricConstants.L1LossMetric === "L1_loss") + assert(MetricConstants.L2LossMetric === "L2_loss") + assert(MetricConstants.LogLossMetric === "log_loss") + } + + test("RegressionPerInstanceMetrics contains expected metrics") { + assert(MetricConstants.RegressionPerInstanceMetrics.contains(MetricConstants.RegressionMetricsName)) + assert(MetricConstants.RegressionPerInstanceMetrics.size === 1) + } + + test("ClassificationPerInstanceMetrics contains expected metrics") { + assert(MetricConstants.ClassificationPerInstanceMetrics.contains(MetricConstants.ClassificationMetricsName)) + assert(MetricConstants.ClassificationPerInstanceMetrics.size === 1) + } + + // FindBestModelMetrics tests + test("FindBestModelMetrics contains all expected metrics") { + assert(MetricConstants.FindBestModelMetrics.contains(MetricConstants.MseSparkMetric)) + assert(MetricConstants.FindBestModelMetrics.contains(MetricConstants.RmseSparkMetric)) + assert(MetricConstants.FindBestModelMetrics.contains(MetricConstants.R2SparkMetric)) + assert(MetricConstants.FindBestModelMetrics.contains(MetricConstants.MaeSparkMetric)) + assert(MetricConstants.FindBestModelMetrics.contains(MetricConstants.AccuracySparkMetric)) + assert(MetricConstants.FindBestModelMetrics.contains(MetricConstants.PrecisionSparkMetric)) + assert(MetricConstants.FindBestModelMetrics.contains(MetricConstants.RecallSparkMetric)) + assert(MetricConstants.FindBestModelMetrics.contains(MetricConstants.AucSparkMetric)) + assert(MetricConstants.FindBestModelMetrics.size === 8) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/schema/VerifyBinaryFileSchema.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/schema/VerifyBinaryFileSchema.scala new file mode 100644 index 00000000000..86b81ad7a6f --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/schema/VerifyBinaryFileSchema.scala @@ -0,0 +1,92 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.core.schema + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class VerifyBinaryFileSchema extends TestBase { + + test("ColumnSchema has correct structure") { + val schema = BinaryFileSchema.ColumnSchema + assert(schema.fields.length === 2) + + val pathField = schema.fields(0) + assert(pathField.name === "path") + assert(pathField.dataType === StringType) + assert(pathField.nullable === true) + + val bytesField = schema.fields(1) + assert(bytesField.name === "bytes") + assert(bytesField.dataType === BinaryType) + assert(bytesField.nullable === true) + } + + test("Schema wraps ColumnSchema in value field") { + val schema = BinaryFileSchema.Schema + assert(schema.fields.length === 1) + + val valueField = schema.fields(0) + assert(valueField.name === "value") + assert(valueField.dataType === BinaryFileSchema.ColumnSchema) + assert(valueField.nullable === true) + } + + test("getPath extracts path from Row") { + val testPath = "/path/to/file.bin" + val testBytes = Array[Byte](1, 2, 3) + val row = Row(testPath, testBytes) + + assert(BinaryFileSchema.getPath(row) === testPath) + } + + test("getBytes extracts bytes from Row") { + val testPath = "/path/to/file.bin" + val testBytes = Array[Byte](1, 2, 3, 4, 5) + val row = Row(testPath, testBytes) + + val result = BinaryFileSchema.getBytes(row) + assert(result.sameElements(testBytes)) + } + + test("getPath handles empty path") { + val row = Row("", Array[Byte]()) + assert(BinaryFileSchema.getPath(row) === "") + } + + test("getBytes handles empty byte array") { + val row = Row("test", Array[Byte]()) + assert(BinaryFileSchema.getBytes(row).length === 0) + } + + test("isBinaryFile with DataType returns true for matching schema") { + assert(BinaryFileSchema.isBinaryFile(BinaryFileSchema.ColumnSchema) === true) + } + + test("isBinaryFile with DataType returns false for non-matching schema") { + assert(BinaryFileSchema.isBinaryFile(StringType) === false) + assert(BinaryFileSchema.isBinaryFile(BinaryType) === false) + assert(BinaryFileSchema.isBinaryFile(IntegerType) === false) + + // Different StructType + val differentSchema = StructType(Seq( + StructField("path", StringType, true) + )) + assert(BinaryFileSchema.isBinaryFile(differentSchema) === false) + } + + test("isBinaryFile with StructField returns true for matching schema") { + val field = StructField("data", BinaryFileSchema.ColumnSchema, true) + assert(BinaryFileSchema.isBinaryFile(field) === true) + } + + test("isBinaryFile with StructField returns false for non-matching schema") { + val stringField = StructField("data", StringType, true) + assert(BinaryFileSchema.isBinaryFile(stringField) === false) + + val binaryField = StructField("data", BinaryType, true) + assert(BinaryFileSchema.isBinaryFile(binaryField) === false) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/schema/VerifyImageSchemaUtils.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/schema/VerifyImageSchemaUtils.scala new file mode 100644 index 00000000000..3aebc578557 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/schema/VerifyImageSchemaUtils.scala @@ -0,0 +1,86 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.core.schema + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.image.ImageSchema +import org.apache.spark.sql.types._ + +class VerifyImageSchemaUtils extends TestBase { + + test("ColumnSchemaNullable has correct structure") { + val schema = ImageSchemaUtils.ColumnSchemaNullable + assert(schema.fields.length === 6) + + assert(schema.fields(0).name === "origin") + assert(schema.fields(0).dataType === StringType) + + assert(schema.fields(1).name === "height") + assert(schema.fields(1).dataType === IntegerType) + + assert(schema.fields(2).name === "width") + assert(schema.fields(2).dataType === IntegerType) + + assert(schema.fields(3).name === "nChannels") + assert(schema.fields(3).dataType === IntegerType) + + assert(schema.fields(4).name === "mode") + assert(schema.fields(4).dataType === IntegerType) + + assert(schema.fields(5).name === "data") + assert(schema.fields(5).dataType === BinaryType) + } + + test("ColumnSchemaNullable fields are all nullable") { + val schema = ImageSchemaUtils.ColumnSchemaNullable + schema.fields.foreach { field => + assert(field.nullable === true, s"Field ${field.name} should be nullable") + } + } + + test("ImageSchemaNullable wraps ColumnSchemaNullable") { + val schema = ImageSchemaUtils.ImageSchemaNullable + assert(schema.fields.length === 1) + assert(schema.fields(0).name === "image") + assert(schema.fields(0).dataType === ImageSchemaUtils.ColumnSchemaNullable) + assert(schema.fields(0).nullable === true) + } + + test("isImage returns true for ImageSchema.columnSchema") { + assert(ImageSchemaUtils.isImage(ImageSchema.columnSchema) === true) + } + + test("isImage returns false for non-image types") { + assert(ImageSchemaUtils.isImage(StringType) === false) + assert(ImageSchemaUtils.isImage(BinaryType) === false) + assert(ImageSchemaUtils.isImage(IntegerType) === false) + } + + test("isImage returns false for different struct type") { + val differentSchema = StructType(Seq( + StructField("path", StringType, true) + )) + assert(ImageSchemaUtils.isImage(differentSchema) === false) + } + + test("isImage with StructField returns true for image column") { + val imageField = StructField("img", ImageSchema.columnSchema, true) + assert(ImageSchemaUtils.isImage(imageField) === true) + } + + test("isImage with StructField returns false for non-image column") { + val stringField = StructField("text", StringType, true) + assert(ImageSchemaUtils.isImage(stringField) === false) + } + + test("ColumnSchemaNullable matches ImageSchema.columnSchema structurally") { + // The nullable version should match structurally when ignoring nullability + val isMatch = DataType.equalsStructurally( + ImageSchemaUtils.ColumnSchemaNullable, + ImageSchema.columnSchema, + ignoreNullability = true + ) + assert(isMatch === true) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/schema/VerifySchemaConstants.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/schema/VerifySchemaConstants.scala new file mode 100644 index 00000000000..19908e14fa3 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/schema/VerifySchemaConstants.scala @@ -0,0 +1,54 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.core.schema + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifySchemaConstants extends TestBase { + + test("score column kind constants have expected values") { + assert(SchemaConstants.ScoreColumnKind === "ScoreColumnKind") + assert(SchemaConstants.ScoreValueKind === "ScoreValueKind") + } + + test("label and score column names have expected values") { + assert(SchemaConstants.TrueLabelsColumn === "true_labels") + assert(SchemaConstants.ScoredLabelsColumn === "scored_labels") + assert(SchemaConstants.ScoresColumn === "scores") + assert(SchemaConstants.ScoredProbabilitiesColumn === "scored_probabilities") + } + + test("model and tag constants have expected values") { + assert(SchemaConstants.ScoreModelPrefix === "score_model") + assert(SchemaConstants.MMLTag === "mml") + assert(SchemaConstants.MLlibTag === "ml_attr") + } + + test("residual column names have expected values") { + assert(SchemaConstants.TreatmentResidualColumn === "treatment_residual") + assert(SchemaConstants.OutcomeResidualColumn === "outcome_residual") + } + + test("categorical metadata tag constants have expected values") { + assert(SchemaConstants.Ordinal === "ord") + assert(SchemaConstants.MLlibTypeTag === "type") + assert(SchemaConstants.ValuesString === "vals") + assert(SchemaConstants.ValuesInt === "vals_int") + assert(SchemaConstants.ValuesLong === "vals_long") + assert(SchemaConstants.ValuesDouble === "vals_double") + assert(SchemaConstants.ValuesBool === "vals_bool") + assert(SchemaConstants.HasNullLevels === "null_exists") + } + + test("ML kind constants have expected values") { + assert(SchemaConstants.ClassificationKind === "Classification") + assert(SchemaConstants.RegressionKind === "Regression") + } + + test("Spark native column name constants have expected values") { + assert(SchemaConstants.SparkPredictionColumn === "prediction") + assert(SchemaConstants.SparkRawPredictionColumn === "rawPrediction") + assert(SchemaConstants.SparkProbabilityColumn === "probability") + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/VerifyOsUtils.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/VerifyOsUtils.scala new file mode 100644 index 00000000000..7281e686002 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/VerifyOsUtils.scala @@ -0,0 +1,21 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.core.utils + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyOsUtils extends TestBase { + + test("IsWindows returns a boolean based on os.name property") { + val osName = System.getProperty("os.name").toLowerCase() + val expected = osName.indexOf("win") >= 0 + assert(OsUtils.IsWindows === expected) + } + + test("IsWindows is consistent across multiple accesses") { + val first = OsUtils.IsWindows + val second = OsUtils.IsWindows + assert(first === second) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/VerifyExplainerSharedParams.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/VerifyExplainerSharedParams.scala new file mode 100644 index 00000000000..b5b68cb905a --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/VerifyExplainerSharedParams.scala @@ -0,0 +1,136 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.explainers + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.util.Identifiable + +class VerifyExplainerSharedParams extends TestBase { + + // Test implementation that mixes in all the traits + private class TestExplainerParams(override val uid: String) + extends Params + with HasMetricsCol + with HasNumSamples + with HasTokensCol + with HasSuperpixelCol + with HasSamplingFraction + with HasExplainTarget { + override def copy(extra: ParamMap): Params = this + } + + test("HasMetricsCol sets and gets metrics column") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + params.setMetricsCol("metrics_output") + assert(params.getMetricsCol === "metrics_output") + } + + test("HasMetricsCol param has correct name and doc") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + assert(params.metricsCol.name === "metricsCol") + assert(params.metricsCol.doc.contains("fitting metrics")) + } + + test("HasNumSamples sets and gets number of samples") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + params.setNumSamples(100) + assert(params.getNumSamples === 100) + } + + test("HasNumSamples getNumSamplesOpt returns Some when set") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + params.setNumSamples(50) + assert(params.getNumSamplesOpt === Some(50)) + } + + test("HasNumSamples getNumSamplesOpt returns None when not set") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + assert(params.getNumSamplesOpt.isEmpty) + } + + test("HasNumSamples validates numSamples must be positive") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + assertThrows[IllegalArgumentException] { + params.setNumSamples(0) + } + assertThrows[IllegalArgumentException] { + params.setNumSamples(-1) + } + } + + test("HasTokensCol sets and gets tokens column") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + params.setTokensCol("tokens") + assert(params.getTokensCol === "tokens") + } + + test("HasTokensCol param has correct name and doc") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + assert(params.tokensCol.name === "tokensCol") + assert(params.tokensCol.doc.contains("tokens")) + } + + test("HasSuperpixelCol sets and gets superpixel column") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + params.setSuperpixelCol("superpixels") + assert(params.getSuperpixelCol === "superpixels") + } + + test("HasSuperpixelCol param has correct name and doc") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + assert(params.superpixelCol.name === "superpixelCol") + assert(params.superpixelCol.doc.contains("superpixel")) + } + + test("HasSamplingFraction sets and gets sampling fraction") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + params.setSamplingFraction(0.5) + assert(params.getSamplingFraction === 0.5) + } + + test("HasSamplingFraction validates fraction in range 0 to 1") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + // Valid values + params.setSamplingFraction(0.0) + params.setSamplingFraction(1.0) + params.setSamplingFraction(0.5) + + // Invalid values + assertThrows[IllegalArgumentException] { + params.setSamplingFraction(-0.1) + } + assertThrows[IllegalArgumentException] { + params.setSamplingFraction(1.1) + } + } + + test("HasExplainTarget has default targetCol value") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + assert(params.getTargetCol === "probability") + } + + test("HasExplainTarget sets and gets targetCol") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + params.setTargetCol("prediction") + assert(params.getTargetCol === "prediction") + } + + test("HasExplainTarget has default empty targetClasses") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + assert(params.getTargetClasses.isEmpty) + } + + test("HasExplainTarget sets and gets targetClasses") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + params.setTargetClasses(Array(0, 1, 2)) + assert(params.getTargetClasses === Array(0, 1, 2)) + } + + test("HasExplainTarget sets and gets targetClassesCol") { + val params = new TestExplainerParams(Identifiable.randomUID("test")) + params.setTargetClassesCol("classes") + assert(params.getTargetClassesCol === "classes") + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/VerifyRowUtils.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/VerifyRowUtils.scala new file mode 100644 index 00000000000..691242833d2 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/VerifyRowUtils.scala @@ -0,0 +1,100 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.explainers + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class VerifyRowUtils extends TestBase { + + import RowUtils.RowCanGetAsDouble + + test("getAsDouble converts Byte to Double") { + val schema = StructType(Seq(StructField("value", ByteType))) + val row = Row(10.toByte) + val result = row.getAsDouble(0) + assert(result === 10.0) + } + + test("getAsDouble converts Short to Double") { + val schema = StructType(Seq(StructField("value", ShortType))) + val row = Row(100.toShort) + val result = row.getAsDouble(0) + assert(result === 100.0) + } + + test("getAsDouble converts Int to Double") { + val schema = StructType(Seq(StructField("value", IntegerType))) + val row = Row(42) + val result = row.getAsDouble(0) + assert(result === 42.0) + } + + test("getAsDouble converts Long to Double") { + val schema = StructType(Seq(StructField("value", LongType))) + val row = Row(1000L) + val result = row.getAsDouble(0) + assert(result === 1000.0) + } + + test("getAsDouble converts Float to Double") { + val schema = StructType(Seq(StructField("value", FloatType))) + val row = Row(3.14f) + val result = row.getAsDouble(0) + assert(Math.abs(result - 3.14) < 0.001) + } + + test("getAsDouble returns Double unchanged") { + val schema = StructType(Seq(StructField("value", DoubleType))) + val row = Row(2.718) + val result = row.getAsDouble(0) + assert(result === 2.718) + } + + test("getAsDouble by column name") { + val schema = StructType(Seq( + StructField("other", StringType), + StructField("value", IntegerType) + )) + // Create a row that knows about its schema + val row = new org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema( + Array("test", 42), + schema + ) + val result = row.getAsDouble("value") + assert(result === 42.0) + } + + test("getAsDouble throws for unsupported types") { + val schema = StructType(Seq(StructField("value", StringType))) + val row = Row("not a number") + assertThrows[Exception] { + row.getAsDouble(0) + } + } + + test("getAsDouble handles negative numbers") { + val schema = StructType(Seq(StructField("value", IntegerType))) + val row = Row(-100) + val result = row.getAsDouble(0) + assert(result === -100.0) + } + + test("getAsDouble handles zero") { + val schema = StructType(Seq(StructField("value", IntegerType))) + val row = Row(0) + val result = row.getAsDouble(0) + assert(result === 0.0) + } + + test("getAsDouble handles large Long values") { + val schema = StructType(Seq(StructField("value", LongType))) + val largeValue = Long.MaxValue / 2 + val row = Row(largeValue) + val result = row.getAsDouble(0) + // Note: some precision loss expected for very large longs + assert(result > 0) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/binary/VerifyBinaryFileFormat.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/binary/VerifyBinaryFileFormat.scala new file mode 100644 index 00000000000..f255271f444 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/binary/VerifyBinaryFileFormat.scala @@ -0,0 +1,123 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.io.binary + +import com.microsoft.azure.synapse.ml.core.schema.BinaryFileSchema +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.hadoop.fs.Path + +import java.io.File +import java.nio.file.Files + +class VerifyBinaryFileFormat extends TestBase { + + test("BinaryFileFormat shortName returns 'binary'") { + val format = new BinaryFileFormat() + assert(format.shortName() === "binary") + } + + test("BinaryFileFormat toString returns 'Binary'") { + val format = new BinaryFileFormat() + assert(format.toString === "Binary") + } + + test("BinaryFileFormat isSplitable returns false") { + val format = new BinaryFileFormat() + val result = format.isSplitable(spark, Map.empty, new Path("/test")) + assert(!result) + } + + test("BinaryFileFormat inferSchema returns BinaryFileSchema") { + val format = new BinaryFileFormat() + val schema = format.inferSchema(spark, Map.empty, Seq.empty) + assert(schema.isDefined) + assert(schema.get === BinaryFileSchema.Schema) + } + + test("BinaryFileFormat equals returns true for same type") { + val format1 = new BinaryFileFormat() + val format2 = new BinaryFileFormat() + assert(format1.equals(format2)) + } + + test("BinaryFileFormat equals returns false for different type") { + val format = new BinaryFileFormat() + assert(!format.equals("not a format")) + } + + test("BinaryFileFormat hashCode is consistent") { + val format1 = new BinaryFileFormat() + val format2 = new BinaryFileFormat() + assert(format1.hashCode() === format2.hashCode()) + } + + test("ConfUtils.getHConf returns SerializableConfiguration") { + import spark.implicits._ + val df = Seq(1, 2, 3).toDF("num") + val hConf = ConfUtils.getHConf(df) + assert(hConf != null) + } + + test("BinaryFileFormat can read binary files") { + // Create a temp directory with binary files + val tempDir = Files.createTempDirectory("binary-test").toFile + tempDir.deleteOnExit() + + val testFile = new File(tempDir, "test.bin") + Files.write(testFile.toPath, Array[Byte](1, 2, 3, 4, 5)) + testFile.deleteOnExit() + + val df = spark.read.format("binary").load(tempDir.getAbsolutePath) + assert(df.count() >= 1) + assert(df.schema === BinaryFileSchema.Schema) + } + + test("BinaryFileFormat reads file content correctly") { + val tempDir = Files.createTempDirectory("binary-content-test").toFile + tempDir.deleteOnExit() + + val testContent = "Hello, Binary!".getBytes + val testFile = new File(tempDir, "content.bin") + Files.write(testFile.toPath, testContent) + testFile.deleteOnExit() + + val df = spark.read.format("binary").load(tempDir.getAbsolutePath) + val row = df.collect().head + val struct = row.getStruct(0) + val bytes = struct.getAs[Array[Byte]]("bytes") + assert(bytes.sameElements(testContent)) + } + + test("BinaryFileFormat respects subsample option") { + val tempDir = Files.createTempDirectory("binary-subsample-test").toFile + tempDir.deleteOnExit() + + // Create multiple files + for (i <- 1 to 10) { + val testFile = new File(tempDir, s"file$i.bin") + Files.write(testFile.toPath, Array[Byte](i.toByte)) + testFile.deleteOnExit() + } + + // Read with subsample=0.0 should return fewer or no rows + val df = spark.read.format("binary") + .option("subsample", "0.0") + .load(tempDir.getAbsolutePath) + assert(df.count() === 0) + } + + test("BinaryFileFormat reads multiple files") { + val tempDir = Files.createTempDirectory("binary-multi-test").toFile + tempDir.deleteOnExit() + + for (i <- 1 to 3) { + val testFile = new File(tempDir, s"multi$i.bin") + Files.write(testFile.toPath, s"content$i".getBytes) + testFile.deleteOnExit() + } + + val df = spark.read.format("binary").load(tempDir.getAbsolutePath) + assert(df.count() === 3) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/http/VerifyClients.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/http/VerifyClients.scala new file mode 100644 index 00000000000..aab8ffb099a --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/http/VerifyClients.scala @@ -0,0 +1,178 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.io.http + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +import scala.concurrent.ExecutionContext +import scala.concurrent.duration._ + +class VerifyClients extends TestBase { + + // Test SingleThreadedClient + private class TestSingleThreadedClient extends SingleThreadedClient { + override protected type Client = Unit + override protected type ResponseType = String + override protected type RequestType = String + override protected val internalClient: Client = () + + override protected def sendRequestWithContext( + request: RequestWithContext): ResponseWithContext = { + request.request match { + case Some(req) => ResponseWithContext(Some(s"response-$req"), request.context) + case None => ResponseWithContext(None, request.context) + } + } + } + + test("SingleThreadedClient processes requests sequentially") { + val client = new TestSingleThreadedClient + val requests = Iterator( + client.RequestWithContext(Some("a"), Some("ctx-a")), + client.RequestWithContext(Some("b"), Some("ctx-b")), + client.RequestWithContext(Some("c"), Some("ctx-c")) + ) + + val responses = client.sendRequestsWithContext(requests).toList + assert(responses.length === 3) + assert(responses(0).response === Some("response-a")) + assert(responses(0).context === Some("ctx-a")) + assert(responses(1).response === Some("response-b")) + assert(responses(2).response === Some("response-c")) + } + + test("SingleThreadedClient handles empty requests") { + val client = new TestSingleThreadedClient + val requests = Iterator( + client.RequestWithContext(None, Some("ctx")) + ) + + val responses = client.sendRequestsWithContext(requests).toList + assert(responses.length === 1) + assert(responses.head.response === None) + assert(responses.head.context === Some("ctx")) + } + + test("SingleThreadedClient handles empty iterator") { + val client = new TestSingleThreadedClient + val requests = Iterator.empty.asInstanceOf[Iterator[client.RequestWithContext]] + val responses = client.sendRequestsWithContext(requests).toList + assert(responses.isEmpty) + } + + // Test AsyncClient + private class TestAsyncClient(conc: Int, to: Duration) + (implicit ec: ExecutionContext) + extends AsyncClient(conc, to) { + + override protected type Client = Unit + override protected type ResponseType = String + override protected type RequestType = String + override protected val internalClient: Client = () + + override protected def sendRequestWithContext( + request: RequestWithContext): ResponseWithContext = { + request.request match { + case Some(req) => + // Simulate some work + Thread.sleep(10) + ResponseWithContext(Some(s"async-$req"), request.context) + case None => + ResponseWithContext(None, request.context) + } + } + } + + test("AsyncClient processes requests concurrently") { + implicit val ec: ExecutionContext = ExecutionContext.global + val client = new TestAsyncClient(4, 30.seconds) + + val requests = Iterator( + client.RequestWithContext(Some("1"), None), + client.RequestWithContext(Some("2"), None), + client.RequestWithContext(Some("3"), None), + client.RequestWithContext(Some("4"), None) + ) + + val responses = client.sendRequestsWithContext(requests).toList + assert(responses.length === 4) + assert(responses.map(_.response).forall(_.isDefined)) + assert(responses.flatMap(_.response).toSet === Set("async-1", "async-2", "async-3", "async-4")) + } + + test("AsyncClient preserves context") { + implicit val ec: ExecutionContext = ExecutionContext.global + val client = new TestAsyncClient(2, 30.seconds) + + val requests = Iterator( + client.RequestWithContext(Some("a"), Some("context-a")), + client.RequestWithContext(Some("b"), Some("context-b")) + ) + + val responses = client.sendRequestsWithContext(requests).toList + assert(responses.length === 2) + // Note: order may vary due to concurrency, but contexts should match requests + responses.foreach { resp => + resp.response match { + case Some("async-a") => assert(resp.context === Some("context-a")) + case Some("async-b") => assert(resp.context === Some("context-b")) + case _ => fail("Unexpected response") + } + } + } + + test("AsyncClient handles empty requests") { + implicit val ec: ExecutionContext = ExecutionContext.global + val client = new TestAsyncClient(2, 30.seconds) + + val requests = Iterator( + client.RequestWithContext(None, Some("ctx")) + ) + + val responses = client.sendRequestsWithContext(requests).toList + assert(responses.length === 1) + assert(responses.head.response === None) + } + + test("AsyncClient respects concurrency parameter") { + implicit val ec: ExecutionContext = ExecutionContext.global + val client = new TestAsyncClient(2, 30.seconds) + assert(client.concurrency === 2) + } + + test("AsyncClient respects timeout parameter") { + implicit val ec: ExecutionContext = ExecutionContext.global + val client = new TestAsyncClient(2, 45.seconds) + assert(client.timeout === 45.seconds) + } + + // Test RequestWithContext and ResponseWithContext + test("RequestWithContext can be created with request only") { + val client = new TestSingleThreadedClient + val req = new client.RequestWithContext(Some("test")) + assert(req.request === Some("test")) + assert(req.context === None) + } + + test("RequestWithContext can be created with request and context") { + val client = new TestSingleThreadedClient + val req = client.RequestWithContext(Some("test"), Some("ctx")) + assert(req.request === Some("test")) + assert(req.context === Some("ctx")) + } + + test("ResponseWithContext can be created with response only") { + val client = new TestSingleThreadedClient + val resp = new client.ResponseWithContext(Some("test")) + assert(resp.response === Some("test")) + assert(resp.context === None) + } + + test("ResponseWithContext can be created with response and context") { + val client = new TestSingleThreadedClient + val resp = client.ResponseWithContext(Some("test"), Some("ctx")) + assert(resp.response === Some("test")) + assert(resp.context === Some("ctx")) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/http/VerifyHTTPSchema.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/http/VerifyHTTPSchema.scala new file mode 100644 index 00000000000..5cd1c3843d9 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/http/VerifyHTTPSchema.scala @@ -0,0 +1,179 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.io.http + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.http.ProtocolVersion +import org.apache.http.message.BasicHeader + +class VerifyHTTPSchema extends TestBase { + + test("HeaderData stores name and value") { + val header = HeaderData("Content-Type", "application/json") + assert(header.name === "Content-Type") + assert(header.value === "application/json") + } + + test("HeaderData.toHTTPCore creates BasicHeader") { + val header = HeaderData("Authorization", "Bearer token123") + val httpHeader = header.toHTTPCore + assert(httpHeader.getName === "Authorization") + assert(httpHeader.getValue === "Bearer token123") + } + + test("HeaderData can be created from HTTP Header") { + val basicHeader = new BasicHeader("X-Custom", "custom-value") + val headerData = new HeaderData(basicHeader) + assert(headerData.name === "X-Custom") + assert(headerData.value === "custom-value") + } + + test("ProtocolVersionData stores protocol info") { + val pvd = ProtocolVersionData("HTTP", 1, 1) + assert(pvd.protocol === "HTTP") + assert(pvd.major === 1) + assert(pvd.minor === 1) + } + + test("ProtocolVersionData.toHTTPCore creates ProtocolVersion") { + val pvd = ProtocolVersionData("HTTP", 2, 0) + val pv = pvd.toHTTPCore + assert(pv.getProtocol === "HTTP") + assert(pv.getMajor === 2) + assert(pv.getMinor === 0) + } + + test("ProtocolVersionData can be created from ProtocolVersion") { + val pv = new ProtocolVersion("HTTP", 1, 0) + val pvd = new ProtocolVersionData(pv) + assert(pvd.protocol === "HTTP") + assert(pvd.major === 1) + assert(pvd.minor === 0) + } + + test("StatusLineData stores status info") { + val pvd = ProtocolVersionData("HTTP", 1, 1) + val sld = StatusLineData(pvd, 200, "OK") + assert(sld.protocolVersion === pvd) + assert(sld.statusCode === 200) + assert(sld.reasonPhrase === "OK") + } + + test("RequestLineData stores request info") { + val pvd = Some(ProtocolVersionData("HTTP", 1, 1)) + val rld = RequestLineData("GET", "https://example.com", pvd) + assert(rld.method === "GET") + assert(rld.uri === "https://example.com") + assert(rld.protocolVersion === pvd) + } + + test("RequestLineData works without protocol version") { + val rld = RequestLineData("POST", "/api/data", None) + assert(rld.method === "POST") + assert(rld.uri === "/api/data") + assert(rld.protocolVersion.isEmpty) + } + + test("EntityData stores content info") { + val content = "test content".getBytes + val entity = EntityData( + content = content, + contentEncoding = None, + contentLength = Some(content.length.toLong), + contentType = Some(HeaderData("Content-Type", "text/plain")), + isChunked = false, + isRepeatable = true, + isStreaming = false + ) + assert(entity.content === content) + assert(entity.contentLength === Some(content.length.toLong)) + assert(entity.isChunked === false) + assert(entity.isRepeatable === true) + } + + test("HTTPResponseData stores response info") { + val pvd = ProtocolVersionData("HTTP", 1, 1) + val statusLine = StatusLineData(pvd, 200, "OK") + val response = HTTPResponseData( + headers = Array(HeaderData("Content-Type", "application/json")), + entity = None, + statusLine = statusLine, + locale = "en-US" + ) + assert(response.headers.length === 1) + assert(response.statusLine.statusCode === 200) + assert(response.locale === "en-US") + assert(response.entity.isEmpty) + } + + test("HTTPRequestData stores request info") { + val requestLine = RequestLineData("GET", "https://api.example.com/data", None) + val request = HTTPRequestData( + requestLine = requestLine, + headers = Array(HeaderData("Accept", "application/json")), + entity = None + ) + assert(request.requestLine.method === "GET") + assert(request.headers.length === 1) + assert(request.entity.isEmpty) + } + + test("HTTPRequestData with entity") { + val content = """{"key": "value"}""".getBytes + val entity = EntityData( + content = content, + contentEncoding = None, + contentLength = Some(content.length.toLong), + contentType = Some(HeaderData("Content-Type", "application/json")), + isChunked = false, + isRepeatable = true, + isStreaming = false + ) + val requestLine = RequestLineData("POST", "/api/submit", None) + val request = HTTPRequestData( + requestLine = requestLine, + headers = Array(), + entity = Some(entity) + ) + assert(request.entity.isDefined) + assert(request.entity.get.content === content) + } + + test("HTTPSchema.Response has schema") { + val schema = HTTPSchema.Response + assert(schema != null) + } + + test("HTTPSchema.Request has schema") { + val schema = HTTPSchema.Request + assert(schema != null) + } + + test("HTTPSchema.stringToResponse creates response") { + val response = HTTPSchema.stringToResponse("test body", 200, "OK") + assert(response.statusLine.statusCode === 200) + assert(response.statusLine.reasonPhrase === "OK") + assert(response.entity.isDefined) + } + + test("HTTPSchema.emptyResponse creates response without entity") { + val response = HTTPSchema.emptyResponse(404, "Not Found") + assert(response.statusLine.statusCode === 404) + assert(response.statusLine.reasonPhrase === "Not Found") + assert(response.entity.isEmpty) + } + + test("HTTPSchema.binaryToResponse creates response with binary content") { + val content = Array[Byte](1, 2, 3, 4, 5) + val response = HTTPSchema.binaryToResponse(content, 200, "OK") + assert(response.entity.isDefined) + assert(response.entity.get.content === content) + } + + test("HeaderValues.PlatformInfo returns a string") { + val platformInfo = HeaderValues.PlatformInfo + assert(platformInfo != null) + assert(platformInfo.nonEmpty) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/image/VerifyImageUtils.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/image/VerifyImageUtils.scala new file mode 100644 index 00000000000..d601f7f1ff0 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/image/VerifyImageUtils.scala @@ -0,0 +1,159 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.io.image + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +import java.awt.image.BufferedImage + +class VerifyImageUtils extends TestBase { + + test("channelsToType returns TYPE_BYTE_GRAY for 1 channel") { + assert(ImageUtils.channelsToType(1) === BufferedImage.TYPE_BYTE_GRAY) + } + + test("channelsToType returns TYPE_3BYTE_BGR for 3 channels") { + assert(ImageUtils.channelsToType(3) === BufferedImage.TYPE_3BYTE_BGR) + } + + test("channelsToType returns TYPE_4BYTE_ABGR for 4 channels") { + assert(ImageUtils.channelsToType(4) === BufferedImage.TYPE_4BYTE_ABGR) + } + + test("channelsToType throws for unsupported channel count") { + assertThrows[UnsupportedOperationException] { + ImageUtils.channelsToType(2) + } + assertThrows[UnsupportedOperationException] { + ImageUtils.channelsToType(5) + } + assertThrows[UnsupportedOperationException] { + ImageUtils.channelsToType(0) + } + } + + test("toBufferedImage creates image from byte array - grayscale") { + val width = 2 + val height = 2 + val nChannels = 1 + val bytes = Array[Byte](10, 20, 30, 40) + + val img = ImageUtils.toBufferedImage(bytes, width, height, nChannels) + + assert(img.getWidth === width) + assert(img.getHeight === height) + assert(img.getType === BufferedImage.TYPE_BYTE_GRAY) + } + + test("toBufferedImage creates image from byte array - RGB") { + val width = 2 + val height = 2 + val nChannels = 3 + // BGR format: 2x2 pixels = 12 bytes + val bytes = Array[Byte]( + 0, 0, 100.toByte, // pixel (0,0) + 0, 100.toByte, 0, // pixel (1,0) + 100.toByte, 0, 0, // pixel (0,1) + 50, 50, 50 // pixel (1,1) + ) + + val img = ImageUtils.toBufferedImage(bytes, width, height, nChannels) + + assert(img.getWidth === width) + assert(img.getHeight === height) + assert(img.getType === BufferedImage.TYPE_3BYTE_BGR) + } + + test("toBufferedImage creates image from byte array - RGBA") { + val width = 2 + val height = 1 + val nChannels = 4 + // ABGR format: 2x1 pixels = 8 bytes + val bytes = Array[Byte](0, 0, 100.toByte, -1, 100.toByte, 0, 0, -1) + + val img = ImageUtils.toBufferedImage(bytes, width, height, nChannels) + + assert(img.getWidth === width) + assert(img.getHeight === height) + assert(img.getType === BufferedImage.TYPE_4BYTE_ABGR) + } + + test("safeRead returns None for null bytes") { + // scalastyle:off null + val result = ImageUtils.safeRead(null) + // scalastyle:on null + assert(result.isEmpty) + } + + test("safeRead returns None for invalid image bytes") { + val invalidBytes = Array[Byte](1, 2, 3, 4, 5) + val result = ImageUtils.safeRead(invalidBytes) + assert(result.isEmpty) + } + + test("toSparkImage converts BufferedImage to Row format") { + val img = new BufferedImage(10, 10, BufferedImage.TYPE_3BYTE_BGR) + val row = ImageUtils.toSparkImage(img, Some("/path/to/image.jpg")) + + assert(row != null) + // The row should contain an inner row with image data + val innerRow = row.getAs[org.apache.spark.sql.Row](0) + assert(innerRow.getAs[String](0) === Some("/path/to/image.jpg")) + assert(innerRow.getAs[Int](1) === 10) // height + assert(innerRow.getAs[Int](2) === 10) // width + assert(innerRow.getAs[Int](3) === 3) // nChannels + } + + test("toSparkImage works without path") { + val img = new BufferedImage(5, 5, BufferedImage.TYPE_BYTE_GRAY) + val row = ImageUtils.toSparkImage(img, None) + + assert(row != null) + val innerRow = row.getAs[org.apache.spark.sql.Row](0) + assert(innerRow.getAs[Int](3) === 1) // grayscale = 1 channel + } + + test("toSparkImageTuple returns correct tuple for grayscale image") { + val img = new BufferedImage(4, 3, BufferedImage.TYPE_BYTE_GRAY) + val (path, height, width, nChannels, mode, decoded) = ImageUtils.toSparkImageTuple(img, Some("/test")) + + assert(path === Some("/test")) + assert(height === 3) + assert(width === 4) + assert(nChannels === 1) + assert(decoded.length === 4 * 3 * 1) // width * height * channels + } + + test("toSparkImageTuple returns correct tuple for RGB image") { + val img = new BufferedImage(4, 3, BufferedImage.TYPE_3BYTE_BGR) + val (path, height, width, nChannels, mode, decoded) = ImageUtils.toSparkImageTuple(img) + + assert(path === None) + assert(height === 3) + assert(width === 4) + assert(nChannels === 3) + assert(decoded.length === 4 * 3 * 3) + } + + test("toSparkImageTuple returns correct tuple for RGBA image") { + val img = new BufferedImage(2, 2, BufferedImage.TYPE_4BYTE_ABGR) + val (_, height, width, nChannels, _, decoded) = ImageUtils.toSparkImageTuple(img) + + assert(height === 2) + assert(width === 2) + assert(nChannels === 4) + assert(decoded.length === 2 * 2 * 4) + } + + test("roundtrip: toSparkImage then toBufferedImage preserves dimensions") { + val original = new BufferedImage(8, 6, BufferedImage.TYPE_3BYTE_BGR) + val sparkRow = ImageUtils.toSparkImage(original) + val innerRow = sparkRow.getAs[org.apache.spark.sql.Row](0) + + val reconstructed = ImageUtils.toBufferedImage(innerRow) + + assert(reconstructed.getWidth === original.getWidth) + assert(reconstructed.getHeight === original.getHeight) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/VerifyFeatureNames.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/VerifyFeatureNames.scala new file mode 100644 index 00000000000..c1366c347ad --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/VerifyFeatureNames.scala @@ -0,0 +1,63 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.logging + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifyFeatureNames extends TestBase { + + test("AiServices constants have expected values") { + assert(FeatureNames.AiServices.Anomaly === "aiservice-anomalydetection") + assert(FeatureNames.AiServices.Face === "aiservice-face") + assert(FeatureNames.AiServices.Form === "aiservice-form") + assert(FeatureNames.AiServices.Language === "aiservice-language") + assert(FeatureNames.AiServices.OpenAI === "aiservice-openai") + assert(FeatureNames.AiServices.Search === "aiservice-search") + assert(FeatureNames.AiServices.Speech === "aiservice-speech") + assert(FeatureNames.AiServices.Text === "aiservice-text") + assert(FeatureNames.AiServices.Translate === "aiservice-translate") + assert(FeatureNames.AiServices.Vision === "aiservice-vision") + } + + test("ML feature constants have expected values") { + assert(FeatureNames.AutoML === "automl") + assert(FeatureNames.Causal === "causal") + assert(FeatureNames.Explainers === "explainers") + assert(FeatureNames.Featurize === "featurize") + assert(FeatureNames.Geospatial === "geospatial") + assert(FeatureNames.Image === "image") + assert(FeatureNames.IsolationForest === "isolationforest") + assert(FeatureNames.NearestNeighbor === "nearestneighbor") + assert(FeatureNames.Recommendation === "recommendation") + } + + test("Deep learning and model feature constants have expected values") { + assert(FeatureNames.DeepLearning === "deeplearning") + assert(FeatureNames.OpenCV === "opencv") + assert(FeatureNames.LightGBM === "lightgbm") + assert(FeatureNames.VowpalWabbit === "vowpalwabbit") + } + + test("Core constant has expected value") { + assert(FeatureNames.Core === "core") + } + + test("All AiServices constants start with 'aiservice-' prefix") { + val aiServices = Seq( + FeatureNames.AiServices.Anomaly, + FeatureNames.AiServices.Face, + FeatureNames.AiServices.Form, + FeatureNames.AiServices.Language, + FeatureNames.AiServices.OpenAI, + FeatureNames.AiServices.Search, + FeatureNames.AiServices.Speech, + FeatureNames.AiServices.Text, + FeatureNames.AiServices.Translate, + FeatureNames.AiServices.Vision + ) + aiServices.foreach { name => + assert(name.startsWith("aiservice-"), s"$name should start with 'aiservice-'") + } + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/VerifySynapseMLLogging.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/VerifySynapseMLLogging.scala new file mode 100644 index 00000000000..12167b2db78 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/VerifySynapseMLLogging.scala @@ -0,0 +1,93 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.logging + +import com.microsoft.azure.synapse.ml.build.BuildInfo +import com.microsoft.azure.synapse.ml.core.test.base.TestBase + +class VerifySynapseMLLogging extends TestBase { + + test("RequiredLogFields stores uid, className, and method") { + val fields = RequiredLogFields("test-uid-123", "TestClass", "testMethod") + assert(fields.uid === "test-uid-123") + assert(fields.className === "TestClass") + assert(fields.method === "testMethod") + } + + test("RequiredLogFields.toMap contains all required fields") { + val fields = RequiredLogFields("uid1", "MyClass", "myMethod") + val map = fields.toMap + + assert(map("modelUid") === "uid1") + assert(map("className") === "MyClass") + assert(map("method") === "myMethod") + assert(map("libraryVersion") === BuildInfo.version) + assert(map("libraryName") === "SynapseML") + assert(map("protocolVersion") === "0.0.1") + } + + test("RequiredLogFields.toMap size is 6") { + val fields = RequiredLogFields("uid", "class", "method") + assert(fields.toMap.size === 6) + } + + test("RequiredErrorFields stores errorType and errorMessage") { + val fields = RequiredErrorFields("java.lang.RuntimeException", "Test error message") + assert(fields.errorType === "java.lang.RuntimeException") + assert(fields.errorMessage === "Test error message") + } + + test("RequiredErrorFields.toMap contains error fields") { + val fields = RequiredErrorFields("ErrorType", "ErrorMessage") + val map = fields.toMap + + assert(map("errorType") === "ErrorType") + assert(map("errorMessage") === "ErrorMessage") + } + + test("RequiredErrorFields can be created from Exception") { + val exception = new RuntimeException("Test exception message") + val fields = new RequiredErrorFields(exception) + + assert(fields.errorType === "java.lang.RuntimeException") + assert(fields.errorMessage === "Test exception message") + } + + test("RequiredErrorFields handles exception with no message") { + // scalastyle:off null + val exception = new RuntimeException(None.orNull: String) + val fields = new RequiredErrorFields(exception) + + assert(fields.errorType === "java.lang.RuntimeException") + assert(Option(fields.errorMessage).isEmpty) + // scalastyle:on null + } + + test("SynapseMLLogging.HadoopKeysToLog contains expected mappings") { + val keys = SynapseMLLogging.HadoopKeysToLog + + assert(keys("trident.artifact.id") === "artifactId") + assert(keys("trident.workspace.id") === "workspaceId") + assert(keys("trident.capacity.id") === "capacityId") + assert(keys("trident.artifact.workspace.id") === "artifactWorkspaceId") + assert(keys("trident.lakehouse.id") === "lakehouseId") + assert(keys("trident.activity.id") === "livyId") + assert(keys("trident.artifact.type") === "artifactType") + assert(keys("trident.tenant.id") === "tenantId") + } + + test("SynapseMLLogging.HadoopKeysToLog size is 8") { + assert(SynapseMLLogging.HadoopKeysToLog.size === 8) + } + + test("SynapseMLLogging.LoggedClasses is a mutable set") { + try { + SynapseMLLogging.LoggedClasses.add("TestClass") + assert(SynapseMLLogging.LoggedClasses.contains("TestClass")) + } finally { + // Clean up to avoid leaking state into other tests + SynapseMLLogging.LoggedClasses.remove("TestClass") + } + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/common/VerifyPlatformDetails.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/common/VerifyPlatformDetails.scala index 3ca5b1f3f6a..8be4732ef64 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/common/VerifyPlatformDetails.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/common/VerifyPlatformDetails.scala @@ -7,24 +7,57 @@ import com.microsoft.azure.synapse.ml.core.test.base.TestBase class VerifyPlatformDetails extends TestBase { - test("currentPlatform returns a non-null non-empty string") { - val platform = PlatformDetails.currentPlatform() - assert(platform != null) + test("Platform constants have expected values") { + assert(PlatformDetails.PlatformSynapseInternal === "synapse_internal") + assert(PlatformDetails.PlatformSynapse === "synapse") + assert(PlatformDetails.PlatformBinder === "binder") + assert(PlatformDetails.PlatformDatabricks === "databricks") + assert(PlatformDetails.PlatformUnknown === "unknown") + assert(PlatformDetails.SynapseProjectName === "Microsoft.ProjectArcadia") + } + + test("CurrentPlatform returns a string") { + val platform = PlatformDetails.CurrentPlatform assert(platform.nonEmpty) } - test("currentPlatform returns one of the known platform values") { - val known = Set( + test("currentPlatform returns a valid platform string") { + val platform = PlatformDetails.currentPlatform() + val validPlatforms = Set( PlatformDetails.PlatformSynapseInternal, PlatformDetails.PlatformSynapse, - PlatformDetails.PlatformDatabricks, PlatformDetails.PlatformBinder, + PlatformDetails.PlatformDatabricks, PlatformDetails.PlatformUnknown ) - assert(known.contains(PlatformDetails.currentPlatform())) + assert(validPlatforms.contains(platform)) + } + + test("runningOnSynapseInternal returns boolean") { + val result = PlatformDetails.runningOnSynapseInternal() + assert(result.isInstanceOf[Boolean]) } - test("runningOnFabric is consistent with runningOnSynapseInternal") { + test("runningOnSynapse returns boolean") { + val result = PlatformDetails.runningOnSynapse() + assert(result.isInstanceOf[Boolean]) + } + + test("runningOnFabric returns same as runningOnSynapseInternal") { assert(PlatformDetails.runningOnFabric() === PlatformDetails.runningOnSynapseInternal()) } + + test("CurrentPlatform returns a known platform value") { + val platform = PlatformDetails.CurrentPlatform + // Expected platforms when running tests on a local/dev environment + val expectedOnDev = Set(PlatformDetails.PlatformUnknown, PlatformDetails.PlatformBinder) + // Allow-list of platforms that may legitimately appear in CI (e.g., Synapse or Databricks) + val ciPlatforms = Set( + PlatformDetails.PlatformSynapseInternal, + PlatformDetails.PlatformSynapse, + PlatformDetails.PlatformDatabricks + ) + // Verify that the platform is either a dev-expected value or a known CI platform + assert(expectedOnDev.contains(platform) || ciPlatforms.contains(platform)) + } } diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/common/VerifyScrubber.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/common/VerifyScrubber.scala index 5db71ef9d3f..f97a237c6c6 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/common/VerifyScrubber.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/logging/common/VerifyScrubber.scala @@ -7,50 +7,67 @@ import com.microsoft.azure.synapse.ml.core.test.base.TestBase class VerifyScrubber extends TestBase { - test("SASScrubber replaces SAS signature in URL") { - val url = "https://storage.blob.core.windows.net/container?sv=2021-06-08" + - "&sig=abcdef1234567890abcdef1234567890abcdef12345%3d&se=2024-01-01" - val scrubbed = SASScrubber.scrub(url) - assert(scrubbed.contains("sig=####")) - assert(!scrubbed.contains("abcdef1234567890")) - } - - test("SASScrubber leaves strings without SAS tokens unchanged") { - val message = "This is a normal log message with no SAS token" - assert(SASScrubber.scrub(message) === message) - } - - test("SASScrubber handles multiple SAS tokens in same string") { - val url1 = "https://storage1.blob.core.windows.net/c1?sig=abcdef1234567890abcdef1234567890abcdef12345%3d" - val url2 = "https://storage2.blob.core.windows.net/c2?sig=123456abcdef7890123456abcdef7890123456abc78%3d" - val combined = s"$url1 and $url2" - val scrubbed = SASScrubber.scrub(combined) - assert(scrubbed === "https://storage1.blob.core.windows.net/c1?sig=####" + - " and https://storage2.blob.core.windows.net/c2?sig=####") - } - - test("SASScrubber is case insensitive") { - val upperUrl = "https://storage.blob.core.windows.net/container" + - "?SIG=ABCDEF1234567890ABCDEF1234567890ABCDEF12345%3D&se=2024-01-01" - val mixedUrl = "https://storage.blob.core.windows.net/container" + - "?Sig=AbCdEf1234567890AbCdEf1234567890AbCdEf12345%3d&se=2024-01-01" - val scrubbedUpper = SASScrubber.scrub(upperUrl) - val scrubbedMixed = SASScrubber.scrub(mixedUrl) - assert(scrubbedUpper.contains("sig=####")) - assert(scrubbedMixed.contains("sig=####")) - assert(!scrubbedUpper.contains("ABCDEF1234567890")) - assert(!scrubbedMixed.contains("AbCdEf1234567890")) - } - - test("SASScrubber preserves rest of URL around the signature") { - val url = "https://storage.blob.core.windows.net/container?sv=2021-06-08" + - "&sig=abcdef1234567890abcdef1234567890abcdef12345%3d&se=2024-01-01&sp=r" - val scrubbed = SASScrubber.scrub(url) - assert(scrubbed.contains("sv=2021-06-08")) - assert(scrubbed.contains("se=2024-01-01")) - assert(scrubbed.contains("sp=r")) - assert(scrubbed.contains("sig=####")) - assert(scrubbed === "https://storage.blob.core.windows.net/container?sv=2021-06-08" + - "&sig=####&se=2024-01-01&sp=r") + test("SASScrubber scrubs SAS signature from URL") { + // SAS tokens typically contain sig= followed by base64-like encoded signature + // Dummy URL for testing - not a real endpoint + // scalastyle:off line.size.limit + val urlWithSas = "https://storage.blob.core.windows.net/container/file?sig=abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG%3d" + // scalastyle:on line.size.limit + val result = SASScrubber.scrub(urlWithSas) + assert(result.contains("sig=####")) + assert(!result.contains("abcdefghijklmnopqrstuvwxyz")) + } + + test("SASScrubber handles multiple SAS signatures in one string") { + // Use signatures that match the pattern: sig= followed by 43-63 alphanumeric/% chars ending with %3d + val sig1 = "sig=abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG%3d" + val sig2 = "sig=XYZ987wvu654tsr321qpo098nml765kji432hgf109AB%3d" + val message = s"URL1: $sig1 and URL2: $sig2" + val result = SASScrubber.scrub(message) + // Both signatures should be replaced + assert(!result.contains("abcdefghijklmnopqrstuvwxyz")) + assert(!result.contains("XYZ987wvu654")) + assert(result.contains("sig=####")) + } + + test("SASScrubber leaves non-SAS content unchanged") { + val message = "This is a normal log message without any signatures" + val result = SASScrubber.scrub(message) + assert(result === message) + } + + test("SASScrubber is case insensitive for sig parameter") { + val lowerCase = "https://test.com?sig=abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG%3d" + val upperCase = "https://test.com?SIG=abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG%3d" + val mixedCase = "https://test.com?SiG=abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG%3d" + + assert(SASScrubber.scrub(lowerCase).contains("sig=####")) + assert(SASScrubber.scrub(upperCase).contains("sig=####")) + assert(SASScrubber.scrub(mixedCase).contains("sig=####")) + } + + test("SASScrubber handles empty string") { + assert(SASScrubber.scrub("") === "") + } + + test("SASScrubber handles string with only sig= but invalid signature") { + // Too short signature should not be scrubbed + val shortSig = "https://test.com?sig=abc" + assert(SASScrubber.scrub(shortSig) === shortSig) + } + + test("SASScrubber preserves text before and after signature") { + val message = "Prefix text sig=abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG%3d suffix text" + val result = SASScrubber.scrub(message) + assert(result.startsWith("Prefix text")) + assert(result.endsWith("suffix text")) + assert(result.contains("sig=####")) + } + + test("SASScrubber implements Scrubber trait") { + val scrubber: Scrubber = SASScrubber + val message = "Test sig=abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG%3d" + val result = scrubber.scrub(message) + assert(result.contains("sig=####")) } } diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyByteArrayParam.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyByteArrayParam.scala new file mode 100644 index 00000000000..97fdfba0c6c --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyByteArrayParam.scala @@ -0,0 +1,74 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.{ParamMap, Params} + +class VerifyByteArrayParam extends TestBase { + + private class TestParamsHolder extends Params { + override val uid: String = "test-holder" + val bytesParam = new ByteArrayParam(this, "bytes", "A byte array param") + override def copy(extra: ParamMap): Params = this + } + + test("ByteArrayParam can be created with basic constructor") { + val holder = new TestParamsHolder + assert(holder.bytesParam.name === "bytes") + assert(holder.bytesParam.doc === "A byte array param") + } + + test("ByteArrayParam accepts empty byte array") { + val holder = new TestParamsHolder + holder.set(holder.bytesParam, Array.empty[Byte]) + assert(holder.get(holder.bytesParam).exists(_.isEmpty)) + } + + test("ByteArrayParam accepts byte array with data") { + val holder = new TestParamsHolder + val data = Array[Byte](1, 2, 3, 4, 5) + holder.set(holder.bytesParam, data) + assert(holder.get(holder.bytesParam).exists(_.sameElements(data))) + } + + test("ByteArrayParam accepts large byte array") { + val holder = new TestParamsHolder + val data = Array.fill(1000)(42.toByte) + holder.set(holder.bytesParam, data) + assert(holder.get(holder.bytesParam).exists(_.length === 1000)) + } + + test("ByteArrayParam accepts byte array with all byte values") { + val holder = new TestParamsHolder + val data = (-128 to 127).map(_.toByte).toArray + holder.set(holder.bytesParam, data) + assert(holder.get(holder.bytesParam).exists(_.length === 256)) + } + + test("ByteArrayParam with custom validator") { + val holder = new Params { + override val uid: String = "test" + val nonEmptyBytes = new ByteArrayParam( + this, "nonEmpty", "Non-empty byte array", + (arr: Array[Byte]) => arr.nonEmpty + ) + override def copy(extra: ParamMap): Params = this + } + holder.set(holder.nonEmptyBytes, Array[Byte](1, 2, 3)) + } + + test("ByteArrayParam can be cleared") { + val holder = new TestParamsHolder + holder.set(holder.bytesParam, Array[Byte](1, 2, 3)) + assert(holder.isSet(holder.bytesParam)) + holder.clear(holder.bytesParam) + assert(!holder.isSet(holder.bytesParam)) + } + + test("ByteArrayParam returns None when not set") { + val holder = new TestParamsHolder + assert(holder.get(holder.bytesParam).isEmpty) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyDataFrameParam.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyDataFrameParam.scala new file mode 100644 index 00000000000..485be61190f --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyDataFrameParam.scala @@ -0,0 +1,163 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.linalg.{DenseVector, Vectors} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types._ + +class VerifyDataFrameParam extends TestBase { + + import spark.implicits._ + + private class TestParamsHolder extends Params { + override val uid: String = "test-holder" + val dfParam = new DataFrameParam(this, "dataFrame", "A dataframe param") + override def copy(extra: ParamMap): Params = this + } + + // DataFrameParam basic tests + test("DataFrameParam can be created with basic constructor") { + val holder = new TestParamsHolder + assert(holder.dfParam.name === "dataFrame") + assert(holder.dfParam.doc === "A dataframe param") + } + + test("DataFrameParam accepts DataFrame") { + val holder = new TestParamsHolder + val df = Seq(("a", 1), ("b", 2)).toDF("str", "num") + holder.set(holder.dfParam, df) + assert(holder.isSet(holder.dfParam)) + } + + test("DataFrameParam pyValue returns DF reference") { + val holder = new TestParamsHolder + val df = Seq(1, 2, 3).toDF("num") + val pyVal = holder.dfParam.pyValue(df) + assert(pyVal === "dataFrameDF") + } + + test("DataFrameParam pyLoadLine generates Python code") { + val holder = new TestParamsHolder + val pyCode = holder.dfParam.pyLoadLine(1) + assert(pyCode.contains("spark.read.parquet")) + assert(pyCode.contains("model-1.model")) + assert(pyCode.contains("complexParams")) + assert(pyCode.contains("dataFrame")) + } + + test("DataFrameParam rValue returns DF reference") { + val holder = new TestParamsHolder + val df = Seq(1, 2, 3).toDF("num") + val rVal = holder.dfParam.rValue(df) + assert(rVal === "dataFrameDF") + } + + test("DataFrameParam rLoadLine generates R code") { + val holder = new TestParamsHolder + val rCode = holder.dfParam.rLoadLine(2) + assert(rCode.contains("spark_read_parquet")) + assert(rCode.contains("model-2.model")) + assert(rCode.contains("complexParams")) + assert(rCode.contains("dataFrame")) + } + + // DataFrameEquality tests + test("DataFrameEquality compares equal DataFrames correctly") { + val holder = new TestParamsHolder + val df1 = Seq(("a", 1), ("b", 2)).toDF("str", "num") + val df2 = Seq(("a", 1), ("b", 2)).toDF("str", "num") + // Should not throw + holder.dfParam.assertEquality(df1, df2) + } + + test("DataFrameEquality detects different DataFrames") { + val holder = new TestParamsHolder + val df1 = Seq(("a", 1), ("b", 2)).toDF("str", "num") + val df2 = Seq(("a", 1), ("c", 3)).toDF("str", "num") + assertThrows[AssertionError] { + holder.dfParam.assertEquality(df1, df2) + } + } + + test("DataFrameEquality throws for non-DataFrame types") { + val holder = new TestParamsHolder + assertThrows[AssertionError] { + holder.dfParam.assertEquality("not a df", "also not a df") + } + } + + // DataFrameEquality implicit tests + test("DataFrameEquality handles doubles with tolerance") { + val holder = new TestParamsHolder + val df1 = Seq(1.0, 2.0, 3.0).toDF("num") + val df2 = Seq(1.00001, 2.00001, 3.00001).toDF("num") + // Should not throw due to tolerance + holder.dfParam.assertEquality(df1, df2) + } + + test("DataFrameEquality handles NaN values") { + val holder = new TestParamsHolder + val df1 = Seq(Double.NaN, 2.0).toDF("num") + val df2 = Seq(Double.NaN, 2.0).toDF("num") + holder.dfParam.assertEquality(df1, df2) + } + + test("DataFrameEquality handles DenseVector columns") { + val holder = new TestParamsHolder + // Create DataFrames with vector columns using VectorAssembler + import org.apache.spark.ml.feature.VectorAssembler + val baseData1 = Seq((1.0, 2.0), (3.0, 4.0)).toDF("a", "b") + val baseData2 = Seq((1.0, 2.0), (3.0, 4.0)).toDF("a", "b") + val assembler = new VectorAssembler().setInputCols(Array("a", "b")).setOutputCol("vec") + val df1 = assembler.transform(baseData1).select("vec") + val df2 = assembler.transform(baseData2).select("vec") + holder.dfParam.assertEquality(df1, df2) + } + + test("DataFrameEquality handles binary columns") { + val holder = new TestParamsHolder + val df1 = Seq(Array[Byte](1, 2, 3)).toDF("bytes") + val df2 = Seq(Array[Byte](1, 2, 3)).toDF("bytes") + holder.dfParam.assertEquality(df1, df2) + } + + test("DataFrameEquality detects different column names") { + val holder = new TestParamsHolder + val df1 = Seq(1, 2, 3).toDF("col1") + val df2 = Seq(1, 2, 3).toDF("col2") + assertThrows[AssertionError] { + holder.dfParam.assertEquality(df1, df2) + } + } + + test("DataFrameEquality detects different row counts") { + val holder = new TestParamsHolder + val df1 = Seq(1, 2, 3).toDF("num") + val df2 = Seq(1, 2).toDF("num") + assertThrows[AssertionError] { + holder.dfParam.assertEquality(df1, df2) + } + } + + test("DataFrameParam with custom validator") { + val holder = new Params { + override val uid: String = "test" + val nonEmptyDf = new DataFrameParam( + this, "nonEmpty", "Non-empty dataframe", + (df: DataFrame) => df.count() > 0 + ) + override def copy(extra: ParamMap): Params = this + } + val df = Seq(1, 2, 3).toDF("num") + holder.set(holder.nonEmptyDf, df) + } + + test("DataFrameParam sortInDataframeEquality is true") { + val holder = new TestParamsHolder + assert(holder.dfParam.sortInDataframeEquality) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyDataTypeParam.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyDataTypeParam.scala new file mode 100644 index 00000000000..9c433f2b22a --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyDataTypeParam.scala @@ -0,0 +1,122 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.sql.types._ + +class VerifyDataTypeParam extends TestBase { + + private class TestParamsHolder extends Params { + override val uid: String = "test-holder" + val dataTypeParam = new DataTypeParam(this, "dataType", "A data type param") + override def copy(extra: ParamMap): Params = this + } + + test("DataTypeParam can be created with basic constructor") { + val holder = new TestParamsHolder + assert(holder.dataTypeParam.name === "dataType") + assert(holder.dataTypeParam.doc === "A data type param") + } + + test("DataTypeParam accepts StringType") { + val holder = new TestParamsHolder + holder.set(holder.dataTypeParam, StringType) + assert(holder.get(holder.dataTypeParam).contains(StringType)) + } + + test("DataTypeParam accepts IntegerType") { + val holder = new TestParamsHolder + holder.set(holder.dataTypeParam, IntegerType) + assert(holder.get(holder.dataTypeParam).contains(IntegerType)) + } + + test("DataTypeParam accepts DoubleType") { + val holder = new TestParamsHolder + holder.set(holder.dataTypeParam, DoubleType) + assert(holder.get(holder.dataTypeParam).contains(DoubleType)) + } + + test("DataTypeParam accepts BooleanType") { + val holder = new TestParamsHolder + holder.set(holder.dataTypeParam, BooleanType) + assert(holder.get(holder.dataTypeParam).contains(BooleanType)) + } + + test("DataTypeParam accepts ArrayType") { + val holder = new TestParamsHolder + val arrayType = ArrayType(StringType) + holder.set(holder.dataTypeParam, arrayType) + assert(holder.get(holder.dataTypeParam).contains(arrayType)) + } + + test("DataTypeParam accepts MapType") { + val holder = new TestParamsHolder + val mapType = MapType(StringType, IntegerType) + holder.set(holder.dataTypeParam, mapType) + assert(holder.get(holder.dataTypeParam).contains(mapType)) + } + + test("DataTypeParam accepts StructType") { + val holder = new TestParamsHolder + val structType = StructType(Seq( + StructField("name", StringType), + StructField("age", IntegerType) + )) + holder.set(holder.dataTypeParam, structType) + assert(holder.get(holder.dataTypeParam).contains(structType)) + } + + test("DataTypeParam accepts nested StructType") { + val holder = new TestParamsHolder + val nestedType = StructType(Seq( + StructField("outer", StructType(Seq( + StructField("inner", StringType) + ))) + )) + holder.set(holder.dataTypeParam, nestedType) + assert(holder.get(holder.dataTypeParam).contains(nestedType)) + } + + test("DataTypeParam accepts TimestampType") { + val holder = new TestParamsHolder + holder.set(holder.dataTypeParam, TimestampType) + assert(holder.get(holder.dataTypeParam).contains(TimestampType)) + } + + test("DataTypeParam accepts DateType") { + val holder = new TestParamsHolder + holder.set(holder.dataTypeParam, DateType) + assert(holder.get(holder.dataTypeParam).contains(DateType)) + } + + test("DataTypeParam accepts BinaryType") { + val holder = new TestParamsHolder + holder.set(holder.dataTypeParam, BinaryType) + assert(holder.get(holder.dataTypeParam).contains(BinaryType)) + } + + test("DataTypeParam with custom validator") { + val holder = new Params { + override val uid: String = "test" + val numericOnlyParam = new DataTypeParam( + this, "numericOnly", "Only numeric types", + (dt: DataType) => dt.isInstanceOf[NumericType] + ) + override def copy(extra: ParamMap): Params = this + } + holder.set(holder.numericOnlyParam, IntegerType) + holder.set(holder.numericOnlyParam, DoubleType) + holder.set(holder.numericOnlyParam, FloatType) + } + + test("DataTypeParam can be cleared") { + val holder = new TestParamsHolder + holder.set(holder.dataTypeParam, StringType) + assert(holder.isSet(holder.dataTypeParam)) + holder.clear(holder.dataTypeParam) + assert(!holder.isSet(holder.dataTypeParam)) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyEstimatorArrayParam.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyEstimatorArrayParam.scala new file mode 100644 index 00000000000..4d0e3d0012d --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyEstimatorArrayParam.scala @@ -0,0 +1,89 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.classification.{LogisticRegression, DecisionTreeClassifier} +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.ml.param.{ParamMap, Params} + +import java.util.{ArrayList => JArrayList} + +class VerifyEstimatorArrayParam extends TestBase { + + private class TestParamsHolder extends Params { + override val uid: String = "test-holder" + val estimatorsParam = new EstimatorArrayParam(this, "estimators", "An array of estimators") + override def copy(extra: ParamMap): Params = this + } + + test("EstimatorArrayParam can be created with basic constructor") { + val holder = new TestParamsHolder + assert(holder.estimatorsParam.name === "estimators") + assert(holder.estimatorsParam.doc === "An array of estimators") + } + + test("EstimatorArrayParam accepts empty array") { + val holder = new TestParamsHolder + holder.set(holder.estimatorsParam, Array.empty[Estimator[_]]) + assert(holder.get(holder.estimatorsParam).exists(_.isEmpty)) + } + + test("EstimatorArrayParam accepts array with single estimator") { + val holder = new TestParamsHolder + val estimators = Array[Estimator[_]](new LogisticRegression()) + holder.set(holder.estimatorsParam, estimators) + assert(holder.get(holder.estimatorsParam).exists(_.length === 1)) + } + + test("EstimatorArrayParam accepts array with multiple estimators") { + val holder = new TestParamsHolder + val estimators = Array[Estimator[_]]( + new LogisticRegression(), + new DecisionTreeClassifier(), + new StringIndexer() + ) + holder.set(holder.estimatorsParam, estimators) + assert(holder.get(holder.estimatorsParam).exists(_.length === 3)) + } + + test("EstimatorArrayParam w() method accepts Java List") { + val holder = new TestParamsHolder + val javaList = new JArrayList[Estimator[_]]() + javaList.add(new LogisticRegression()) + javaList.add(new DecisionTreeClassifier()) + + val paramPair = holder.estimatorsParam.w(javaList) + assert(paramPair.param === holder.estimatorsParam) + assert(paramPair.value.length === 2) + } + + test("EstimatorArrayParam with custom validator") { + val holder = new Params { + override val uid: String = "test" + val nonEmptyEstimators = new EstimatorArrayParam( + this, "nonEmpty", "Non-empty estimator array", + (arr: Array[Estimator[_]]) => arr.nonEmpty + ) + override def copy(extra: ParamMap): Params = this + } + val estimators = Array[Estimator[_]](new LogisticRegression()) + holder.set(holder.nonEmptyEstimators, estimators) + } + + test("EstimatorArrayParam can be cleared") { + val holder = new TestParamsHolder + val estimators = Array[Estimator[_]](new LogisticRegression()) + holder.set(holder.estimatorsParam, estimators) + assert(holder.isSet(holder.estimatorsParam)) + holder.clear(holder.estimatorsParam) + assert(!holder.isSet(holder.estimatorsParam)) + } + + test("EstimatorArrayParam returns None when not set") { + val holder = new TestParamsHolder + assert(holder.get(holder.estimatorsParam).isEmpty) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyEvaluatorParam.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyEvaluatorParam.scala new file mode 100644 index 00000000000..93dcdf24905 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyEvaluatorParam.scala @@ -0,0 +1,91 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.evaluation.{ + BinaryClassificationEvaluator, MulticlassClassificationEvaluator, RegressionEvaluator +} +import org.apache.spark.ml.param.{ParamMap, Params} + +class VerifyEvaluatorParam extends TestBase { + + private class TestParamsHolder extends Params { + override val uid: String = "test-holder" + val evaluatorParam = new EvaluatorParam(this, "evaluator", "An evaluator param") + override def copy(extra: ParamMap): Params = this + } + + test("EvaluatorParam can be created with basic constructor") { + val holder = new TestParamsHolder + assert(holder.evaluatorParam.name === "evaluator") + assert(holder.evaluatorParam.doc === "An evaluator param") + } + + test("EvaluatorParam accepts BinaryClassificationEvaluator") { + val holder = new TestParamsHolder + val evaluator = new BinaryClassificationEvaluator() + holder.set(holder.evaluatorParam, evaluator) + assert(holder.isSet(holder.evaluatorParam)) + } + + test("EvaluatorParam accepts MulticlassClassificationEvaluator") { + val holder = new TestParamsHolder + val evaluator = new MulticlassClassificationEvaluator() + holder.set(holder.evaluatorParam, evaluator) + assert(holder.isSet(holder.evaluatorParam)) + } + + test("EvaluatorParam accepts RegressionEvaluator") { + val holder = new TestParamsHolder + val evaluator = new RegressionEvaluator() + holder.set(holder.evaluatorParam, evaluator) + assert(holder.isSet(holder.evaluatorParam)) + } + + test("EvaluatorParam assertEquality passes for same evaluator type") { + val holder = new TestParamsHolder + val eval1 = new BinaryClassificationEvaluator() + .setMetricName("areaUnderROC") + .setLabelCol("label") + val eval2 = new BinaryClassificationEvaluator() + .setMetricName("areaUnderROC") + .setLabelCol("label") + holder.evaluatorParam.assertEquality(eval1, eval2) + } + + test("EvaluatorParam assertEquality throws for non-Evaluator types") { + val holder = new TestParamsHolder + assertThrows[AssertionError] { + holder.evaluatorParam.assertEquality("not an evaluator", "also not") + } + } + + test("EvaluatorParam with custom validator") { + val holder = new Params { + override val uid: String = "test" + val binaryOnly = new EvaluatorParam( + this, "binaryOnly", "Binary evaluator only", + _.isInstanceOf[BinaryClassificationEvaluator] + ) + override def copy(extra: ParamMap): Params = this + } + val evaluator = new BinaryClassificationEvaluator() + holder.set(holder.binaryOnly, evaluator) + } + + test("EvaluatorParam can be cleared") { + val holder = new TestParamsHolder + val evaluator = new RegressionEvaluator() + holder.set(holder.evaluatorParam, evaluator) + assert(holder.isSet(holder.evaluatorParam)) + holder.clear(holder.evaluatorParam) + assert(!holder.isSet(holder.evaluatorParam)) + } + + test("EvaluatorParam returns None when not set") { + val holder = new TestParamsHolder + assert(holder.get(holder.evaluatorParam).isEmpty) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyGlobalParams.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyGlobalParams.scala new file mode 100644 index 00000000000..5077b4b91d8 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyGlobalParams.scala @@ -0,0 +1,90 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.{Param, ParamMap, Params} +import org.apache.spark.ml.util.Identifiable + +class VerifyGlobalParams extends TestBase { + + // Test keys + case object TestStringKey extends GlobalKey[String] + case object TestIntKey extends GlobalKey[Int] + case object AnotherStringKey extends GlobalKey[String] + + // Helper to reset state before each test + private def resetGlobalState(): Unit = { + GlobalParams.resetGlobalParam(TestStringKey) + GlobalParams.resetGlobalParam(TestIntKey) + GlobalParams.resetGlobalParam(AnotherStringKey) + } + + test("setGlobalParam and getGlobalParam work for String") { + resetGlobalState() + GlobalParams.setGlobalParam(TestStringKey, "test-value") + val result = GlobalParams.getGlobalParam(TestStringKey) + assert(result === Some("test-value")) + } + + test("setGlobalParam and getGlobalParam work for Int") { + GlobalParams.setGlobalParam(TestIntKey, 42) + val result = GlobalParams.getGlobalParam(TestIntKey) + assert(result === Some(42)) + } + + test("getGlobalParam returns None for unset key") { + val result = GlobalParams.getGlobalParam(TestStringKey) + assert(result.isEmpty) + } + + test("resetGlobalParam removes the parameter") { + GlobalParams.setGlobalParam(TestStringKey, "value") + assert(GlobalParams.getGlobalParam(TestStringKey).isDefined) + GlobalParams.resetGlobalParam(TestStringKey) + assert(GlobalParams.getGlobalParam(TestStringKey).isEmpty) + } + + test("setGlobalParam overwrites existing value") { + GlobalParams.setGlobalParam(TestStringKey, "first") + GlobalParams.setGlobalParam(TestStringKey, "second") + assert(GlobalParams.getGlobalParam(TestStringKey) === Some("second")) + } + + test("multiple keys can be set independently") { + GlobalParams.setGlobalParam(TestStringKey, "string-value") + GlobalParams.setGlobalParam(TestIntKey, 100) + GlobalParams.setGlobalParam(AnotherStringKey, "another-value") + + assert(GlobalParams.getGlobalParam(TestStringKey) === Some("string-value")) + assert(GlobalParams.getGlobalParam(TestIntKey) === Some(100)) + assert(GlobalParams.getGlobalParam(AnotherStringKey) === Some("another-value")) + } + + test("registerParam and getParam work together") { + // Create a test Params implementation + class TestParams(override val uid: String) extends Params { + val testParam = new Param[String](this, "testParam", "test param") + override def copy(extra: ParamMap): Params = this + } + + val params = new TestParams(Identifiable.randomUID("test")) + GlobalParams.registerParam(params.testParam, TestStringKey) + GlobalParams.setGlobalParam(TestStringKey, "global-value") + + val result = GlobalParams.getParam(params.testParam) + assert(result === Some("global-value")) + } + + test("getParam returns None for unregistered param") { + class TestParams(override val uid: String) extends Params { + val unregisteredParam = new Param[String](this, "unregisteredParam", "not registered") + override def copy(extra: ParamMap): Params = this + } + + val params = new TestParams(Identifiable.randomUID("test")) + val result = GlobalParams.getParam(params.unregisteredParam) + assert(result.isEmpty) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyModelParam.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyModelParam.scala new file mode 100644 index 00000000000..6723898d278 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyModelParam.scala @@ -0,0 +1,99 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.Model +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler} +import org.apache.spark.ml.param.{ParamMap, Params} + +class VerifyModelParam extends TestBase { + + import spark.implicits._ + + private class TestParamsHolder extends Params { + override val uid: String = "test-holder" + val modelParam = new ModelParam(this, "model", "A model param") + override def copy(extra: ParamMap): Params = this + } + + // Helper to create a trained model + private def createTrainedModel(): LogisticRegressionModel = { + val data = Seq( + (0.0, 1.0, 0.0), + (1.0, 0.0, 1.0), + (0.0, 1.0, 0.0), + (1.0, 0.0, 1.0) + ).toDF("label", "f1", "f2") + + val assembler = new VectorAssembler() + .setInputCols(Array("f1", "f2")) + .setOutputCol("features") + val assembled = assembler.transform(data) + + val lr = new LogisticRegression() + .setMaxIter(5) + .setLabelCol("label") + .setFeaturesCol("features") + lr.fit(assembled) + } + + test("ModelParam can be created with basic constructor") { + val holder = new TestParamsHolder + assert(holder.modelParam.name === "model") + assert(holder.modelParam.doc === "A model param") + } + + test("ModelParam accepts LogisticRegressionModel") { + val holder = new TestParamsHolder + val model = createTrainedModel() + holder.set(holder.modelParam, model) + assert(holder.isSet(holder.modelParam)) + } + + test("ModelParam pyValue returns model reference") { + val holder = new TestParamsHolder + val model = createTrainedModel() + val pyVal = holder.modelParam.pyValue(model) + assert(pyVal === "modelModel") + } + + test("ModelParam pyLoadLine generates Python code") { + val holder = new TestParamsHolder + val pyCode = holder.modelParam.pyLoadLine(1) + assert(pyCode.contains("Pipeline.load")) + assert(pyCode.contains("model-1.model")) + assert(pyCode.contains("complexParams")) + } + + test("ModelParam rValue returns model reference") { + val holder = new TestParamsHolder + val model = createTrainedModel() + val rVal = holder.modelParam.rValue(model) + assert(rVal === "modelModel") + } + + test("ModelParam rLoadLine generates R code") { + val holder = new TestParamsHolder + val rCode = holder.modelParam.rLoadLine(2) + assert(rCode.contains("ml_load")) + assert(rCode.contains("model-2.model")) + assert(rCode.contains("ml_stages")) + } + + test("ModelParam can be cleared") { + val holder = new TestParamsHolder + val model = createTrainedModel() + holder.set(holder.modelParam, model) + assert(holder.isSet(holder.modelParam)) + holder.clear(holder.modelParam) + assert(!holder.isSet(holder.modelParam)) + } + + test("ModelParam returns None when not set") { + val holder = new TestParamsHolder + assert(holder.get(holder.modelParam).isEmpty) + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyPipelineStageParams.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyPipelineStageParams.scala new file mode 100644 index 00000000000..34cef6b75a8 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyPipelineStageParams.scala @@ -0,0 +1,151 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.{Transformer, Estimator, Model, PipelineStage} +import org.apache.spark.ml.feature.{Tokenizer, HashingTF, StringIndexer, StringIndexerModel} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types.StructType + +class VerifyPipelineStageParams extends TestBase { + + // Test class that holds the params + private class TestParamsHolder extends Params { + override val uid: String = "test-holder" + + val transformerParam = new TransformerParam(this, "transformer", "A transformer param") + val estimatorParam = new EstimatorParam(this, "estimator", "An estimator param") + val pipelineStageParam = new PipelineStageParam(this, "pipelineStage", "A pipeline stage param") + + override def copy(extra: ParamMap): Params = this + } + + // TransformerParam tests + test("TransformerParam can be created with basic constructor") { + val holder = new TestParamsHolder + assert(holder.transformerParam.name === "transformer") + assert(holder.transformerParam.doc === "A transformer param") + } + + test("TransformerParam accepts valid Transformer") { + val holder = new TestParamsHolder + val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words") + // Should not throw + holder.set(holder.transformerParam, tokenizer) + } + + test("TransformerParam with custom validator") { + val holder = new Params { + override val uid: String = "test" + val validatedParam = new TransformerParam( + this, "validated", "validated param", + (t: Transformer) => t.isInstanceOf[Tokenizer] + ) + override def copy(extra: ParamMap): Params = this + } + val tokenizer = new Tokenizer() + holder.set(holder.validatedParam, tokenizer) + } + + test("TransformerParam rLoadLine generates correct R code") { + val holder = new TestParamsHolder + val rCode = holder.transformerParam.rLoadLine(1) + assert(rCode.contains("ml_load")) + assert(rCode.contains("model-1.model")) + assert(rCode.contains("complexParams")) + assert(rCode.contains("transformer")) + assert(rCode.contains("ml_stages")) + } + + // EstimatorParam tests + test("EstimatorParam can be created with basic constructor") { + val holder = new TestParamsHolder + assert(holder.estimatorParam.name === "estimator") + assert(holder.estimatorParam.doc === "An estimator param") + } + + test("EstimatorParam accepts valid Estimator") { + val holder = new TestParamsHolder + val indexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel") + holder.set(holder.estimatorParam, indexer) + } + + test("EstimatorParam rLoadLine generates correct R code") { + val holder = new TestParamsHolder + val rCode = holder.estimatorParam.rLoadLine(2) + assert(rCode.contains("ml_load")) + assert(rCode.contains("model-2.model")) + assert(rCode.contains("complexParams")) + assert(rCode.contains("estimator")) + } + + // PipelineStageParam tests + test("PipelineStageParam can be created with basic constructor") { + val holder = new TestParamsHolder + assert(holder.pipelineStageParam.name === "pipelineStage") + assert(holder.pipelineStageParam.doc === "A pipeline stage param") + } + + test("PipelineStageParam accepts Transformer") { + val holder = new TestParamsHolder + val tokenizer = new Tokenizer() + holder.set(holder.pipelineStageParam, tokenizer) + } + + test("PipelineStageParam accepts Estimator") { + val holder = new TestParamsHolder + val indexer = new StringIndexer() + holder.set(holder.pipelineStageParam, indexer) + } + + test("PipelineStageParam rLoadLine generates correct R code") { + val holder = new TestParamsHolder + val rCode = holder.pipelineStageParam.rLoadLine(3) + assert(rCode.contains("ml_load")) + assert(rCode.contains("model-3.model")) + assert(rCode.contains("pipelineStage")) + assert(rCode.contains("ml_stages")) + } + + // PipelineStageWrappable trait tests + test("PipelineStageWrappable pyValue returns model reference") { + val holder = new TestParamsHolder + val tokenizer = new Tokenizer() + val pyVal = holder.transformerParam.pyValue(tokenizer) + assert(pyVal === "transformerModel") + } + + test("PipelineStageWrappable pyLoadLine generates Python code") { + val holder = new TestParamsHolder + val pyCode = holder.transformerParam.pyLoadLine(1) + assert(pyCode.contains("Pipeline.load")) + assert(pyCode.contains("model-1.model")) + assert(pyCode.contains("complexParams")) + assert(pyCode.contains("getStages()")) + } + + test("PipelineStageWrappable rValue returns model reference") { + val holder = new TestParamsHolder + val tokenizer = new Tokenizer() + val rVal = holder.transformerParam.rValue(tokenizer) + assert(rVal === "transformerModel") + } + + test("PipelineStageWrappable assertEquality succeeds for same transformer") { + val holder = new TestParamsHolder + val t1 = new Tokenizer().setInputCol("a").setOutputCol("b") + val t2 = new Tokenizer().setInputCol("a").setOutputCol("b") + // Should not throw + holder.transformerParam.assertEquality(t1, t2) + } + + test("PipelineStageWrappable assertEquality throws for non-PipelineStage") { + val holder = new TestParamsHolder + assertThrows[AssertionError] { + holder.transformerParam.assertEquality("not a stage", "also not a stage") + } + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyPythonWrappableParam.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyPythonWrappableParam.scala new file mode 100644 index 00000000000..fc5264333e4 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyPythonWrappableParam.scala @@ -0,0 +1,126 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.{ParamMap, Params} +import spray.json._ +import spray.json.DefaultJsonProtocol._ + +class VerifyPythonWrappableParam extends TestBase { + + test("PythonPrinter converts JsNull to Python None") { + val result = PythonPrinter(JsNull) + assert(result === "None") + } + + test("PythonPrinter converts JsTrue to Python True") { + val result = PythonPrinter(JsTrue) + assert(result === "True") + } + + test("PythonPrinter converts JsFalse to Python False") { + val result = PythonPrinter(JsFalse) + assert(result === "False") + } + + test("PythonPrinter converts JsNumber correctly") { + assert(PythonPrinter(JsNumber(42)) === "42") + assert(PythonPrinter(JsNumber(3.14)) === "3.14") + assert(PythonPrinter(JsNumber(-100)) === "-100") + } + + test("PythonPrinter converts JsString correctly") { + val result = PythonPrinter(JsString("hello")) + assert(result === "\"hello\"") + } + + test("PythonPrinter converts JsArray correctly") { + val arr = JsArray(JsNumber(1), JsNumber(2), JsNumber(3)) + val result = PythonPrinter(arr) + assert(result === "[1,2,3]") + } + + test("PythonPrinter converts JsObject correctly") { + val obj = JsObject("key" -> JsString("value")) + val result = PythonPrinter(obj) + assert(result.contains("key")) + assert(result.contains("value")) + } + + test("PythonPrinter converts nested structures") { + val nested = JsObject( + "bool" -> JsTrue, + "null" -> JsNull, + "number" -> JsNumber(42) + ) + val result = PythonPrinter(nested) + assert(result.contains("True")) + assert(result.contains("None")) + assert(result.contains("42")) + } + + test("pyDefaultRender with JsonFormat") { + val result = PythonWrappableParam.pyDefaultRender("test") + assert(result === "\"test\"") + } + + test("pyDefaultRender with Int") { + val result = PythonWrappableParam.pyDefaultRender(42) + assert(result === "42") + } + + test("pyDefaultRender with Boolean true") { + val result = PythonWrappableParam.pyDefaultRender(true) + assert(result === "True") + } + + test("pyDefaultRender with Boolean false") { + val result = PythonWrappableParam.pyDefaultRender(false) + assert(result === "False") + } + + test("pyDefaultRender with custom jsonFunc") { + val result = PythonWrappableParam.pyDefaultRender( + List(1, 2, 3), + (v: List[Int]) => v.toJson.compactPrint + ) + assert(result === "[1,2,3]") + } + + // Test PythonWrappableParam trait implementation + private class TestPythonParam(parent: Params, override val name: String, doc: String) + extends org.apache.spark.ml.param.Param[String](parent, name, doc) + with PythonWrappableParam[String] + + private class TestParams extends Params { + override val uid: String = "test-uid" + val stringParam = new TestPythonParam(this, "testString", "A test string param") + override def copy(extra: ParamMap): Params = this + } + + test("PythonWrappableParam.pyValue renders value correctly") { + val params = new TestParams + val result = params.stringParam.pyValue("hello") + assert(result === "\"hello\"") + } + + test("PythonWrappableParam.pyName returns param name") { + val params = new TestParams + val result = params.stringParam.pyName("anyValue") + assert(result === "testString") + } + + test("PythonWrappableParam.pyConstructorLine generates correct format") { + val params = new TestParams + val result = params.stringParam.pyConstructorLine("world") + assert(result === "testString=\"world\"") + } + + test("PythonWrappableParam.pySetterLine generates correct format") { + val params = new TestParams + val result = params.stringParam.pySetterLine("value") + assert(result === "setTestString(\"value\")") + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyRWrappableParam.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyRWrappableParam.scala new file mode 100644 index 00000000000..ef069b43669 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyRWrappableParam.scala @@ -0,0 +1,152 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.{ParamMap, Params} +import spray.json._ +import spray.json.DefaultJsonProtocol._ + +class VerifyRWrappableParam extends TestBase { + + test("RPrinter converts JsNull to R NULL") { + val result = RPrinter(JsNull) + assert(result === "NULL") + } + + test("RPrinter converts JsTrue to R TRUE") { + val result = RPrinter(JsTrue) + assert(result === "TRUE") + } + + test("RPrinter converts JsFalse to R FALSE") { + val result = RPrinter(JsFalse) + assert(result === "FALSE") + } + + test("RPrinter converts integer JsNumber with L suffix") { + val result = RPrinter(JsNumber(42)) + assert(result === "42L") + } + + test("RPrinter converts double JsNumber without L suffix") { + val result = RPrinter(JsNumber(3.14)) + assert(result === "3.14") + } + + test("RPrinter converts JsString correctly") { + val result = RPrinter(JsString("hello")) + assert(result === "\"hello\"") + } + + test("RPrinter converts empty JsArray to c()") { + val arr = JsArray() + val result = RPrinter(arr) + assert(result === "c()") + } + + test("RPrinter converts JsArray of numbers to list()") { + val arr = JsArray(JsNumber(1), JsNumber(2), JsNumber(3)) + val result = RPrinter(arr) + assert(result === "list(1L,2L,3L)") + } + + test("RPrinter converts empty JsObject to c()") { + val obj = JsObject() + val result = RPrinter(obj) + assert(result === "c()") + } + + test("RPrinter converts JsObject to list2env") { + val obj = JsObject("key" -> JsString("value")) + val result = RPrinter(obj) + assert(result.contains("list2env")) + assert(result.contains("key")) + assert(result.contains("value")) + } + + test("RPrinter converts nested JsObject correctly") { + val nested = JsObject( + "bool" -> JsTrue, + "null" -> JsNull, + "number" -> JsNumber(42) + ) + val result = RPrinter(nested) + assert(result.contains("TRUE")) + assert(result.contains("NULL")) + assert(result.contains("42L")) + } + + test("RPrinter converts JsArray of JsObjects correctly") { + val arr = JsArray( + JsObject("a" -> JsNumber(1)), + JsObject("b" -> JsNumber(2)) + ) + val result = RPrinter(arr) + assert(result.contains("list2env")) + } + + test("rDefaultRender with JsonFormat for String") { + val result = RWrappableParam.rDefaultRender("test") + assert(result === "\"test\"") + } + + test("rDefaultRender with JsonFormat for Int") { + val result = RWrappableParam.rDefaultRender(42) + assert(result === "42L") + } + + test("rDefaultRender with JsonFormat for Boolean true") { + val result = RWrappableParam.rDefaultRender(true) + assert(result === "TRUE") + } + + test("rDefaultRender with JsonFormat for Boolean false") { + val result = RWrappableParam.rDefaultRender(false) + assert(result === "FALSE") + } + + test("rDefaultRender with custom jsonFunc") { + val result = RWrappableParam.rDefaultRender( + List(1, 2, 3), + (v: List[Int]) => v.toJson.compactPrint + ) + assert(result === "list(1L,2L,3L)") + } + + // Test RWrappableParam trait implementation + private class TestRParam(parent: Params, override val name: String, doc: String) + extends org.apache.spark.ml.param.Param[String](parent, name, doc) + with RWrappableParam[String] + + private class TestParams extends Params { + override val uid: String = "test-uid" + val stringParam = new TestRParam(this, "testString", "A test string param") + override def copy(extra: ParamMap): Params = this + } + + test("RWrappableParam.rValue renders value correctly") { + val params = new TestParams + val result = params.stringParam.rValue("hello") + assert(result === "\"hello\"") + } + + test("RWrappableParam.rName returns param name") { + val params = new TestParams + val result = params.stringParam.rName("anyValue") + assert(result === "testString") + } + + test("RWrappableParam.rConstructorLine generates correct format") { + val params = new TestParams + val result = params.stringParam.rConstructorLine("world") + assert(result === "testString=\"world\"") + } + + test("RWrappableParam.rSetterLine generates correct format") { + val params = new TestParams + val result = params.stringParam.rSetterLine("value") + assert(result === "setTestString(\"value\")") + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyUDFParam.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyUDFParam.scala new file mode 100644 index 00000000000..caa62719256 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/VerifyUDFParam.scala @@ -0,0 +1,89 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.TestBase +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions.udf + +class VerifyUDFParam extends TestBase { + + private class TestParamsHolder extends Params { + override val uid: String = "test-holder" + val udfParam = new UDFParam(this, "udf", "A UDF param") + override def copy(extra: ParamMap): Params = this + } + + test("UDFParam can be created with basic constructor") { + val holder = new TestParamsHolder + assert(holder.udfParam.name === "udf") + assert(holder.udfParam.doc === "A UDF param") + } + + test("UDFParam accepts simple UDF") { + val holder = new TestParamsHolder + val myUdf = udf((x: Int) => x * 2) + holder.set(holder.udfParam, myUdf) + assert(holder.isSet(holder.udfParam)) + } + + test("UDFParam accepts string transformation UDF") { + val holder = new TestParamsHolder + val myUdf = udf((s: String) => s.toUpperCase) + holder.set(holder.udfParam, myUdf) + assert(holder.isSet(holder.udfParam)) + } + + test("UDFParam accepts multi-argument UDF") { + val holder = new TestParamsHolder + val myUdf = udf((a: Int, b: Int) => a + b) + holder.set(holder.udfParam, myUdf) + assert(holder.isSet(holder.udfParam)) + } + + test("UDFParam with custom validator") { + val holder = new Params { + override val uid: String = "test" + // Accept any UDF + val validatedUdf = new UDFParam( + this, "validated", "Validated UDF", + (_: UserDefinedFunction) => true + ) + override def copy(extra: ParamMap): Params = this + } + val myUdf = udf((x: Double) => x * x) + holder.set(holder.validatedUdf, myUdf) + } + + test("UDFParam can be cleared") { + val holder = new TestParamsHolder + val myUdf = udf((x: Int) => x) + holder.set(holder.udfParam, myUdf) + assert(holder.isSet(holder.udfParam)) + holder.clear(holder.udfParam) + assert(!holder.isSet(holder.udfParam)) + } + + test("UDFParam returns None when not set") { + val holder = new TestParamsHolder + assert(holder.get(holder.udfParam).isEmpty) + } + + test("UDFParam assertEquality passes for same UDFs") { + val holder = new TestParamsHolder + val udf1 = udf((x: Int) => x * 2) + val udf2 = udf((x: Int) => x * 2) + // Note: This test verifies the assertEquality method exists and runs + // Exact equality depends on internal UDF representation + holder.udfParam.assertEquality(udf1, udf1) + } + + test("UDFParam assertEquality throws for non-UDF types") { + val holder = new TestParamsHolder + assertThrows[AssertionError] { + holder.udfParam.assertEquality("not a udf", "also not a udf") + } + } +} diff --git a/deep-learning/src/test/scala/com/microsoft/azure/synapse/ml/onnx/ONNXModelSuite.scala b/deep-learning/src/test/scala/com/microsoft/azure/synapse/ml/onnx/ONNXModelSuite.scala index 5cd28f57192..4fee100465e 100644 --- a/deep-learning/src/test/scala/com/microsoft/azure/synapse/ml/onnx/ONNXModelSuite.scala +++ b/deep-learning/src/test/scala/com/microsoft/azure/synapse/ml/onnx/ONNXModelSuite.scala @@ -50,8 +50,9 @@ class ONNXModelSuite extends TestBase private implicit val eqFloat: Equality[Float] = TolerantNumerics.tolerantFloatEquality(1E-5f) private implicit val eqMap: Equality[Map[Long, Float]] = mapEq[Long, Float] private implicit val eqSeqDouble: Equality[Seq[Double]] = (a: Seq[Double], b: Any) => { - b match { - case sd: Seq[Double] => a.zip(sd).forall(x => x._1 === x._2) + // Using @unchecked because Seq[Double] type parameter is erased at runtime + (b: @unchecked) match { + case sd: Seq[Double @unchecked] => a.zip(sd).forall(x => x._1 === x._2) case _ => false } } diff --git a/pipeline.yaml b/pipeline.yaml index b59de5d31a3..627497f785c 100644 --- a/pipeline.yaml +++ b/pipeline.yaml @@ -107,7 +107,7 @@ jobs: azureSubscription: 'SynapseML Build' scriptLocation: inlineScript scriptType: bash - inlineScript: 'sbt scalastyle test:scalastyle' + inlineScript: 'sbt scalastyle "Test / scalastyle"' - template: templates/conda.yml - bash: | set -e @@ -505,6 +505,7 @@ jobs: condition: and(succeededOrFailed(), eq(variables.runCoverage, true)) - ${{ if or(eq(variables['Build.Reason'], 'PullRequest'), eq(variables['Build.SourceBranch'], 'refs/heads/master'), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) }}: - template: templates/codecov.yml + - template: templates/publish_coverage_ado.yml - job: RTests timeoutInMinutes: 60 cancelTimeoutInMinutes: 0 @@ -579,6 +580,7 @@ jobs: condition: and(succeededOrFailed(), eq(variables.runCoverage, true)) - ${{ if or(eq(variables['Build.Reason'], 'PullRequest'), eq(variables['Build.SourceBranch'], 'refs/heads/master'), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) }}: - template: templates/codecov.yml + - template: templates/publish_coverage_ado.yml - job: BuildAndCacheCondaEnv cancelTimeoutInMinutes: 0 @@ -627,6 +629,7 @@ jobs: condition: and(succeededOrFailed(), eq(variables.runCoverage, true)) - ${{ if or(eq(variables['Build.Reason'], 'PullRequest'), eq(variables['Build.SourceBranch'], 'refs/heads/master'), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) }}: - template: templates/codecov.yml + - template: templates/publish_coverage_ado.yml - job: UnitTests cancelTimeoutInMinutes: 1 @@ -787,3 +790,4 @@ jobs: - template: templates/kv.yml - ${{ if or(eq(variables['Build.Reason'], 'PullRequest'), eq(variables['Build.SourceBranch'], 'refs/heads/master'), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) }}: - template: templates/codecov.yml + - template: templates/publish_coverage_ado.yml diff --git a/src/test/scala/com/microsoft/azure/synapse/ml/core/test/fuzzing/FuzzingTest.scala b/src/test/scala/com/microsoft/azure/synapse/ml/core/test/fuzzing/FuzzingTest.scala index 8905e3f88b4..4e57442a377 100644 --- a/src/test/scala/com/microsoft/azure/synapse/ml/core/test/fuzzing/FuzzingTest.scala +++ b/src/test/scala/com/microsoft/azure/synapse/ml/core/test/fuzzing/FuzzingTest.scala @@ -12,7 +12,7 @@ import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.util.{MLReadable, MLWritable} -import java.lang.reflect.ParameterizedType +import java.lang.reflect.{InvocationTargetException, Modifier, ParameterizedType} import scala.language.existentials /** Tests to validate fuzzing of modules. */ @@ -433,7 +433,19 @@ class FuzzingTest extends TestBase { private lazy val readers: List[MLReadable[_]] = JarLoadingUtils.instantiateObjects[MLReadable[_]]() - private lazy val pipelineStages: List[PipelineStage] = JarLoadingUtils.instantiateServices[PipelineStage]() + private lazy val pipelineStages: List[PipelineStage] = { + JarLoadingUtils.AllClasses + .filter(classOf[PipelineStage].isAssignableFrom(_)) + .filter(clazz => !Modifier.isAbstract(clazz.getModifiers)) + .filterNot(clazz => clazz.getName.contains("$") || clazz.getSimpleName.startsWith("Testable")) + .map { clazz => + try { + clazz.getConstructor().newInstance().asInstanceOf[PipelineStage] + } catch { + case e: InvocationTargetException => throw e.getCause + } + } + } private lazy val experimentFuzzers: List[ExperimentFuzzing[_ <: PipelineStage]] = JarLoadingUtils.instantiateServices[ExperimentFuzzing[_ <: PipelineStage]]() diff --git a/templates/publish_coverage_ado.yml b/templates/publish_coverage_ado.yml new file mode 100644 index 00000000000..cf928d52f23 --- /dev/null +++ b/templates/publish_coverage_ado.yml @@ -0,0 +1,10 @@ +steps: + - task: PublishCodeCoverageResults@2 + displayName: 'Publish Code Coverage to Azure DevOps' + inputs: + # Use Cobertura format XML which Azure DevOps understands + # scoverage generates this with coverageOutputCobertura := true in build.sbt + summaryFileLocation: '**/scoverage-report/cobertura.xml' + pathToSources: '$(Build.SourcesDirectory)' + failIfCoverageEmpty: false + condition: and(succeededOrFailed(), eq(variables.runCoverage, true))