Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 58 additions & 11 deletions sdk/cs/src/Catalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,51 +52,59 @@ internal static async Task<Catalog> CreateAsync(IModelLoadManager modelManager,
return catalog;
}

public async Task<List<Model>> ListModelsAsync(CancellationToken? ct = null)
public async Task<List<IModel>> ListModelsAsync(CancellationToken? ct = null)
{
return await Utils.CallWithExceptionHandling(() => ListModelsImplAsync(ct),
"Error listing models.", _logger).ConfigureAwait(false);
}

public async Task<List<ModelVariant>> GetCachedModelsAsync(CancellationToken? ct = null)
public async Task<List<IModel>> GetCachedModelsAsync(CancellationToken? ct = null)
{
return await Utils.CallWithExceptionHandling(() => GetCachedModelsImplAsync(ct),
"Error getting cached models.", _logger).ConfigureAwait(false);
}

public async Task<List<ModelVariant>> GetLoadedModelsAsync(CancellationToken? ct = null)
public async Task<List<IModel>> GetLoadedModelsAsync(CancellationToken? ct = null)
{
return await Utils.CallWithExceptionHandling(() => GetLoadedModelsImplAsync(ct),
"Error getting loaded models.", _logger).ConfigureAwait(false);
}

public async Task<Model?> GetModelAsync(string modelAlias, CancellationToken? ct = null)
public async Task<IModel?> GetModelAsync(string modelAlias, CancellationToken? ct = null)
{
return await Utils.CallWithExceptionHandling(() => GetModelImplAsync(modelAlias, ct),
$"Error getting model with alias '{modelAlias}'.", _logger)
.ConfigureAwait(false);
}

public async Task<ModelVariant?> GetModelVariantAsync(string modelId, CancellationToken? ct = null)
public async Task<IModel?> GetModelVariantAsync(string modelId, CancellationToken? ct = null)
{
return await Utils.CallWithExceptionHandling(() => GetModelVariantImplAsync(modelId, ct),
$"Error getting model variant with ID '{modelId}'.", _logger)
.ConfigureAwait(false);
}

private async Task<List<Model>> ListModelsImplAsync(CancellationToken? ct = null)
public async Task<IModel> GetLatestVersionAsync(IModel modelOrModelVariant, CancellationToken? ct = null)
{
return await Utils.CallWithExceptionHandling(
() => GetLatestVersionImplAsync(modelOrModelVariant, ct),
$"Error getting latest version for model with name '{modelOrModelVariant.Info.Name}'.",
_logger).ConfigureAwait(false);
}

private async Task<List<IModel>> ListModelsImplAsync(CancellationToken? ct = null)
{
await UpdateModels(ct).ConfigureAwait(false);

using var disposable = await _lock.LockAsync().ConfigureAwait(false);
return _modelAliasToModel.Values.OrderBy(m => m.Alias).ToList();
return _modelAliasToModel.Values.OrderBy(m => m.Alias).Cast<IModel>().ToList();
}

private async Task<List<ModelVariant>> GetCachedModelsImplAsync(CancellationToken? ct = null)
private async Task<List<IModel>> GetCachedModelsImplAsync(CancellationToken? ct = null)
{
var cachedModelIds = await Utils.GetCachedModelIdsAsync(_coreInterop, ct).ConfigureAwait(false);

List<ModelVariant> cachedModels = new();
List<IModel> cachedModels = [];
foreach (var modelId in cachedModelIds)
{
if (_modelIdToModelVariant.TryGetValue(modelId, out ModelVariant? modelVariant))
Expand All @@ -108,10 +116,10 @@ private async Task<List<ModelVariant>> GetCachedModelsImplAsync(CancellationToke
return cachedModels;
}

private async Task<List<ModelVariant>> GetLoadedModelsImplAsync(CancellationToken? ct = null)
private async Task<List<IModel>> GetLoadedModelsImplAsync(CancellationToken? ct = null)
{
var loadedModelIds = await _modelLoadManager.ListLoadedModelsAsync(ct).ConfigureAwait(false);
List<ModelVariant> loadedModels = new();
List<IModel> loadedModels = [];

foreach (var modelId in loadedModelIds)
{
Expand Down Expand Up @@ -143,6 +151,45 @@ private async Task<List<ModelVariant>> GetLoadedModelsImplAsync(CancellationToke
return modelVariant;
}

private async Task<IModel> GetLatestVersionImplAsync(IModel modelOrModelVariant, CancellationToken? ct)
{
Model? model;

if (modelOrModelVariant is ModelVariant)
{
// For ModelVariant, resolve the owning Model via alias.
model = await GetModelImplAsync(modelOrModelVariant.Alias, ct);
}
else
{
// Try to use the concrete Model instance if this is our SDK type.
model = modelOrModelVariant as Model;

// If this is a different IModel implementation (e.g., a test stub),
// fall back to resolving the Model via alias.
if (model == null)
{
model = await GetModelImplAsync(modelOrModelVariant.Alias, ct);
}
}

if (model == null)
{
throw new FoundryLocalException($"Model with alias '{modelOrModelVariant.Alias}' not found in catalog.",
_logger);
}

// variants are sorted by version, so the first one matching the name is the latest version for that variant.
var latest = model!.Variants.FirstOrDefault(v => v.Info.Name == modelOrModelVariant.Info.Name) ??
// should not be possible given we internally manage all the state involved
throw new FoundryLocalException($"Internal error. Mismatch between model (alias:{model.Alias}) and " +
$"model variant (alias:{modelOrModelVariant.Alias}).", _logger);

// if input was the latest return the input (could be model or model variant)
// otherwise return the latest model variant
return latest.Id == modelOrModelVariant.Id ? modelOrModelVariant : latest;
}

private async Task UpdateModels(CancellationToken? ct)
{
// TODO: make this configurable
Expand Down
32 changes: 9 additions & 23 deletions sdk/cs/src/Model.cs → sdk/cs/src/Detail/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ public class Model : IModel
{
private readonly ILogger _logger;

public List<ModelVariant> Variants { get; internal set; }
public ModelVariant SelectedVariant { get; internal set; } = default!;
private readonly List<IModel> _variants;
public IReadOnlyList<IModel> Variants => _variants;
public IModel SelectedVariant { get; internal set; } = default!;

public string Alias { get; init; }
public string Id => SelectedVariant.Id;
public ModelInfo Info => SelectedVariant.Info;

/// <summary>
/// Is the currently selected variant cached locally?
Expand All @@ -33,7 +35,7 @@ internal Model(ModelVariant modelVariant, ILogger logger)
_logger = logger;

Alias = modelVariant.Alias;
Variants = new() { modelVariant };
_variants = [modelVariant];

// variants are sorted by Core, so the first one added is the default
SelectedVariant = modelVariant;
Expand All @@ -48,7 +50,7 @@ internal void AddVariant(ModelVariant variant)
_logger);
}

Variants.Add(variant);
_variants.Add(variant);

// prefer the highest priority locally cached variant
if (variant.Info.Cached && !SelectedVariant.Info.Cached)
Expand All @@ -62,31 +64,15 @@ internal void AddVariant(ModelVariant variant)
/// </summary>
/// <param name="variant">Model variant to select. Must be one of the variants in <see cref="Variants"/>.</param>
/// <exception cref="FoundryLocalException">If variant is not valid for this model.</exception>
public void SelectVariant(ModelVariant variant)
public void SelectVariant(IModel variant)
{
_ = Variants.FirstOrDefault(v => v == variant) ??
// user error so don't log
throw new FoundryLocalException($"Model {Alias} does not have a {variant.Id} variant.");
// user error so don't log.
throw new FoundryLocalException($"Input variant was not found in Variants.");

SelectedVariant = variant;
}

/// <summary>
/// Get the latest version of the specified model variant.
/// </summary>
/// <param name="variant">Model variant.</param>
/// <returns>ModelVariant for latest version. Same as `variant` if that is the latest version.</returns>
/// <exception cref="FoundryLocalException">If variant is not valid for this model.</exception>
public ModelVariant GetLatestVersion(ModelVariant variant)
{
// variants are sorted by version, so the first one matching the name is the latest version for that variant.
var latest = Variants.FirstOrDefault(v => v.Info.Name == variant.Info.Name) ??
// user error so don't log
throw new FoundryLocalException($"Model {Alias} does not have a {variant.Id} variant.");

return latest;
}

public async Task<string> GetPathAsync(CancellationToken? ct = null)
{
return await SelectedVariant.GetPathAsync(ct).ConfigureAwait(false);
Expand Down
12 changes: 11 additions & 1 deletion sdk/cs/src/ModelVariant.cs → sdk/cs/src/Detail/ModelVariant.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Microsoft.AI.Foundry.Local;
using Microsoft.AI.Foundry.Local.Detail;
using Microsoft.Extensions.Logging;

public class ModelVariant : IModel
internal class ModelVariant : IModel
{
private readonly IModelLoadManager _modelLoadManager;
private readonly ICoreInterop _coreInterop;
Expand All @@ -22,6 +22,9 @@ public class ModelVariant : IModel
public string Alias => Info.Alias;
public int Version { get; init; } // parsed from Info.Version if possible, else 0

public IReadOnlyList<IModel> Variants => [this];
public IModel SelectedVariant => this;

internal ModelVariant(ModelInfo modelInfo, IModelLoadManager modelLoadManager, ICoreInterop coreInterop,
ILogger logger)
{
Expand Down Expand Up @@ -190,4 +193,11 @@ private async Task<OpenAIAudioClient> GetAudioClientImplAsync(CancellationToken?

return new OpenAIAudioClient(Id);
}

public void SelectVariant(IModel variant)
{
throw new FoundryLocalException(
$"SelectVariant is not supported on a ModelVariant. " +
$"Call Catalog.GetModelAsync(\"{Alias}\") to get a Model with all variants available.");
}
}
28 changes: 18 additions & 10 deletions sdk/cs/src/ICatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,44 @@ public interface ICatalog
/// List the available models in the catalog.
/// </summary>
/// <param name="ct">Optional CancellationToken.</param>
/// <returns>List of Model instances.</returns>
Task<List<Model>> ListModelsAsync(CancellationToken? ct = null);
/// <returns>List of IModel instances.</returns>
Task<List<IModel>> ListModelsAsync(CancellationToken? ct = null);

/// <summary>
/// Lookup a model by its alias.
/// </summary>
/// <param name="modelAlias">Model alias.</param>
/// <param name="ct">Optional CancellationToken.</param>
/// <returns>The matching Model, or null if no model with the given alias exists.</returns>
Task<Model?> GetModelAsync(string modelAlias, CancellationToken? ct = null);
/// <returns>The matching IModel, or null if no model with the given alias exists.</returns>
Task<IModel?> GetModelAsync(string modelAlias, CancellationToken? ct = null);

/// <summary>
/// Lookup a model variant by its unique model id.
/// </summary>
/// <param name="modelId">Model id.</param>
/// <param name="ct">Optional CancellationToken.</param>
/// <returns>The matching ModelVariant, or null if no variant with the given id exists.</returns>
Task<ModelVariant?> GetModelVariantAsync(string modelId, CancellationToken? ct = null);
/// <returns>The matching IModel, or null if no variant with the given id exists.</returns>
Task<IModel?> GetModelVariantAsync(string modelId, CancellationToken? ct = null);

/// <summary>
/// Get a list of currently downloaded models from the model cache.
/// </summary>
/// <param name="ct">Optional CancellationToken.</param>
/// <returns>List of ModelVariant instances.</returns>
Task<List<ModelVariant>> GetCachedModelsAsync(CancellationToken? ct = null);
/// <returns>List of IModel instances.</returns>
Task<List<IModel>> GetCachedModelsAsync(CancellationToken? ct = null);

/// <summary>
/// Get a list of the currently loaded models.
/// </summary>
/// <param name="ct">Optional CancellationToken.</param>
/// <returns>List of ModelVariant instances.</returns>
Task<List<ModelVariant>> GetLoadedModelsAsync(CancellationToken? ct = null);
/// <returns>List of IModel instances.</returns>
Task<List<IModel>> GetLoadedModelsAsync(CancellationToken? ct = null);

/// <summary>
/// Get the latest version of a model.
/// This is used to check if a newer version of a model is available in the catalog for download.
/// </summary>
/// <param name="model">The model to check for the latest version.</param>
/// <returns>The latest version of the model. Will match the input if it is the latest version.</returns>
Task<IModel> GetLatestVersionAsync(IModel model, CancellationToken? ct = null);
}
19 changes: 19 additions & 0 deletions sdk/cs/src/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public interface IModel
Justification = "Alias is a suitable name in this context.")]
string Alias { get; }

ModelInfo Info { get; }

Task<bool> IsCachedAsync(CancellationToken? ct = null);
Task<bool> IsLoadedAsync(CancellationToken? ct = null);

Expand Down Expand Up @@ -67,4 +69,21 @@ Task DownloadAsync(Action<float>? downloadProgress = null,
/// <param name="ct">Optional cancellation token.</param>
/// <returns>OpenAI.AudioClient</returns>
Task<OpenAIAudioClient> GetAudioClientAsync(CancellationToken? ct = null);

/// <summary>
/// Variants of the model that are available. Variants of the model are optimized for different devices.
/// </summary>
IReadOnlyList<IModel> Variants { get; }

/// <summary>
/// Currently selected model variant in use.
/// </summary>
IModel SelectedVariant { get; }

/// <summary>
/// Select a specific model variant from <see cref="Variants"/> to use for <see cref="IModel"/> operations.
/// </summary>
/// <param name="variant">Model variant to select. Must be one of the variants in <see cref="Variants"/>.</param>
/// <exception cref="FoundryLocalException">If variant is not valid for this model.</exception>
void SelectVariant(IModel variant);
}
2 changes: 1 addition & 1 deletion sdk/cs/test/FoundryLocal.Tests/AudioClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.AI.Foundry.Local.Tests;

internal sealed class AudioClientTests
{
private static Model? model;
private static IModel? model;

[Before(Class)]
public static async Task Setup()
Expand Down
Loading
Loading