From 015452e8fbb998dff6f286f50f90d2ec1f7faeab Mon Sep 17 00:00:00 2001 From: akumatus Date: Tue, 18 Feb 2025 02:20:20 +0000 Subject: [PATCH] fix(core): unable to explain image when network search is active (#10228) Fix issue [PD-2316](https://linear.app/affine-design/issue/PD-2316). --- .../presets/ai/chat-panel/chat-panel-input.ts | 26 ++-- .../presets/ai/peek-view/chat-block-input.ts | 117 ++++++++++-------- .../affine-cloud-copilot/e2e/copilot.spec.ts | 39 ++++++ 3 files changed, 116 insertions(+), 66 deletions(-) diff --git a/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts b/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts index 6ad48b6699194..d30e9069c71aa 100644 --- a/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts +++ b/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts @@ -288,7 +288,7 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) { ); } - private get _promptName() { + private _getPromptName() { if (this._isNetworkDisabled) { return PROMPT_NAME_AFFINE_AI; } @@ -297,12 +297,12 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) { : PROMPT_NAME_AFFINE_AI; } - private async _updatePromptName() { - if (this._lastPromptName !== this._promptName) { - this._lastPromptName = this._promptName; + private async _updatePromptName(promptName: string) { + if (this._lastPromptName !== promptName) { const sessionId = await this.getSessionId(); - if (sessionId) { - await AIProvider.session?.updateSession(sessionId, this._promptName); + if (sessionId && AIProvider.session) { + await AIProvider.session.updateSession(sessionId, promptName); + this._lastPromptName = promptName; } } } @@ -457,7 +457,7 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) { }} @keydown=${async (evt: KeyboardEvent) => { if (evt.key === 'Enter' && !evt.shiftKey && !evt.isComposing) { - this._onTextareaSend(evt); + await this._onTextareaSend(evt); } }} @focus=${() => { @@ -538,7 +538,7 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) { `; } - private readonly _onTextareaSend = (e: MouseEvent | KeyboardEvent) => { + private readonly _onTextareaSend = async (e: MouseEvent | KeyboardEvent) => { e.preventDefault(); e.stopPropagation(); @@ -549,17 +549,17 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) { this.isInputEmpty = true; this.textarea.style.height = 'unset'; - this.send(value).catch(console.error); + await this.send(value); }; send = async (text: string) => { - const { status, markdown, chips } = this.chatContextValue; + const { status, markdown, chips, images } = this.chatContextValue; if (status === 'loading' || status === 'transmitting') return; if (!text) return; try { - const { images } = this.chatContextValue; const { doc } = this.host; + const promptName = this._getPromptName(); this.updateContext({ images: [], @@ -593,7 +593,9 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) { ], }); - await this._updatePromptName(); + // must update prompt name after local chat message is updated + // otherwise, the unauthorized error can not be rendered properly + await this._updatePromptName(promptName); const abortController = new AbortController(); const sessionId = await this.getSessionId(); diff --git a/packages/frontend/core/src/blocksuite/presets/ai/peek-view/chat-block-input.ts b/packages/frontend/core/src/blocksuite/presets/ai/peek-view/chat-block-input.ts index bf70214fa5f25..de0d1bdffbfb9 100644 --- a/packages/frontend/core/src/blocksuite/presets/ai/peek-view/chat-block-input.ts +++ b/packages/frontend/core/src/blocksuite/presets/ai/peek-view/chat-block-input.ts @@ -241,7 +241,7 @@ export class ChatBlockInput extends SignalWatcher(LitElement) { ${status === 'transmitting' ? html`
${ChatAbortIcon}
` : html`
@@ -306,7 +306,7 @@ export class ChatBlockInput extends SignalWatcher(LitElement) { return !!this.chatContext.images.length; } - private get _promptName() { + private _getPromptName() { if (this._isNetworkDisabled) { return PROMPT_NAME_AFFINE_AI; } @@ -315,15 +315,12 @@ export class ChatBlockInput extends SignalWatcher(LitElement) { : PROMPT_NAME_AFFINE_AI; } - private async _updatePromptName() { - if (this._lastPromptName !== this._promptName) { - this._lastPromptName = this._promptName; + private async _updatePromptName(promptName: string) { + if (this._lastPromptName !== promptName) { const { currentSessionId } = this.chatContext; - if (currentSessionId) { - await AIProvider.session?.updateSession( - currentSessionId, - this._promptName - ); + if (currentSessionId && AIProvider.session) { + await AIProvider.session.updateSession(currentSessionId, promptName); + this._lastPromptName = promptName; } } } @@ -346,7 +343,7 @@ export class ChatBlockInput extends SignalWatcher(LitElement) { private readonly _handleKeyDown = async (evt: KeyboardEvent) => { if (evt.key === 'Enter' && !evt.shiftKey && !evt.isComposing) { evt.preventDefault(); - await this._send(); + await this._onTextareaSend(evt); } }; @@ -452,56 +449,70 @@ export class ChatBlockInput extends SignalWatcher(LitElement) { `; } - private readonly _send = async () => { - const { images, status } = this.chatContext; - if (status === 'loading' || status === 'transmitting') return; + private readonly _onTextareaSend = async (e: MouseEvent | KeyboardEvent) => { + e.preventDefault(); + e.stopPropagation(); - const text = this.textarea.value; - if (!text && !images.length) { - return; - } + const value = this.textarea.value.trim(); + if (value.length === 0) return; - const { doc } = this.host; this.textarea.value = ''; this._isInputEmpty = true; this.textarea.style.height = 'unset'; - this.updateContext({ - images: [], - status: 'loading', - error: null, - }); - - const attachments = await Promise.all( - images?.map(image => readBlobAsURL(image)) - ); - const userInfo = await AIProvider.userInfo; - this.updateContext({ - messages: [ - ...this.chatContext.messages, - { - id: '', - content: text, - role: 'user', - createdAt: new Date().toISOString(), - attachments, - userId: userInfo?.id, - userName: userInfo?.name, - avatarUrl: userInfo?.avatarUrl ?? undefined, - }, - { - id: '', - content: '', - role: 'assistant', - createdAt: new Date().toISOString(), - }, - ], - }); + await this._send(value); + }; - const { currentChatBlockId, currentSessionId } = this.chatContext; - let content = ''; + private readonly _send = async (text: string) => { + const { images, status, currentChatBlockId, currentSessionId } = + this.chatContext; const chatBlockExists = !!currentChatBlockId; + let content = ''; + + if (status === 'loading' || status === 'transmitting') return; + if (!text) return; + try { + const { doc } = this.host; + const promptName = this._getPromptName(); + + this.updateContext({ + images: [], + status: 'loading', + error: null, + }); + + const attachments = await Promise.all( + images?.map(image => readBlobAsURL(image)) + ); + + const userInfo = await AIProvider.userInfo; + this.updateContext({ + messages: [ + ...this.chatContext.messages, + { + id: '', + content: text, + role: 'user', + createdAt: new Date().toISOString(), + attachments, + userId: userInfo?.id, + userName: userInfo?.name, + avatarUrl: userInfo?.avatarUrl ?? undefined, + }, + { + id: '', + content: '', + role: 'assistant', + createdAt: new Date().toISOString(), + }, + ], + }); + + // must update prompt name after local chat message is updated + // otherwise, the unauthorized error can not be rendered properly + await this._updatePromptName(promptName); + // If has not forked a chat session, fork a new one let chatSessionId = currentSessionId; if (!chatSessionId) { @@ -518,8 +529,6 @@ export class ChatBlockInput extends SignalWatcher(LitElement) { chatSessionId = forkSessionId; } - await this._updatePromptName(); - const abortController = new AbortController(); const stream = AIProvider.actions.chat?.({ input: text, diff --git a/tests/affine-cloud-copilot/e2e/copilot.spec.ts b/tests/affine-cloud-copilot/e2e/copilot.spec.ts index 41973efc0e997..44b800ac98e0c 100644 --- a/tests/affine-cloud-copilot/e2e/copilot.spec.ts +++ b/tests/affine-cloud-copilot/e2e/copilot.spec.ts @@ -517,6 +517,45 @@ test.describe('chat panel', () => { ); }); + test('can identify shape color, even if network search is active', async ({ + page, + }) => { + await page.reload(); + await clickSideBarAllPageButton(page); + await page.waitForTimeout(200); + await createLocalWorkspace({ name: 'test' }, page); + await clickNewPageButton(page); + + await openChat(page); + await page.getByTestId('chat-network-search').click(); + + await switchToEdgelessMode(page); + + const shapeButton = await page.waitForSelector( + 'edgeless-shape-tool-button' + ); + await shapeButton.click(); + await page.mouse.click(400, 400); + + const askAIButton = await page.waitForSelector('.copilot-icon-button'); + await askAIButton.click(); + + await page.waitForTimeout(1000); + await page.keyboard.type('What color is this shape?'); + await page.keyboard.press('Enter'); + + const history = await collectChat(page); + expect(history[0]).toEqual({ + name: 'You', + content: 'What color is this shape?', + }); + expect(history[1].name).toBe('AFFiNE AI'); + expect(history[1].content).toContain('yellow'); + expect(await page.locator('chat-panel affine-footnote-node').count()).toBe( + 0 + ); + }); + test('can trigger inline ai input and action panel by clicking Start with AI button', async ({ page, }) => {