Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
changes:
- section: "Features Added"
description: "Added an optional `--schema` parameter to `azmcp postgres list` for listing tables in non-public PostgreSQL schemas (defaults to `public`), and removed unused `--subscription`/`--resource-group` parameters from the PostgreSQL data-plane commands: `database query`, `table schema get`."
13 changes: 5 additions & 8 deletions servers/Azure.Mcp.Server/docs/azmcp-commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -2269,28 +2269,25 @@ azmcp mysql server param set --subscription <subscription> \
# Hierarchical list command for PostgreSQL resources
# Without parameters: lists all PostgreSQL servers in the resource group
# With --server: lists all databases on that server
# With --server and --database: lists all tables in that database
# With --server and --database: lists all tables in that database (optionally scoped to a --schema, defaults to 'public')
# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired
azmcp postgres list --subscription <subscription> \
--resource-group <resource-group> \
--user <user> \
[--server <server>] \
[--database <database>]
[--database <database>] \
[--schema <schema>]

# Execute a query on a PostgreSQL database
# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired
azmcp postgres database query --subscription <subscription> \
--resource-group <resource-group> \
--user <user> \
azmcp postgres database query --user <user> \
--server <server> \
--database <database> \
--query <query>

# Get the schema of a specific table in a PostgreSQL database
# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired
azmcp postgres table schema get --subscription <subscription> \
--resource-group <resource-group> \
--user <user> \
azmcp postgres table schema get --user <user> \
--server <server> \
--database <database> \
--table <table>
Expand Down
1 change: 1 addition & 0 deletions servers/Azure.Mcp.Server/docs/e2eTestPrompts.md
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ This file contains prompts used for end-to-end testing to ensure each tool is in
| postgres_list | Show me the PostgreSQL databases in server \<server> |
| postgres_list | List all tables in the PostgreSQL database \<database> in server \<server> |
| postgres_list | Show me the tables in the PostgreSQL database \<database> in server \<server> |
| postgres_list | List all tables in the \<schema> schema of the PostgreSQL database \<database> in server \<server> |
| postgres_database_query | Show me all items that contain the word \<search_term> in the PostgreSQL database \<database> in server \<server> |
| postgres_server_config_get | Show me the configuration of PostgreSQL server \<server> |
| postgres_server_param_get | Show me if the parameter my PostgreSQL server \<server> has replication enabled |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@

namespace Azure.Mcp.Tools.Postgres.Commands;

// Data-plane commands connect directly to PostgreSQL via Npgsql and therefore do not
// require ARM-scoping options (--subscription / --resource-group). They derive from
// GlobalCommand rather than the subscription-based hierarchy so those options are not
// part of the MCP input schema.
public abstract class BaseDatabaseCommand<
[DynamicallyAccessedMembers(TrimAnnotations.CommandAnnotations)] TOptions>(ILogger<BasePostgresCommand<TOptions>> logger)
: BaseServerCommand<TOptions>(logger) where TOptions : BasePostgresOptions, new()
[DynamicallyAccessedMembers(TrimAnnotations.CommandAnnotations)] TOptions>(ILogger<BaseDatabaseCommand<TOptions>> logger)
: GlobalCommand<TOptions> where TOptions : BasePostgresOptions, new()
{
protected readonly ILogger<BaseDatabaseCommand<TOptions>> _logger = logger;

protected override void RegisterOptions(Command command)
{
base.RegisterOptions(command);
command.Options.Add(PostgresOptionDefinitions.Server);
command.Options.Add(PostgresOptionDefinitions.User);
command.Options.Add(PostgresOptionDefinitions.Database);
command.Options.Add(PostgresOptionDefinitions.AuthType);
command.Options.Add(PostgresOptionDefinitions.Password);
Expand All @@ -24,6 +32,8 @@ protected override void RegisterOptions(Command command)
protected override TOptions BindOptions(ParseResult parseResult)
{
var options = base.BindOptions(parseResult);
options.Server = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.Server.Name);
options.User = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.User.Name);
options.Database = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.Database.Name);
options.AuthType = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.AuthType.Name);
options.Password = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.Password.Name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
// Validate the query early to avoid sending unsafe SQL to the server.
SqlQueryValidator.EnsureReadOnlySelect(options.Query);
List<string> queryResult = await _postgresService.ExecuteQueryAsync(
options.Subscription!,
options.ResourceGroup!,
Comment thread
vcolin7 marked this conversation as resolved.
options.AuthType!,
options.User!,
options.Password,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ protected override void RegisterOptions(Command command)
command.Options.Add(PostgresOptionDefinitions.User.AsOptional());
command.Options.Add(PostgresOptionDefinitions.ServerOptional);
command.Options.Add(PostgresOptionDefinitions.DatabaseOptional);
command.Options.Add(PostgresOptionDefinitions.Schema);
command.Options.Add(PostgresOptionDefinitions.AuthType);
command.Options.Add(PostgresOptionDefinitions.Password);
command.Validators.Add(result =>
Expand All @@ -58,6 +59,7 @@ protected override BasePostgresOptions BindOptions(ParseResult parseResult)
var options = base.BindOptions(parseResult);
options.Server = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.ServerOptional.Name);
options.Database = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.DatabaseOptional.Name);
options.Schema = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.Schema.Name);
options.AuthType = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.AuthType.Name);
options.Password = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.Password.Name);
return options;
Expand All @@ -81,13 +83,12 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
{
// List tables in specified database
List<string> tables = await _postgresService.ListTablesAsync(
options.Subscription!,
options.ResourceGroup!,
options.AuthType!,
options.User!,
options.Password,
options.Server!,
options.Database!,
string.IsNullOrEmpty(options.Schema) ? "public" : options.Schema,
cancellationToken);

context.Response.Results = ResponseResult.Create(
Expand All @@ -98,8 +99,6 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
{
// List databases on specified server
List<string> databases = await _postgresService.ListDatabasesAsync(
options.Subscription!,
options.ResourceGroup!,
options.AuthType!,
options.User!,
options.Password,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
{

List<string> schema = await _postgresService.GetTableSchemaAsync(
options.Subscription!,
options.ResourceGroup!,
options.AuthType!,
options.User!,
options.Password,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ public class BasePostgresOptions : SubscriptionOptions

[JsonPropertyName(PostgresOptionDefinitions.DatabaseName)]
public string? Database { get; set; }

[JsonPropertyName(PostgresOptionDefinitions.SchemaName)]
public string? Schema { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public static class PostgresOptionDefinitions
public const string PasswordText = "password";
public const string ServerName = "server";
public const string DatabaseName = "database";
public const string SchemaName = "schema";
public const string TableName = "table";
public const string QueryText = "query";
public const string ParamName = "param";
Expand Down Expand Up @@ -72,6 +73,15 @@ public static class PostgresOptionDefinitions
Description = "The PostgreSQL database to list tables from (optional, requires --server)."
};

public static readonly Option<string?> Schema = new(
$"--{SchemaName}"
)
{
Description = "The PostgreSQL schema to list tables from when listing tables (optional, defaults to 'public').",
Arity = ArgumentArity.ZeroOrOne,
Required = false
};

public static readonly Option<string> Table = new(
$"--{TableName}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@ namespace Azure.Mcp.Tools.Postgres.Services;
public interface IPostgresService
{
Task<List<string>> ListDatabasesAsync(
string subscriptionId,
string resourceGroup,
string authType,
string user,
string? password,
string server,
CancellationToken cancellationToken);

Task<List<string>> ExecuteQueryAsync(
string subscriptionId,
string resourceGroup,
string authType,
string user,
string? password,
Expand All @@ -28,18 +24,15 @@ Task<List<string>> ExecuteQueryAsync(
CancellationToken cancellationToken);

Task<List<string>> ListTablesAsync(
string subscriptionId,
string resourceGroup,
string authType,
string user,
string? password,
string server,
string database,
string schema,
CancellationToken cancellationToken);

Task<List<string>> GetTableSchemaAsync(
string subscriptionId,
string resourceGroup,
string authType,
string user,
string? password,
Expand Down
12 changes: 3 additions & 9 deletions tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ private string NormalizeServerName(string server)
}

public async Task<List<string>> ListDatabasesAsync(
string subscriptionId,
string resourceGroup,
string authType,
string user,
string? password,
Expand All @@ -104,8 +102,6 @@ public async Task<List<string>> ListDatabasesAsync(
}

public async Task<List<string>> ExecuteQueryAsync(
string subscriptionId,
string resourceGroup,
string authType,
string user,
string? password,
Expand Down Expand Up @@ -162,22 +158,22 @@ public async Task<List<string>> ExecuteQueryAsync(
}

public async Task<List<string>> ListTablesAsync(
string subscriptionId,
string resourceGroup,
string authType,
string user,
string? password,
string server,
string database,
string schema,
CancellationToken cancellationToken)
{
string? passwordToUse = await GetPassword(authType, password, cancellationToken);
var host = NormalizeServerName(server);
var connectionString = BuildConnectionString(host, database, user, passwordToUse);

var query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';";
var query = "SELECT table_name FROM information_schema.tables WHERE table_schema = @schema ORDER BY table_name;";
await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString, authType, cancellationToken);
await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource);
command.Parameters.AddWithValue("schema", schema);
await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command, cancellationToken);
var tables = new List<string>();
while (await reader.ReadAsync(cancellationToken))
Expand All @@ -188,8 +184,6 @@ public async Task<List<string>> ListTablesAsync(
}

public async Task<List<string>> GetTableSchemaAsync(
string subscriptionId,
string resourceGroup,
string authType,
string user,
string? password,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ public async Task ExecuteAsync_ReturnsQueryResults_WhenQueryIsValid()
{
var expectedResults = new List<string> { "result1", "result2" };

Service.ExecuteQueryAsync("sub123", "rg1", AuthTypes.MicrosoftEntra, "user1", null, "server1", "db123", "SELECT * FROM test;", Arg.Any<CancellationToken>())
Service.ExecuteQueryAsync(AuthTypes.MicrosoftEntra, "user1", null, "server1", "db123", "SELECT * FROM test;", Arg.Any<CancellationToken>())
.Returns(expectedResults);

var response = await ExecuteCommandAsync(
"--subscription", "sub123",
"--resource-group", "rg1",
$"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra,
"--user", "user1",
"--server", "server1",
Expand All @@ -41,12 +39,10 @@ public async Task ExecuteAsync_ReturnsQueryResults_WhenQueryIsValid()
[Fact]
public async Task ExecuteAsync_ReturnsEmpty_WhenQueryFails()
{
Service.ExecuteQueryAsync("sub123", "rg1", AuthTypes.MicrosoftEntra, "user1", null, "server1", "db123", "SELECT * FROM test;", Arg.Any<CancellationToken>())
Service.ExecuteQueryAsync(AuthTypes.MicrosoftEntra, "user1", null, "server1", "db123", "SELECT * FROM test;", Arg.Any<CancellationToken>())
.Returns([]);

var response = await ExecuteCommandAsync(
"--subscription", "sub123",
"--resource-group", "rg1",
$"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra,
"--user", "user1",
"--server", "server1",
Expand All @@ -58,17 +54,13 @@ public async Task ExecuteAsync_ReturnsEmpty_WhenQueryFails()
}

[Theory]
[InlineData("--subscription")]
[InlineData("--resource-group")]
[InlineData("--user")]
[InlineData("--server")]
[InlineData("--database")]
[InlineData("--query")]
public async Task ExecuteAsync_ReturnsError_WhenParameterIsMissing(string missingParameter)
{
var response = await ExecuteCommandAsync(ArgBuilder.BuildArgs(missingParameter,
("--subscription", "sub123"),
("--resource-group", "rg1"),
($"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra),
("--user", "user1"),
("--server", "server123"),
Expand All @@ -81,6 +73,18 @@ public async Task ExecuteAsync_ReturnsError_WhenParameterIsMissing(string missin
Assert.Equal($"Missing Required options: {missingParameter}", response.Message);
}

[Fact]
public void Command_DoesNotExposeArmScopingOptions()
{
var optionNames = CommandDefinition.Options.Select(o => o.Name.TrimStart('-')).ToList();

Assert.DoesNotContain("subscription", optionNames);
Assert.DoesNotContain("resource-group", optionNames);
Assert.Contains(PostgresOptionDefinitions.UserName, optionNames);
Assert.Contains(PostgresOptionDefinitions.ServerName, optionNames);
Assert.Contains(PostgresOptionDefinitions.DatabaseName, optionNames);
}

[Theory]
[InlineData("DELETE FROM users;")]
[InlineData("SELECT * FROM users; DROP TABLE users;")]
Expand Down Expand Up @@ -126,8 +130,6 @@ public async Task ExecuteAsync_ReturnsError_WhenParameterIsMissing(string missin
public async Task ExecuteAsync_InvalidQuery_ValidationError(string badQuery)
{
var response = await ExecuteCommandAsync(
"--subscription", "sub123",
"--resource-group", "rg1",
$"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra,
"--user", "user1",
"--server", "server1",
Expand All @@ -137,16 +139,14 @@ public async Task ExecuteAsync_InvalidQuery_ValidationError(string badQuery)
Assert.NotNull(response);
Assert.Equal(HttpStatusCode.BadRequest, response.Status); // CommandValidationException => 400
// Service should never be called for invalid queries.
await Service.DidNotReceive().ExecuteQueryAsync(Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<CancellationToken>());
await Service.DidNotReceive().ExecuteQueryAsync(Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<CancellationToken>());
}

[Fact]
public async Task ExecuteAsync_LongQuery_ValidationError()
{
var longSelect = "SELECT " + new string('a', 6000) + " FROM test"; // exceeds max length
var response = await ExecuteCommandAsync(
"--subscription", "sub123",
"--resource-group", "rg1",
$"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra,
"--user", "user1",
"--server", "server1",
Expand All @@ -155,6 +155,7 @@ public async Task ExecuteAsync_LongQuery_ValidationError()

Assert.NotNull(response);
Assert.Equal(HttpStatusCode.BadRequest, response.Status);
await Service.DidNotReceive().ExecuteQueryAsync(Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<CancellationToken>());
await Service.DidNotReceive().ExecuteQueryAsync(Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<CancellationToken>());
}
}

Loading