diff --git a/src/Nncase.Compiler/Compiler.cs b/src/Nncase.Compiler/Compiler.cs index 5a0efa3567..d3b84359c3 100644 --- a/src/Nncase.Compiler/Compiler.cs +++ b/src/Nncase.Compiler/Compiler.cs @@ -244,47 +244,34 @@ public void ClearFixShape(IPassManager p) }); } - public async Task CompileWithReportAsync(IProgress progress, CancellationToken token) - { - CancellationTokenSource cts = new(); - var internalToken = cts.Token; - using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(internalToken, token)) - { - try - { - var task = Task.Run(CompileAsync, linkedCts.Token); - Report(progress, 9, linkedCts.Token); - await task.WaitAsync(linkedCts.Token); - } - catch (Exception) - { - return; - } - } - } - - public async Task CompileAsync() + public async Task CompileAsync(IProgress? progress = null, CancellationToken token = default) { var target = _compileSession.Target; - await RunPassAsync(p => TargetIndependentPass(p), "TargetIndependentPass"); - await RunPassAsync(p => RegisterTargetIndependQuantPass(p), "TargetIndependentQuantPass"); + await RunPassAsync(p => TargetIndependentPass(p), "TargetIndependentPass", progress, token); + await RunPassAsync(p => RegisterTargetIndependQuantPass(p), "TargetIndependentQuantPass", progress, token); if (_compileSession.CompileOptions.ShapeBucketOptions.Enable) { - await RunPassAsync(p => RegisterShapeBucket(p), "ShapeBucket"); - await RunPassAsync(p => TargetIndependentPass(p), "TargetIndependentPass"); + await RunPassAsync(p => RegisterShapeBucket(p), "ShapeBucket", progress, token); + await RunPassAsync(p => TargetIndependentPass(p), "TargetIndependentPass", progress, token); } await RunPassAsync( p => target.RegisterTargetDependentPass(p, _compileSession.CompileOptions), - "TargetDependentPass"); - await RunPassAsync(p => target.RegisterQuantizePass(p, _compileSession.CompileOptions), "QuantizePass"); + "TargetDependentPass", + progress, + token); + await RunPassAsync(p => target.RegisterQuantizePass(p, _compileSession.CompileOptions), "QuantizePass", progress, token); await RunPassAsync( p => target.RegisterTargetDependentAfterQuantPass(p, _compileSession.CompileOptions), - "TargetDependentAfterQuantPass"); - await RunPassAsync(p => ClearFixShape(p), "ClearFixShape"); + "TargetDependentAfterQuantPass", + progress, + token); + await RunPassAsync(p => ClearFixShape(p), "ClearFixShape", progress, token); await RunPassAsync( p => target.RegisterTargetDependentBeforeCodeGen(p, _compileSession.CompileOptions), - "TargetDependentBeforeCodeGen"); + "TargetDependentBeforeCodeGen", + progress, + token); if (_dumpper.IsEnabled(DumpFlags.Compile)) { DumpScope.Current.DumpModule(_module!, "ModuleAfterCompile"); @@ -297,14 +284,6 @@ public void Gencode(Stream output) linkedModel.Serialize(output); } - private void Report(IProgress progress, int maxPassCount, CancellationToken token) - { - while (_runPassCount < maxPassCount && !token.IsCancellationRequested) - { - progress?.Report(_runPassCount); - } - } - private async Task InitializeModuleAsync(IRModule module) { _module = module; @@ -342,7 +321,7 @@ private void RegisterTargetIndependQuantPass(IPassManager passManager) } } - private async Task RunPassAsync(Action register, string name) + private async Task RunPassAsync(Action register, string name, IProgress? progress = null, CancellationToken token = default) { var newName = $"{_runPassCount++}_" + name; var pmgr = _compileSession.CreatePassManager(newName); @@ -354,5 +333,8 @@ private async Task RunPassAsync(Action register, string name) _dumpper.DumpModule(_module, newName); _dumpper.DumpDotIR(_module.Entry!, newName); } + + progress?.Report(_runPassCount); + token.ThrowIfCancellationRequested(); } } diff --git a/src/Nncase.Core/ICompiler.cs b/src/Nncase.Core/ICompiler.cs index 8aa049f215..d9978b92b5 100644 --- a/src/Nncase.Core/ICompiler.cs +++ b/src/Nncase.Core/ICompiler.cs @@ -47,12 +47,7 @@ public interface ICompiler /// Compile module. /// /// A representing the asynchronous operation. - Task CompileAsync(); - - /// - /// Compile module with report pass number. - /// - Task CompileWithReportAsync(IProgress progress, CancellationToken token); + Task CompileAsync(IProgress? progress = null, CancellationToken token = default); /// /// Generate code to stream. diff --git a/src/Nncase.Studio/Util/CompileConfig.cs b/src/Nncase.Studio/Util/CompileConfig.cs index 491c0f2a90..61e4402c1a 100644 --- a/src/Nncase.Studio/Util/CompileConfig.cs +++ b/src/Nncase.Studio/Util/CompileConfig.cs @@ -5,6 +5,7 @@ using System.IO; using System.Runtime.InteropServices.JavaScript; using Nncase.Diagnostics; +using Nncase.Quantization; using Nncase.Studio.ViewModels; namespace Nncase.Studio.Util; @@ -26,6 +27,8 @@ public CompileConfig() CompileOption.Mean = new[] { 0f }; CompileOption.Std = new[] { 0f }; CompileOption.LetterBoxValue = 0f; + UseQuantize = true; + CompileOption.QuantizeOptions.ModelQuantMode = ModelQuantMode.UsePTQ; } public CompileOptions CompileOption { get; set; } = new(); diff --git a/src/Nncase.Studio/Util/DataUtil.cs b/src/Nncase.Studio/Util/DataUtil.cs index 4b0d1f6271..3144a5c518 100644 --- a/src/Nncase.Studio/Util/DataUtil.cs +++ b/src/Nncase.Studio/Util/DataUtil.cs @@ -14,6 +14,7 @@ using Avalonia.Platform.Storage; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; +using DynamicData; using Nncase.IR; using Nncase.Quantization; using Nncase.Studio.Views; @@ -76,6 +77,11 @@ public static bool TryParseFixVarMap(string input, out Dictionary m try { + if (!input.Contains(":", StringComparison.Ordinal)) + { + return false; + } + map = input.Trim().Split(",").Select(x => x.Trim().Split(":")).ToDictionary(x => x[0], x => int.Parse(x[1])); return true; } @@ -101,7 +107,7 @@ public static bool TryParseRangeInfo(string input, out Dictionary x[0], x => { var pair = x[1].Split(","); - return (int.Parse(pair[0]), int.Parse(pair[1])); + return (int.Parse(pair[0].Trim('(')), int.Parse(pair[1].Trim(')'))); }); return true; } diff --git a/src/Nncase.Studio/ViewModels/CompileViewModel.cs b/src/Nncase.Studio/ViewModels/CompileViewModel.cs index ad7fa5dbdd..918bcf2210 100644 --- a/src/Nncase.Studio/ViewModels/CompileViewModel.cs +++ b/src/Nncase.Studio/ViewModels/CompileViewModel.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Threading.Tasks; using System.Windows.Input; +using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using Nncase.IR; @@ -44,15 +45,15 @@ public Task CancelCompile() [RelayCommand] public async Task Compile() { - var info = Context.CheckViewModel(); - if (info.Length != 0) + _cts = new CancellationTokenSource(); + var conf = Context.CompileConfig; + var options = conf.CompileOption; + if (!File.Exists(options.InputFile)) { - Context.OpenDialog($"Error List:\n{string.Join("\n", info)}"); + Context.OpenDialog($"InputFile {options.InputFile} not found"); return; } - var conf = Context.CompileConfig; - var options = conf.CompileOption; if (!Directory.Exists(options.DumpDir)) { Directory.CreateDirectory(options.DumpDir); @@ -87,17 +88,18 @@ public async Task Compile() _cts = new(); ProgressBarMax = 9; + var progress = new Progress(percent => { - ProgressBarValue = percent; + Dispatcher.UIThread.Post(() => + { + ProgressBarValue = percent; + }); }); try { - await Task.Run(async () => - { - await compiler.CompileWithReportAsync(progress, _cts.Token); - }).ContinueWith(_ => Task.CompletedTask, _cts.Token); + await Task.Run(async () => await compiler.CompileAsync(progress, _cts.Token)); } catch (Exception) { @@ -123,5 +125,6 @@ public override void UpdateConfig(CompileConfig config) public override void UpdateViewModelCore(CompileConfig config) { KmodelPath = config.KmodelPath; + ProgressBarValue = 0; } }