Skip to content

Commit

Permalink
fix(core): unable to explain image when network search is active (#10228
Browse files Browse the repository at this point in the history
  • Loading branch information
akumatus committed Feb 18, 2025
1 parent eed00e0 commit 015452e
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
);
}

private get _promptName() {
private _getPromptName() {
if (this._isNetworkDisabled) {
return PROMPT_NAME_AFFINE_AI;
}
Expand All @@ -297,12 +297,12 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
: PROMPT_NAME_AFFINE_AI;
}

private async _updatePromptName() {
if (this._lastPromptName !== this._promptName) {
this._lastPromptName = this._promptName;
private async _updatePromptName(promptName: string) {
if (this._lastPromptName !== promptName) {
const sessionId = await this.getSessionId();
if (sessionId) {
await AIProvider.session?.updateSession(sessionId, this._promptName);
if (sessionId && AIProvider.session) {
await AIProvider.session.updateSession(sessionId, promptName);
this._lastPromptName = promptName;
}
}
}
Expand Down Expand Up @@ -457,7 +457,7 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
}}
@keydown=${async (evt: KeyboardEvent) => {
if (evt.key === 'Enter' && !evt.shiftKey && !evt.isComposing) {
this._onTextareaSend(evt);
await this._onTextareaSend(evt);
}
}}
@focus=${() => {
Expand Down Expand Up @@ -538,7 +538,7 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
</div>`;
}

private readonly _onTextareaSend = (e: MouseEvent | KeyboardEvent) => {
private readonly _onTextareaSend = async (e: MouseEvent | KeyboardEvent) => {
e.preventDefault();
e.stopPropagation();

Expand All @@ -549,17 +549,17 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
this.isInputEmpty = true;
this.textarea.style.height = 'unset';

this.send(value).catch(console.error);
await this.send(value);
};

send = async (text: string) => {
const { status, markdown, chips } = this.chatContextValue;
const { status, markdown, chips, images } = this.chatContextValue;
if (status === 'loading' || status === 'transmitting') return;
if (!text) return;

try {
const { images } = this.chatContextValue;
const { doc } = this.host;
const promptName = this._getPromptName();

this.updateContext({
images: [],
Expand Down Expand Up @@ -593,7 +593,9 @@ export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
],
});

await this._updatePromptName();
// must update prompt name after local chat message is updated
// otherwise, the unauthorized error can not be rendered properly
await this._updatePromptName(promptName);

const abortController = new AbortController();
const sessionId = await this.getSessionId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
${status === 'transmitting'
? html`<div @click=${this._handleAbort}>${ChatAbortIcon}</div>`
: html`<div
@click="${this._send}"
@click=${this._onTextareaSend}
class="chat-panel-send"
aria-disabled=${this._isInputEmpty}
>
Expand Down Expand Up @@ -306,7 +306,7 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
return !!this.chatContext.images.length;
}

private get _promptName() {
private _getPromptName() {
if (this._isNetworkDisabled) {
return PROMPT_NAME_AFFINE_AI;
}
Expand All @@ -315,15 +315,12 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
: PROMPT_NAME_AFFINE_AI;
}

private async _updatePromptName() {
if (this._lastPromptName !== this._promptName) {
this._lastPromptName = this._promptName;
private async _updatePromptName(promptName: string) {
if (this._lastPromptName !== promptName) {
const { currentSessionId } = this.chatContext;
if (currentSessionId) {
await AIProvider.session?.updateSession(
currentSessionId,
this._promptName
);
if (currentSessionId && AIProvider.session) {
await AIProvider.session.updateSession(currentSessionId, promptName);
this._lastPromptName = promptName;
}
}
}
Expand All @@ -346,7 +343,7 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
private readonly _handleKeyDown = async (evt: KeyboardEvent) => {
if (evt.key === 'Enter' && !evt.shiftKey && !evt.isComposing) {
evt.preventDefault();
await this._send();
await this._onTextareaSend(evt);
}
};

Expand Down Expand Up @@ -452,56 +449,70 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
`;
}

private readonly _send = async () => {
const { images, status } = this.chatContext;
if (status === 'loading' || status === 'transmitting') return;
private readonly _onTextareaSend = async (e: MouseEvent | KeyboardEvent) => {
e.preventDefault();
e.stopPropagation();

const text = this.textarea.value;
if (!text && !images.length) {
return;
}
const value = this.textarea.value.trim();
if (value.length === 0) return;

const { doc } = this.host;
this.textarea.value = '';
this._isInputEmpty = true;
this.textarea.style.height = 'unset';
this.updateContext({
images: [],
status: 'loading',
error: null,
});

const attachments = await Promise.all(
images?.map(image => readBlobAsURL(image))
);

const userInfo = await AIProvider.userInfo;
this.updateContext({
messages: [
...this.chatContext.messages,
{
id: '',
content: text,
role: 'user',
createdAt: new Date().toISOString(),
attachments,
userId: userInfo?.id,
userName: userInfo?.name,
avatarUrl: userInfo?.avatarUrl ?? undefined,
},
{
id: '',
content: '',
role: 'assistant',
createdAt: new Date().toISOString(),
},
],
});
await this._send(value);
};

const { currentChatBlockId, currentSessionId } = this.chatContext;
let content = '';
private readonly _send = async (text: string) => {
const { images, status, currentChatBlockId, currentSessionId } =
this.chatContext;
const chatBlockExists = !!currentChatBlockId;
let content = '';

if (status === 'loading' || status === 'transmitting') return;
if (!text) return;

try {
const { doc } = this.host;
const promptName = this._getPromptName();

this.updateContext({
images: [],
status: 'loading',
error: null,
});

const attachments = await Promise.all(
images?.map(image => readBlobAsURL(image))
);

const userInfo = await AIProvider.userInfo;
this.updateContext({
messages: [
...this.chatContext.messages,
{
id: '',
content: text,
role: 'user',
createdAt: new Date().toISOString(),
attachments,
userId: userInfo?.id,
userName: userInfo?.name,
avatarUrl: userInfo?.avatarUrl ?? undefined,
},
{
id: '',
content: '',
role: 'assistant',
createdAt: new Date().toISOString(),
},
],
});

// must update prompt name after local chat message is updated
// otherwise, the unauthorized error can not be rendered properly
await this._updatePromptName(promptName);

// If has not forked a chat session, fork a new one
let chatSessionId = currentSessionId;
if (!chatSessionId) {
Expand All @@ -518,8 +529,6 @@ export class ChatBlockInput extends SignalWatcher(LitElement) {
chatSessionId = forkSessionId;
}

await this._updatePromptName();

const abortController = new AbortController();
const stream = AIProvider.actions.chat?.({
input: text,
Expand Down
39 changes: 39 additions & 0 deletions tests/affine-cloud-copilot/e2e/copilot.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,45 @@ test.describe('chat panel', () => {
);
});

test('can identify shape color, even if network search is active', async ({
page,
}) => {
await page.reload();
await clickSideBarAllPageButton(page);
await page.waitForTimeout(200);
await createLocalWorkspace({ name: 'test' }, page);
await clickNewPageButton(page);

await openChat(page);
await page.getByTestId('chat-network-search').click();

await switchToEdgelessMode(page);

const shapeButton = await page.waitForSelector(
'edgeless-shape-tool-button'
);
await shapeButton.click();
await page.mouse.click(400, 400);

const askAIButton = await page.waitForSelector('.copilot-icon-button');
await askAIButton.click();

await page.waitForTimeout(1000);
await page.keyboard.type('What color is this shape?');
await page.keyboard.press('Enter');

const history = await collectChat(page);
expect(history[0]).toEqual({
name: 'You',
content: 'What color is this shape?',
});
expect(history[1].name).toBe('AFFiNE AI');
expect(history[1].content).toContain('yellow');
expect(await page.locator('chat-panel affine-footnote-node').count()).toBe(
0
);
});

test('can trigger inline ai input and action panel by clicking Start with AI button', async ({
page,
}) => {
Expand Down

0 comments on commit 015452e

Please sign in to comment.