|
| 1 | +// Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file |
| 2 | +// for details. All rights reserved. Use of this source code is governed by a |
| 3 | +// BSD-style license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +import 'dart:async'; |
| 6 | +import 'dart:isolate'; |
| 7 | +import 'dart:math'; |
| 8 | + |
| 9 | +import 'package:analysis_server/src/provisional/completion/dart/completion_dart.dart'; |
| 10 | +import 'package:analysis_server/src/services/completion/dart/completion_ranking_internal.dart'; |
| 11 | +import 'package:analysis_server/src/protocol_server.dart'; |
| 12 | +import 'package:analysis_server/src/services/completion/dart/language_model.dart'; |
| 13 | +import 'package:analyzer/dart/analysis/features.dart'; |
| 14 | +import 'package:analyzer/dart/ast/token.dart'; |
| 15 | + |
| 16 | +// Minimum probability to prioritize model-only suggestion. |
| 17 | +const double _MODEL_RELEVANCE_CUTOFF = 0.5; |
| 18 | +// Maximum [AvailableSuggestionSet] relevance to account for. |
| 19 | +const int _MAX_BASE_RELEVANCE = 9; |
| 20 | +// Number of lookback tokens. |
| 21 | +const int _LOOKBACK = 100; |
| 22 | + |
| 23 | +/// Prediction service run by the model isolate. |
| 24 | +void entrypoint(SendPort sendPort) { |
| 25 | + LanguageModel model; |
| 26 | + final port = ReceivePort(); |
| 27 | + sendPort.send(port.sendPort); |
| 28 | + port.listen((message) { |
| 29 | + Map<String, dynamic> response = {}; |
| 30 | + switch (message['method']) { |
| 31 | + case 'load': |
| 32 | + model = LanguageModel.load(message['args'][0]); |
| 33 | + break; |
| 34 | + case 'predict': |
| 35 | + response['data'] = model.predictWithScores(message['args']); |
| 36 | + break; |
| 37 | + } |
| 38 | + |
| 39 | + message['port'].send(response); |
| 40 | + }); |
| 41 | +} |
| 42 | + |
| 43 | +class CompletionRanking { |
| 44 | + // Singleton instance. |
| 45 | + static CompletionRanking instance; |
| 46 | + |
| 47 | + // Filesystem location of model files. |
| 48 | + final String _directory; |
| 49 | + |
| 50 | + // Isolate in which to make tflite model predictions. |
| 51 | + Isolate _isolate; |
| 52 | + |
| 53 | + // Port to communicate from main to model isolate. |
| 54 | + SendPort _write; |
| 55 | + |
| 56 | + CompletionRanking(this._directory); |
| 57 | + |
| 58 | + /// Spins up the model isolate and tells it to load the tflite model. |
| 59 | + Future<void> start() async { |
| 60 | + final port = ReceivePort(); |
| 61 | + this._isolate = await Isolate.spawn(entrypoint, port.sendPort); |
| 62 | + this._write = await port.first; |
| 63 | + await makeRequest('load', [_directory]); |
| 64 | + } |
| 65 | + |
| 66 | + /// Makes a next-token prediction starting at the completion request |
| 67 | + /// cursor and walking back to find previous input tokens. |
| 68 | + Future<Map<String, double>> predict(DartCompletionRequest request) { |
| 69 | + final query = constructQuery(request, _LOOKBACK); |
| 70 | + if (query == null) { |
| 71 | + return Future.value(null); |
| 72 | + } |
| 73 | + |
| 74 | + return makeRequest('predict', query); |
| 75 | + } |
| 76 | + |
| 77 | + /// Transforms [CompletionSuggestion] relevances and |
| 78 | + /// [IncludedSuggestionRelevanceTag] relevanceBoosts based on language model |
| 79 | + /// predicted next-token probability distribution. |
| 80 | + Future<List<CompletionSuggestion>> rerank( |
| 81 | + Future<Map<String, double>> probabilityFuture, |
| 82 | + List<CompletionSuggestion> suggestions, |
| 83 | + List<IncludedSuggestionRelevanceTag> includedSuggestionRelevanceTags, |
| 84 | + DartCompletionRequest request, |
| 85 | + FeatureSet featureSet) async { |
| 86 | + final probability = await probabilityFuture |
| 87 | + .timeout(const Duration(milliseconds: 500), onTimeout: () => null); |
| 88 | + if (probability == null || probability.isEmpty) { |
| 89 | + // Failed to compute probability distribution, don't rerank. |
| 90 | + return suggestions; |
| 91 | + } |
| 92 | + |
| 93 | + // Discard the type-based relevance boosts. |
| 94 | + if (includedSuggestionRelevanceTags != null) { |
| 95 | + includedSuggestionRelevanceTags.forEach((tag) { |
| 96 | + tag.relevanceBoost = 0; |
| 97 | + }); |
| 98 | + } |
| 99 | + |
| 100 | + // Intersection between static analysis and model suggestions. |
| 101 | + var middle = DART_RELEVANCE_HIGH + probability.length; |
| 102 | + // Up to one suggestion from model with very high confidence. |
| 103 | + var high = middle + probability.length; |
| 104 | + // Lower relevance, model-only suggestions (perhaps literals). |
| 105 | + var low = DART_RELEVANCE_LOW - 1; |
| 106 | + |
| 107 | + List<MapEntry> entries = probability.entries.toList() |
| 108 | + ..sort((a, b) => b.value.compareTo(a.value)); |
| 109 | + |
| 110 | + // If completion is requested inside of quotes, since static analysis does |
| 111 | + // not suggest string literals, only return completion suggestions from |
| 112 | + // model which are string literal. |
| 113 | + if (testInsideQuotes(request)) { |
| 114 | + return entries |
| 115 | + .where((MapEntry entry) => |
| 116 | + isStringLiteral(entry.key) && isNotWhitespace(entry.key)) |
| 117 | + .map<CompletionSuggestion>((MapEntry entry) => |
| 118 | + createCompletionSuggestion( |
| 119 | + entry.key.replaceAll("'", ''), featureSet, low--)) |
| 120 | + .toList(); |
| 121 | + } |
| 122 | + |
| 123 | + // If analysis server thinks this is a declaration context, |
| 124 | + // remove all of the model-suggested literals. |
| 125 | + // TODO(lambdabaa): Ask Brian for help leveraging |
| 126 | + // SimpleIdentifier#inDeclarationContext. |
| 127 | + if (request.opType.includeVarNameSuggestions && |
| 128 | + suggestions.every((CompletionSuggestion suggestion) => |
| 129 | + suggestion.kind == CompletionSuggestionKind.IDENTIFIER)) { |
| 130 | + entries.retainWhere((MapEntry entry) => !isLiteral(entry.key)); |
| 131 | + } |
| 132 | + |
| 133 | + var isRequestFollowingDot = testFollowingDot(request); |
| 134 | + entries.forEach((MapEntry entry) { |
| 135 | + // There may be multiple like |
| 136 | + // CompletionSuggestion and CompletionSuggestion(). |
| 137 | + final completionSuggestions = suggestions.where((suggestion) => |
| 138 | + areCompletionsEquivalent(suggestion.completion, entry.key)); |
| 139 | + List<IncludedSuggestionRelevanceTag> includedSuggestions; |
| 140 | + if (includedSuggestionRelevanceTags != null) { |
| 141 | + includedSuggestions = includedSuggestionRelevanceTags |
| 142 | + .where((tag) => areCompletionsEquivalent(tag.tag, entry.key)) |
| 143 | + .toList(); |
| 144 | + } else { |
| 145 | + includedSuggestions = []; |
| 146 | + } |
| 147 | + if (!isRequestFollowingDot && entry.value > _MODEL_RELEVANCE_CUTOFF) { |
| 148 | + if (completionSuggestions.isNotEmpty || |
| 149 | + includedSuggestions.isNotEmpty) { |
| 150 | + final relevance = high--; |
| 151 | + completionSuggestions.forEach((completionSuggestion) { |
| 152 | + completionSuggestion.relevance = relevance; |
| 153 | + }); |
| 154 | + includedSuggestions.forEach((includedSuggestion) { |
| 155 | + includedSuggestion.relevanceBoost = relevance - _MAX_BASE_RELEVANCE; |
| 156 | + }); |
| 157 | + } else { |
| 158 | + suggestions |
| 159 | + .add(createCompletionSuggestion(entry.key, featureSet, high--)); |
| 160 | + } |
| 161 | + } else if (completionSuggestions.isNotEmpty || |
| 162 | + includedSuggestions.isNotEmpty) { |
| 163 | + final relevance = middle--; |
| 164 | + completionSuggestions.forEach((completionSuggestion) { |
| 165 | + completionSuggestion.relevance = relevance; |
| 166 | + }); |
| 167 | + includedSuggestions.forEach((includedSuggestion) { |
| 168 | + includedSuggestion.relevanceBoost = relevance - _MAX_BASE_RELEVANCE; |
| 169 | + }); |
| 170 | + } else if (!isRequestFollowingDot) { |
| 171 | + suggestions |
| 172 | + .add(createCompletionSuggestion(entry.key, featureSet, low--)); |
| 173 | + } |
| 174 | + }); |
| 175 | + |
| 176 | + return suggestions; |
| 177 | + } |
| 178 | + |
| 179 | + /// Send an RPC to the isolate worker and wait for it to respond. |
| 180 | + Future<Map<String, dynamic>> makeRequest( |
| 181 | + String method, List<String> args) async { |
| 182 | + final port = ReceivePort(); |
| 183 | + _write.send({ |
| 184 | + 'method': method, |
| 185 | + 'args': args, |
| 186 | + 'port': port.sendPort, |
| 187 | + }); |
| 188 | + return await port.first; |
| 189 | + } |
| 190 | +} |
0 commit comments