Skip to content

Commit

Permalink
Improve AAD fallback if key authentication is disabled (#2290)
Browse files Browse the repository at this point in the history
* Improve AAD fallback if key authentication is disabled

Don't load and preserve account keys if user has opted in for AAD
or if the account has local authentication disabled.

* Show a warning if listing account keys failed.

* Improve error handling and reporting for CosmosDB accounts
  • Loading branch information
sevoku authored Sep 9, 2024
1 parent d311ab3 commit f0f4e0d
Showing 1 changed file with 76 additions and 41 deletions.
117 changes: 76 additions & 41 deletions src/tree/SubscriptionTreeItem.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import { CosmosDBManagementClient } from '@azure/arm-cosmosdb';
import { DatabaseAccountGetResults, DatabaseAccountListKeysResult } from '@azure/arm-cosmosdb/src/models';
import { ILocationWizardContext, LocationListStep, ResourceGroupListStep, SubscriptionTreeItemBase, getResourceGroupFromId, uiUtils } from '@microsoft/vscode-azext-azureutils';
import { AzExtParentTreeItem, AzExtTreeItem, AzureWizard, AzureWizardPromptStep, IActionContext } from '@microsoft/vscode-azext-utils';
import { AzExtParentTreeItem, AzExtTreeItem, AzureWizard, AzureWizardPromptStep, IActionContext, callWithTelemetryAndErrorHandling } from '@microsoft/vscode-azext-utils';
import * as vscode from 'vscode';
import { API, Experience, getExperienceLabel, tryGetExperience } from '../AzureDBExperiences';
import { CosmosDBCredential } from '../docdb/getCosmosClient';
import { CosmosDBCredential, CosmosDBKeyCredential } from '../docdb/getCosmosClient';
import { DocDBAccountTreeItem } from "../docdb/tree/DocDBAccountTreeItem";
import { ext } from '../extensionVariables';
import { tryGetGremlinEndpointFromAzure } from '../graph/gremlinEndpoints';
Expand Down Expand Up @@ -121,51 +121,86 @@ export class SubscriptionTreeItem extends SubscriptionTreeItemBase {
const label: string = name + (accountKindLabel ? ` (${accountKindLabel})` : ``);
const isEmulator: boolean = false;

if (experience && experience.api === "MongoDB") {
const result = await client.databaseAccounts.listConnectionStrings(resourceGroup, name);
const connectionString: URL = new URL(nonNullProp(nonNullProp(result, 'connectionStrings')[0], 'connectionString'));
// for any Mongo connectionString, append this query param because the Cosmos Mongo API v3.6 doesn't support retrywrites
// but the newer node.js drivers started breaking this
const searchParam: string = 'retrywrites';
if (!connectionString.searchParams.has(searchParam)) {
connectionString.searchParams.set(searchParam, 'false');
}

// Use the default connection string
return new MongoAccountTreeItem(parent, id, label, connectionString.toString(), isEmulator, databaseAccount);
} else {
let keyResult: DatabaseAccountListKeysResult | undefined;
try {
keyResult = await client.databaseAccounts.listKeys(resourceGroup, name);
} catch (error) {
// If the client failed to list keys, proceed without using keys
}
const newNode = await callWithTelemetryAndErrorHandling('cosmosDB.initCosmosDBChild', async (context: IActionContext) => {
// leave error handling to the caller (command or tree node)
context.errorHandling.suppressDisplay = true;
// rethrow all errors to satisfy initCosmosDBChild contract
context.errorHandling.rethrow = true;
context.telemetry.properties.experience = experience?.api;

if (experience && experience.api === "MongoDB") {
const result = await client.databaseAccounts.listConnectionStrings(resourceGroup, name);
const connectionString: URL = new URL(nonNullProp(nonNullProp(result, 'connectionStrings')[0], 'connectionString'));
// for any Mongo connectionString, append this query param because the Cosmos Mongo API v3.6 doesn't support retrywrites
// but the newer node.js drivers started breaking this
const searchParam: string = 'retrywrites';
if (!connectionString.searchParams.has(searchParam)) {
connectionString.searchParams.set(searchParam, 'false');
}

let keyCred = keyResult?.primaryMasterKey ? {
type: "key",
key: keyResult.primaryMasterKey
} : undefined;
const testCosmosAuth = vscode.workspace.getConfiguration().get<boolean>("azureDatabases.useCosmosOAuth");
if (testCosmosAuth) {
keyCred = undefined;
}
const authCred = { type: "auth" };
const credentials = [keyCred, authCred].filter((cred): cred is CosmosDBCredential => cred !== undefined);
switch (experience && experience.api) {
case "Table":
return new TableAccountTreeItem(parent, id, label, documentEndpoint, credentials, isEmulator, databaseAccount);
case "Graph": {
const gremlinEndpoint = await tryGetGremlinEndpointFromAzure(client, resourceGroup, name);
return new GraphAccountTreeItem(parent, id, label, documentEndpoint, gremlinEndpoint, credentials, isEmulator, databaseAccount);
// Use the default connection string
return new MongoAccountTreeItem(parent, id, label, connectionString.toString(), isEmulator, databaseAccount);
} else {
let keyCred: CosmosDBKeyCredential | undefined = undefined;

const forceOAuth = vscode.workspace.getConfiguration().get<boolean>("azureDatabases.useCosmosOAuth");
context.telemetry.properties.useCosmosOAuth = (forceOAuth ?? false).toString();

// disable key auth if the user has opted in to OAuth (AAD/Entra ID)
if (!forceOAuth) {
try {
const acc = await client.databaseAccounts.get(resourceGroup, name);
const localAuthDisabled = acc.disableLocalAuth === true;
context.telemetry.properties.localAuthDisabled = localAuthDisabled.toString();
let keyResult: DatabaseAccountListKeysResult | undefined;
// If the account has local auth disabled, don't even try to use key auth
if (!localAuthDisabled) {
keyResult = await client.databaseAccounts.listKeys(resourceGroup, name);
keyCred = keyResult?.primaryMasterKey ? {
type: "key",
key: keyResult.primaryMasterKey
} : undefined;
context.telemetry.properties.receivedKeyCreds = "true";
} else {
throw new Error("Local auth is disabled");
}
} catch (error) {
context.telemetry.properties.receivedKeyCreds = "false";
const message = localize("keyPermissionErrorMsg", "You do not have the required permissions to list auth keys for [{0}].\nFalling back to using Entra ID.\nYou can change the default authentication in the settings.", name);
const openSettingsItem = localize("openSettings", "Open Settings");
void vscode.window.showWarningMessage(message, ...[openSettingsItem]).then((item) => {
if (item === openSettingsItem) {
void vscode.commands.executeCommand('workbench.action.openSettings', 'azureDatabases.useCosmosOAuth');
}
});
}
}
case "Core":
default:
// Default to DocumentDB, the base type for all Cosmos DB Accounts
return new DocDBAccountTreeItem(parent, id, label, documentEndpoint, credentials, isEmulator, databaseAccount);

// OAuth is always enabled for Cosmos DB and will be used as a fall back if key auth is unavailable
const authCred = { type: "auth" };
const credentials = [keyCred, authCred].filter((cred): cred is CosmosDBCredential => cred !== undefined);
switch (experience && experience.api) {
case "Table":
return new TableAccountTreeItem(parent, id, label, documentEndpoint, credentials, isEmulator, databaseAccount);
case "Graph": {
const gremlinEndpoint = await tryGetGremlinEndpointFromAzure(client, resourceGroup, name);
return new GraphAccountTreeItem(parent, id, label, documentEndpoint, gremlinEndpoint, credentials, isEmulator, databaseAccount);
}
case "Core":
default:
// Default to DocumentDB, the base type for all Cosmos DB Accounts
return new DocDBAccountTreeItem(parent, id, label, documentEndpoint, credentials, isEmulator, databaseAccount);

}
}
});
if (!(newNode instanceof AzExtTreeItem)) {
// note: this should never happen, callWithTelemetryAndErrorHandling will rethrow all errors
throw new Error(localize('invalidCosmosDBAccount', 'Invalid Cosmos DB account.'));
}
return newNode;
}

public static async initPostgresChild(server: PostgresAbstractServer, parent: AzExtParentTreeItem): Promise<AzExtTreeItem> {
const connectionString: string = createPostgresConnectionString(nonNullProp(server, 'fullyQualifiedDomainName'));
const parsedCS: ParsedPostgresConnectionString = parsePostgresConnectionString(connectionString);
Expand Down

0 comments on commit f0f4e0d

Please sign in to comment.