Skip to content

Commit

Permalink
Cancel request tool calls when cancelling chat request
Browse files Browse the repository at this point in the history
  • Loading branch information
roblourens committed Dec 20, 2024
1 parent 41793c2 commit 8281320
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 13 deletions.
70 changes: 63 additions & 7 deletions src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import { renderStringAsPlaintext } from '../../../../base/browser/markdownRenderer.js';
import { RunOnceScheduler } from '../../../../base/common/async.js';
import { CancellationToken } from '../../../../base/common/cancellation.js';
import { CancellationToken, CancellationTokenSource } from '../../../../base/common/cancellation.js';
import { CancellationError, isCancellationError } from '../../../../base/common/errors.js';
import { Emitter } from '../../../../base/common/event.js';
import { Iterable } from '../../../../base/common/iterator.js';
import { Disposable, IDisposable, toDisposable } from '../../../../base/common/lifecycle.js';
import { Disposable, DisposableStore, dispose, IDisposable, toDisposable } from '../../../../base/common/lifecycle.js';
import { localize } from '../../../../nls.js';
import { IContextKeyService } from '../../../../platform/contextkey/common/contextkey.js';
import { IDialogService } from '../../../../platform/dialogs/common/dialogs.js';
Expand Down Expand Up @@ -37,6 +37,9 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
private _tools = new Map<string, IToolEntry>();
private _toolContextKeys = new Set<string>();


private _callsByRequestId = new Map<string, IDisposable[]>();

constructor(
@IExtensionService private readonly _extensionService: IExtensionService,
@IContextKeyService private readonly _contextKeyService: IContextKeyService,
Expand Down Expand Up @@ -141,10 +144,34 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
// Shortcut to write to the model directly here, but could call all the way back to use the real stream.
let toolInvocation: ChatToolInvocation | undefined;

let requestId: string | undefined;
let store: DisposableStore | undefined;
try {
if (dto.context) {
const model = this._chatService.getSession(dto.context?.sessionId) as ChatModel;
store = new DisposableStore();
const model = this._chatService.getSession(dto.context?.sessionId) as ChatModel | undefined;
if (!model) {
throw new Error(`Tool called for unknown chat session`);
}

const request = model.getRequests().at(-1)!;
requestId = request.id;

// Replace the token with a new token that we can cancel when cancelToolCallsForRequest is called
if (!this._callsByRequestId.has(requestId)) {
this._callsByRequestId.set(requestId, []);
}
this._callsByRequestId.get(requestId)!.push(store);

const source = new CancellationTokenSource();
store.add(toDisposable(() => {
toolInvocation!.confirmed.complete(false);
source.dispose(true);
}));
store.add(token.onCancellationRequested(() => {
source.cancel();
}));
token = source.token;

const prepared = tool.impl.prepareToolInvocation ?
await tool.impl.prepareToolInvocation(dto.parameters, token)
Expand All @@ -153,9 +180,7 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
const defaultMessage = localize('toolInvocationMessage', "Using {0}", `"${tool.data.displayName}"`);
const invocationMessage = prepared?.invocationMessage ?? defaultMessage;
toolInvocation = new ChatToolInvocation(invocationMessage, prepared?.confirmationMessages);
token.onCancellationRequested(() => {
toolInvocation!.confirmed.complete(false);
});

model.acceptResponseProgress(request, toolInvocation);
if (prepared?.confirmationMessages) {
const userConfirmed = await toolInvocation.confirmed.p;
Expand All @@ -176,7 +201,6 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
}
}


const result = await tool.impl.invoke(dto, countTokens, token);
this._telemetryService.publicLog2<LanguageModelToolInvokedEvent, LanguageModelToolInvokedClassification>(
'languageModelToolInvoked',
Expand All @@ -200,7 +224,39 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo
throw err;
} finally {
toolInvocation?.isCompleteDeferred.complete();

if (requestId && store) {
this.cleanupCallDisposables(requestId, store);
}
}
}

private cleanupCallDisposables(requestId: string, store: DisposableStore): void {
const disposables = this._callsByRequestId.get(requestId);
if (disposables) {
const index = disposables.indexOf(store);
if (index > -1) {
disposables.splice(index, 1);
}
if (disposables.length === 0) {
this._callsByRequestId.delete(requestId);
}
}
store.dispose();
}

cancelToolCallsForRequest(requestId: string): void {
const calls = this._callsByRequestId.get(requestId);
if (calls) {
calls.forEach(call => call.dispose());
this._callsByRequestId.delete(requestId);
}
}

public override dispose(): void {
super.dispose();

this._callsByRequestId.forEach(calls => dispose(calls));
}
}

Expand Down
11 changes: 9 additions & 2 deletions src/vs/workbench/contrib/chat/common/chatServiceImpl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import { ChatServiceTelemetry } from './chatServiceTelemetry.js';
import { IChatSlashCommandService } from './chatSlashCommands.js';
import { IChatVariablesService } from './chatVariables.js';
import { ChatMessageRole, IChatMessage } from './languageModels.js';
import { ILanguageModelToolsService } from './languageModelToolsService.js';

const serializedChatKey = 'interactive.sessions';

Expand Down Expand Up @@ -86,14 +87,19 @@ const maxPersistedSessions = 25;
class CancellableRequest implements IDisposable {
constructor(
public readonly cancellationTokenSource: CancellationTokenSource,
public requestId?: string | undefined
public requestId: string | undefined,
@ILanguageModelToolsService private readonly toolsService: ILanguageModelToolsService
) { }

dispose() {
this.cancellationTokenSource.dispose();
}

cancel() {
if (this.requestId) {
this.toolsService.cancelToolCallsForRequest(this.requestId);
}

this.cancellationTokenSource.cancel();
}
}
Expand Down Expand Up @@ -778,7 +784,8 @@ export class ChatService extends Disposable implements IChatService {
}
};
const rawResponsePromise = sendRequestInternal();
this._pendingRequests.set(model.sessionId, new CancellableRequest(source));
// Note- requestId is not known at this point, assigned later
this._pendingRequests.set(model.sessionId, this.instantiationService.createInstance(CancellableRequest, source, undefined));
rawResponsePromise.finally(() => {
this._pendingRequests.deleteAndDispose(model.sessionId);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ export interface ILanguageModelToolsService {
getTool(id: string): IToolData | undefined;
getToolByName(name: string): IToolData | undefined;
invokeTool(invocation: IToolInvocation, countTokens: CountTokensCallback, token: CancellationToken): Promise<IToolResult>;
cancelToolCallsForRequest(requestId: string): void;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,32 @@
import * as assert from 'assert';
import { CancellationToken } from '../../../../../base/common/cancellation.js';
import { ensureNoDisposablesAreLeakedInTestSuite } from '../../../../../base/test/common/utils.js';
import { TestConfigurationService } from '../../../../../platform/configuration/test/common/testConfigurationService.js';
import { ContextKeyService } from '../../../../../platform/contextkey/browser/contextKeyService.js';
import { ContextKeyEqualsExpr, IContextKeyService } from '../../../../../platform/contextkey/common/contextkey.js';
import { workbenchInstantiationService } from '../../../../test/browser/workbenchTestServices.js';
import { LanguageModelToolsService } from '../../browser/languageModelToolsService.js';
import { IChatModel } from '../../common/chatModel.js';
import { IChatService } from '../../common/chatService.js';
import { IToolData, IToolImpl, IToolInvocation } from '../../common/languageModelToolsService.js';
import { ContextKeyService } from '../../../../../platform/contextkey/browser/contextKeyService.js';
import { TestConfigurationService } from '../../../../../platform/configuration/test/common/testConfigurationService.js';
import { MockChatService } from '../common/mockChatService.js';
import { CancellationError, isCancellationError } from '../../../../../base/common/errors.js';
import { Barrier } from '../../../../../base/common/async.js';

suite('LanguageModelToolsService', () => {
const store = ensureNoDisposablesAreLeakedInTestSuite();

let contextKeyService: IContextKeyService;
let service: LanguageModelToolsService;
let chatService: MockChatService;

setup(() => {
const instaService = workbenchInstantiationService({
contextKeyService: () => store.add(new ContextKeyService(new TestConfigurationService))
contextKeyService: () => store.add(new ContextKeyService(new TestConfigurationService)),
}, store);
contextKeyService = instaService.get(IContextKeyService);
chatService = new MockChatService();
instaService.stub(IChatService, chatService);
service = store.add(instaService.createInstance(LanguageModelToolsService));
});

Expand Down Expand Up @@ -122,4 +130,61 @@ suite('LanguageModelToolsService', () => {
const result = await service.invokeTool(dto, async () => 0, CancellationToken.None);
assert.strictEqual(result.content[0].value, 'result');
});

test('cancel tool call', async () => {
const toolData: IToolData = {
id: 'testTool',
modelDescription: 'Test Tool',
displayName: 'Test Tool'
};

store.add(service.registerToolData(toolData));

const toolBarrier = new Barrier();
const toolImpl: IToolImpl = {
invoke: async (invocation, countTokens, cancelToken) => {
assert.strictEqual(invocation.callId, '1');
assert.strictEqual(invocation.toolId, 'testTool');
assert.deepStrictEqual(invocation.parameters, { a: 1 });
await toolBarrier.wait();
if (cancelToken.isCancellationRequested) {
throw new CancellationError();
} else {
throw new Error('Tool call should be cancelled');
}
}
};

store.add(service.registerToolImplementation('testTool', toolImpl));

const sessionId = 'sessionId';
const requestId = 'requestId';
const dto: IToolInvocation = {
callId: '1',
toolId: 'testTool',
tokenBudget: 100,
parameters: {
a: 1
},
context: {
sessionId
},
};
chatService.addSession({
sessionId: sessionId,
getRequests: () => {
return [{
id: requestId
}];
},
acceptResponseProgress: () => { }
} as any as IChatModel);

const toolPromise = service.invokeTool(dto, async () => 0, CancellationToken.None);
service.cancelToolCallsForRequest(requestId);
toolBarrier.open();
await assert.rejects(toolPromise, err => {
return isCancellationError(err);
}, 'Expected tool call to be cancelled');
});
});
7 changes: 6 additions & 1 deletion src/vs/workbench/contrib/chat/test/common/mockChatService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export class MockChatService implements IChatService {
_serviceBrand: undefined;
transferredSessionData: IChatTransferredSessionData | undefined;

private sessions = new Map<string, IChatModel>();

isEnabled(location: ChatAgentLocation): boolean {
throw new Error('Method not implemented.');
}
Expand All @@ -27,9 +29,12 @@ export class MockChatService implements IChatService {
startSession(location: ChatAgentLocation, token: CancellationToken): ChatModel | undefined {
throw new Error('Method not implemented.');
}
addSession(session: IChatModel): void {
this.sessions.set(session.sessionId, session);
}
getSession(sessionId: string): IChatModel | undefined {
// eslint-disable-next-line local/code-no-dangerous-type-assertions
return {} as IChatModel;
return this.sessions.get(sessionId) ?? {} as IChatModel;
}
getOrRestoreSession(sessionId: string): IChatModel | undefined {
throw new Error('Method not implemented.');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ export class MockLanguageModelToolsService implements ILanguageModelToolsService

constructor() { }

cancelToolCallsForRequest(requestId: string): void {
throw new Error('Method not implemented.');
}

onDidChangeTools: Event<void> = Event.None;

registerToolData(toolData: IToolData): IDisposable {
Expand Down

0 comments on commit 8281320

Please sign in to comment.