Skip to content

Commit fc82ed2

Browse files
committed
Isolate-based CompletionRanking
Screenshots: https://i.imgur.com/IOLKOCU.png, https://i.imgur.com/dM9VsNr.png Change-Id: Ieae0f9066b7a349bb4adff6b327b67bcae49aab2 Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/110880 Reviewed-by: Brian Wilkerson <brianwilkerson@google.com>
1 parent e5a6fca commit fc82ed2

File tree

6 files changed

+616
-0
lines changed

6 files changed

+616
-0
lines changed

pkg/analysis_server/lib/src/server/driver.dart

+11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import 'package:analysis_server/src/server/features.dart';
1818
import 'package:analysis_server/src/server/http_server.dart';
1919
import 'package:analysis_server/src/server/lsp_stdio_server.dart';
2020
import 'package:analysis_server/src/server/stdio_server.dart';
21+
import 'package:analysis_server/src/services/completion/dart/completion_ranking.dart';
2122
import 'package:analysis_server/src/services/completion/dart/uri_contributor.dart'
2223
show UriContributor;
2324
import 'package:analysis_server/src/socket_server.dart';
@@ -399,6 +400,16 @@ class Driver implements ServerStarter {
399400
return null;
400401
}
401402

403+
if (analysisServerOptions.completionModelFolder != null) {
404+
CompletionRanking.instance =
405+
CompletionRanking(analysisServerOptions.completionModelFolder);
406+
CompletionRanking.instance.start().catchError(() {
407+
// Disable smart ranking if model startup fails.
408+
analysisServerOptions.completionModelFolder = null;
409+
CompletionRanking.instance = null;
410+
});
411+
}
412+
402413
final defaultSdkPath = _getSdkPath(results);
403414
final dartSdkManager = new DartSdkManager(defaultSdkPath, true);
404415

pkg/analysis_server/lib/src/services/completion/dart/completion_manager.dart

+13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import 'package:analysis_server/src/services/completion/completion_performance.d
1313
import 'package:analysis_server/src/services/completion/dart/arglist_contributor.dart';
1414
import 'package:analysis_server/src/services/completion/dart/combinator_contributor.dart';
1515
import 'package:analysis_server/src/services/completion/dart/common_usage_sorter.dart';
16+
import 'package:analysis_server/src/services/completion/dart/completion_ranking.dart';
1617
import 'package:analysis_server/src/services/completion/dart/contribution_sorter.dart';
1718
import 'package:analysis_server/src/services/completion/dart/field_formal_contributor.dart';
1819
import 'package:analysis_server/src/services/completion/dart/imported_reference_contributor.dart';
@@ -93,6 +94,10 @@ class DartCompletionManager implements CompletionContributor {
9394
return const <CompletionSuggestion>[];
9495
}
9596

97+
final ranking = CompletionRanking.instance;
98+
Future<Map<String, double>> probabilityFuture =
99+
ranking != null ? ranking.predict(dartRequest) : Future.value(null);
100+
96101
SourceRange range =
97102
dartRequest.target.computeReplacementRange(dartRequest.offset);
98103
(request as CompletionRequestImpl)
@@ -173,6 +178,14 @@ class DartCompletionManager implements CompletionContributor {
173178
const SORT_TAG = 'DartCompletionManager - sort';
174179
performance.logStartTime(SORT_TAG);
175180
await contributionSorter.sort(dartRequest, suggestions);
181+
if (ranking != null) {
182+
suggestions = await ranking.rerank(
183+
probabilityFuture,
184+
suggestions,
185+
includedSuggestionRelevanceTags,
186+
dartRequest,
187+
request.result.unit.featureSet);
188+
}
176189
performance.logElapseTime(SORT_TAG);
177190
request.checkAborted();
178191
return suggestions;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
// Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file
2+
// for details. All rights reserved. Use of this source code is governed by a
3+
// BSD-style license that can be found in the LICENSE file.
4+
5+
import 'dart:async';
6+
import 'dart:isolate';
7+
import 'dart:math';
8+
9+
import 'package:analysis_server/src/provisional/completion/dart/completion_dart.dart';
10+
import 'package:analysis_server/src/services/completion/dart/completion_ranking_internal.dart';
11+
import 'package:analysis_server/src/protocol_server.dart';
12+
import 'package:analysis_server/src/services/completion/dart/language_model.dart';
13+
import 'package:analyzer/dart/analysis/features.dart';
14+
import 'package:analyzer/dart/ast/token.dart';
15+
16+
// Minimum probability to prioritize model-only suggestion.
17+
const double _MODEL_RELEVANCE_CUTOFF = 0.5;
18+
// Maximum [AvailableSuggestionSet] relevance to account for.
19+
const int _MAX_BASE_RELEVANCE = 9;
20+
// Number of lookback tokens.
21+
const int _LOOKBACK = 100;
22+
23+
/// Prediction service run by the model isolate.
24+
void entrypoint(SendPort sendPort) {
25+
LanguageModel model;
26+
final port = ReceivePort();
27+
sendPort.send(port.sendPort);
28+
port.listen((message) {
29+
Map<String, dynamic> response = {};
30+
switch (message['method']) {
31+
case 'load':
32+
model = LanguageModel.load(message['args'][0]);
33+
break;
34+
case 'predict':
35+
response['data'] = model.predictWithScores(message['args']);
36+
break;
37+
}
38+
39+
message['port'].send(response);
40+
});
41+
}
42+
43+
class CompletionRanking {
44+
// Singleton instance.
45+
static CompletionRanking instance;
46+
47+
// Filesystem location of model files.
48+
final String _directory;
49+
50+
// Isolate in which to make tflite model predictions.
51+
Isolate _isolate;
52+
53+
// Port to communicate from main to model isolate.
54+
SendPort _write;
55+
56+
CompletionRanking(this._directory);
57+
58+
/// Spins up the model isolate and tells it to load the tflite model.
59+
Future<void> start() async {
60+
final port = ReceivePort();
61+
this._isolate = await Isolate.spawn(entrypoint, port.sendPort);
62+
this._write = await port.first;
63+
await makeRequest('load', [_directory]);
64+
}
65+
66+
/// Makes a next-token prediction starting at the completion request
67+
/// cursor and walking back to find previous input tokens.
68+
Future<Map<String, double>> predict(DartCompletionRequest request) {
69+
final query = constructQuery(request, _LOOKBACK);
70+
if (query == null) {
71+
return Future.value(null);
72+
}
73+
74+
return makeRequest('predict', query);
75+
}
76+
77+
/// Transforms [CompletionSuggestion] relevances and
78+
/// [IncludedSuggestionRelevanceTag] relevanceBoosts based on language model
79+
/// predicted next-token probability distribution.
80+
Future<List<CompletionSuggestion>> rerank(
81+
Future<Map<String, double>> probabilityFuture,
82+
List<CompletionSuggestion> suggestions,
83+
List<IncludedSuggestionRelevanceTag> includedSuggestionRelevanceTags,
84+
DartCompletionRequest request,
85+
FeatureSet featureSet) async {
86+
final probability = await probabilityFuture
87+
.timeout(const Duration(milliseconds: 500), onTimeout: () => null);
88+
if (probability == null || probability.isEmpty) {
89+
// Failed to compute probability distribution, don't rerank.
90+
return suggestions;
91+
}
92+
93+
// Discard the type-based relevance boosts.
94+
if (includedSuggestionRelevanceTags != null) {
95+
includedSuggestionRelevanceTags.forEach((tag) {
96+
tag.relevanceBoost = 0;
97+
});
98+
}
99+
100+
// Intersection between static analysis and model suggestions.
101+
var middle = DART_RELEVANCE_HIGH + probability.length;
102+
// Up to one suggestion from model with very high confidence.
103+
var high = middle + probability.length;
104+
// Lower relevance, model-only suggestions (perhaps literals).
105+
var low = DART_RELEVANCE_LOW - 1;
106+
107+
List<MapEntry> entries = probability.entries.toList()
108+
..sort((a, b) => b.value.compareTo(a.value));
109+
110+
// If completion is requested inside of quotes, since static analysis does
111+
// not suggest string literals, only return completion suggestions from
112+
// model which are string literal.
113+
if (testInsideQuotes(request)) {
114+
return entries
115+
.where((MapEntry entry) =>
116+
isStringLiteral(entry.key) && isNotWhitespace(entry.key))
117+
.map<CompletionSuggestion>((MapEntry entry) =>
118+
createCompletionSuggestion(
119+
entry.key.replaceAll("'", ''), featureSet, low--))
120+
.toList();
121+
}
122+
123+
// If analysis server thinks this is a declaration context,
124+
// remove all of the model-suggested literals.
125+
// TODO(lambdabaa): Ask Brian for help leveraging
126+
// SimpleIdentifier#inDeclarationContext.
127+
if (request.opType.includeVarNameSuggestions &&
128+
suggestions.every((CompletionSuggestion suggestion) =>
129+
suggestion.kind == CompletionSuggestionKind.IDENTIFIER)) {
130+
entries.retainWhere((MapEntry entry) => !isLiteral(entry.key));
131+
}
132+
133+
var isRequestFollowingDot = testFollowingDot(request);
134+
entries.forEach((MapEntry entry) {
135+
// There may be multiple like
136+
// CompletionSuggestion and CompletionSuggestion().
137+
final completionSuggestions = suggestions.where((suggestion) =>
138+
areCompletionsEquivalent(suggestion.completion, entry.key));
139+
List<IncludedSuggestionRelevanceTag> includedSuggestions;
140+
if (includedSuggestionRelevanceTags != null) {
141+
includedSuggestions = includedSuggestionRelevanceTags
142+
.where((tag) => areCompletionsEquivalent(tag.tag, entry.key))
143+
.toList();
144+
} else {
145+
includedSuggestions = [];
146+
}
147+
if (!isRequestFollowingDot && entry.value > _MODEL_RELEVANCE_CUTOFF) {
148+
if (completionSuggestions.isNotEmpty ||
149+
includedSuggestions.isNotEmpty) {
150+
final relevance = high--;
151+
completionSuggestions.forEach((completionSuggestion) {
152+
completionSuggestion.relevance = relevance;
153+
});
154+
includedSuggestions.forEach((includedSuggestion) {
155+
includedSuggestion.relevanceBoost = relevance - _MAX_BASE_RELEVANCE;
156+
});
157+
} else {
158+
suggestions
159+
.add(createCompletionSuggestion(entry.key, featureSet, high--));
160+
}
161+
} else if (completionSuggestions.isNotEmpty ||
162+
includedSuggestions.isNotEmpty) {
163+
final relevance = middle--;
164+
completionSuggestions.forEach((completionSuggestion) {
165+
completionSuggestion.relevance = relevance;
166+
});
167+
includedSuggestions.forEach((includedSuggestion) {
168+
includedSuggestion.relevanceBoost = relevance - _MAX_BASE_RELEVANCE;
169+
});
170+
} else if (!isRequestFollowingDot) {
171+
suggestions
172+
.add(createCompletionSuggestion(entry.key, featureSet, low--));
173+
}
174+
});
175+
176+
return suggestions;
177+
}
178+
179+
/// Send an RPC to the isolate worker and wait for it to respond.
180+
Future<Map<String, dynamic>> makeRequest(
181+
String method, List<String> args) async {
182+
final port = ReceivePort();
183+
_write.send({
184+
'method': method,
185+
'args': args,
186+
'port': port.sendPort,
187+
});
188+
return await port.first;
189+
}
190+
}

0 commit comments

Comments
 (0)