Skip to content
This repository was archived by the owner on Dec 7, 2021. It is now read-only.

Commit c10c971

Browse files
committed
feat: CNTK Export Provider (#771)
Adds CNTK export provider into v2 Resolves #754
1 parent 0fe6386 commit c10c971

15 files changed

+353
-28
lines changed

src/common/localization/en-us.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ export const english: IAppStrings = {
287287
tagged: "Only tagged Assets",
288288
},
289289
},
290+
testTrainSplit: {
291+
title: "Test / Train Split",
292+
description: "The test train split to use for exported data",
293+
},
290294
},
291295
},
292296
vottJson: {
@@ -344,15 +348,14 @@ export const english: IAppStrings = {
344348
},
345349
pascalVoc: {
346350
displayName: "Pascal VOC",
347-
testTrainSplit: {
348-
title: "Test / Train Split",
349-
description: "The test train split to use for exported data",
350-
},
351351
exportUnassigned: {
352352
title: "Export Unassigned",
353353
description: "Whether or not to include unassigned tags in exported data",
354354
},
355355
},
356+
cntk: {
357+
displayName: "Microsoft Cognitive Toolkit (CNTK)",
358+
},
356359
},
357360
messages: {
358361
saveSuccess: "Successfully saved export settings",

src/common/localization/es-cl.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,10 @@ export const spanish: IAppStrings = {
289289
tagged: "Solo activos etiquetados",
290290
},
291291
},
292+
testTrainSplit: {
293+
title: "La división para entrenar y comprobar",
294+
description: "La división de datos para utilizar entre el entrenamiento y la comprobación",
295+
},
292296
},
293297
},
294298
vottJson: {
@@ -346,15 +350,14 @@ export const spanish: IAppStrings = {
346350
},
347351
pascalVoc: {
348352
displayName: "Pascal VOC",
349-
testTrainSplit: {
350-
title: "Prueba/tren Split",
351-
description: "La división del tren de prueba que se utilizará para los datos exportados",
352-
},
353353
exportUnassigned: {
354354
title: "Exportar sin asignar",
355355
description: "Si se incluyen o no etiquetas no asignadas en los datos exportados",
356356
},
357357
},
358+
cntk: {
359+
displayName: "Microsoft Cognitive Toolkit (CNTK)",
360+
},
358361
},
359362
messages: {
360363
saveSuccess: "Configuración de exportación guardada correctamente",

src/common/strings.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ export interface IAppStrings {
285285
tagged: string,
286286
},
287287
},
288+
testTrainSplit: {
289+
title: string,
290+
description: string,
291+
},
288292
},
289293
},
290294
vottJson: {
@@ -342,15 +346,14 @@ export interface IAppStrings {
342346
},
343347
pascalVoc: {
344348
displayName: string,
345-
testTrainSplit: {
346-
title: string,
347-
description: string,
348-
},
349349
exportUnassigned: {
350350
title: string,
351351
description: string,
352352
},
353353
},
354+
cntk: {
355+
displayName: string,
356+
},
354357
},
355358
messages: {
356359
saveSuccess: string;

src/providers/export/azureCustomVision.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { ExportProviderFactory } from "./exportProviderFactory";
66
import MockFactory from "../../common/mockFactory";
77
import {
88
IProject, AssetState, IAsset, IAssetMetadata,
9-
RegionType, IRegion, IExportProviderOptions, AssetType,
9+
RegionType, IRegion, IExportProviderOptions,
1010
} from "../../models/applicationState";
1111
import { ExportAssetState } from "./exportProvider";
1212
jest.mock("./azureCustomVision/azureCustomVisionService");

src/providers/export/cntk.json

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"type": "object",
3+
"title": "${strings.export.providers.cntk.displayName}",
4+
"properties": {
5+
"assetState": {
6+
"type": "string",
7+
"title": "${strings.export.providers.common.properties.assetState.title}",
8+
"description": "${strings.export.providers.common.properties.assetState.description}",
9+
"enum": [
10+
"all",
11+
"visited",
12+
"tagged"
13+
],
14+
"default": "visited",
15+
"enumNames": [
16+
"${strings.export.providers.common.properties.assetState.options.all}",
17+
"${strings.export.providers.common.properties.assetState.options.visited}",
18+
"${strings.export.providers.common.properties.assetState.options.tagged}"
19+
]
20+
},
21+
"testTrainSplit": {
22+
"title": "${strings.export.providers.common.properties.testTrainSplit.title}",
23+
"description": "${strings.export.providers.common.properties.testTrainSplit.description}",
24+
"type": "number",
25+
"default": 80
26+
}
27+
}
28+
}

src/providers/export/cntk.test.ts

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import _ from "lodash";
2+
import os from "os";
3+
import { CntkExportProvider, ICntkExportProviderOptions } from "./cntk";
4+
import { IProject, AssetState, IAssetMetadata } from "../../models/applicationState";
5+
import { AssetProviderFactory } from "../storage/assetProviderFactory";
6+
import { ExportAssetState } from "./exportProvider";
7+
import MockFactory from "../../common/mockFactory";
8+
import registerMixins from "../../registerMixins";
9+
import registerProviders from "../../registerProviders";
10+
import { ExportProviderFactory } from "./exportProviderFactory";
11+
jest.mock("../../services/assetService");
12+
import { AssetService } from "../../services/assetService";
13+
14+
jest.mock("../storage/localFileSystemProxy");
15+
import { LocalFileSystemProxy } from "../storage/localFileSystemProxy";
16+
import HtmlFileReader from "../../common/htmlFileReader";
17+
import { appInfo } from "../../common/appInfo";
18+
19+
describe("CNTK Export Provider", () => {
20+
const testAssets = MockFactory.createTestAssets(10, 1);
21+
let testProject: IProject = null;
22+
23+
const defaultOptions: ICntkExportProviderOptions = {
24+
assetState: ExportAssetState.Tagged,
25+
testTrainSplit: 80,
26+
};
27+
28+
function createProvider(project: IProject): CntkExportProvider {
29+
return new CntkExportProvider(
30+
project,
31+
project.exportFormat.providerOptions as ICntkExportProviderOptions,
32+
);
33+
}
34+
35+
beforeAll(() => {
36+
registerMixins();
37+
registerProviders();
38+
39+
HtmlFileReader.getAssetBlob = jest.fn(() => {
40+
return Promise.resolve(new Blob(["Some binary data"]));
41+
});
42+
});
43+
44+
beforeEach(() => {
45+
jest.resetAllMocks();
46+
47+
testAssets.forEach((asset) => {
48+
asset.state = AssetState.Tagged;
49+
});
50+
51+
testProject = {
52+
...MockFactory.createTestProject("TestProject"),
53+
assets: _.keyBy(testAssets, (a) => a.id),
54+
exportFormat: {
55+
providerType: "cntk",
56+
providerOptions: defaultOptions,
57+
},
58+
};
59+
60+
AssetProviderFactory.create = jest.fn(() => {
61+
return {
62+
getAssets: jest.fn(() => Promise.resolve(testAssets)),
63+
};
64+
});
65+
66+
const assetServiceMock = AssetService as jest.Mocked<typeof AssetService>;
67+
assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => {
68+
const assetMetadata = {
69+
asset: { ...asset },
70+
regions: [
71+
MockFactory.createTestRegion("region-1", ["tag1"]),
72+
MockFactory.createTestRegion("region-2", ["tag1"]),
73+
],
74+
version: appInfo.version,
75+
};
76+
77+
return Promise.resolve(assetMetadata);
78+
});
79+
});
80+
81+
it("Is defined", () => {
82+
expect(CntkExportProvider).toBeDefined();
83+
});
84+
85+
it("Can be instantiated through the factory", () => {
86+
const options: ICntkExportProviderOptions = {
87+
assetState: ExportAssetState.All,
88+
testTrainSplit: 80,
89+
};
90+
const exportProvider = ExportProviderFactory.create("cntk", testProject, options);
91+
expect(exportProvider).not.toBeNull();
92+
expect(exportProvider).toBeInstanceOf(CntkExportProvider);
93+
});
94+
95+
it("Creates correct folder structure", async () => {
96+
const provider = createProvider(testProject);
97+
await provider.export();
98+
99+
const storageProviderMock = LocalFileSystemProxy as any;
100+
const createContainerCalls = storageProviderMock.mock.instances[0].createContainer.mock.calls;
101+
const createContainerArgs = createContainerCalls.map((args) => args[0]);
102+
103+
const expectedFolderPath = "Project-TestProject-CNTK-export";
104+
expect(createContainerArgs).toContain(expectedFolderPath);
105+
expect(createContainerArgs).toContain(`${expectedFolderPath}/positive`);
106+
expect(createContainerArgs).toContain(`${expectedFolderPath}/negative`);
107+
expect(createContainerArgs).toContain(`${expectedFolderPath}/testImages`);
108+
});
109+
110+
it("Writes export files to storage provider", async () => {
111+
const provider = createProvider(testProject);
112+
const getAssetsSpy = jest.spyOn(provider, "getAssetsForExport");
113+
114+
await provider.export();
115+
116+
const assetsToExport = await getAssetsSpy.mock.results[0].value;
117+
const testSplit = (100 - (defaultOptions.testTrainSplit || 80)) / 100;
118+
const testCount = Math.ceil(assetsToExport.length * testSplit);
119+
const testArray = assetsToExport.slice(0, testCount);
120+
const trainArray = assetsToExport.slice(testCount, assetsToExport.length);
121+
122+
const storageProviderMock = LocalFileSystemProxy as any;
123+
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;
124+
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls;
125+
126+
expect(writeBinaryCalls).toHaveLength(testAssets.length);
127+
expect(writeTextFileCalls).toHaveLength(testAssets.length * 2);
128+
129+
testArray.forEach((assetMetadata) => {
130+
const testFolderPath = "Project-TestProject-CNTK-export/testImages";
131+
assertExportedAsset(testFolderPath, assetMetadata);
132+
});
133+
134+
trainArray.forEach((assetMetadata) => {
135+
const trainFolderPath = "Project-TestProject-CNTK-export/positive";
136+
assertExportedAsset(trainFolderPath, assetMetadata);
137+
});
138+
});
139+
140+
function assertExportedAsset(folderPath: string, assetMetadata: IAssetMetadata) {
141+
const storageProviderMock = LocalFileSystemProxy as any;
142+
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;
143+
const writeBinaryFilenameArgs = writeBinaryCalls.map((args) => args[0]);
144+
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls;
145+
const writeTextFilenameArgs = writeTextFileCalls.map((args) => args[0]);
146+
147+
expect(writeBinaryFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}`);
148+
expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.labels.tsv`);
149+
expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.tsv`);
150+
151+
const writeLabelsCall = writeTextFileCalls
152+
.find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.labels.tsv`) >= 0);
153+
154+
const writeBoxesCall = writeTextFileCalls
155+
.find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.tsv`) >= 0);
156+
157+
const expectedLabelData = `${assetMetadata.regions[0].tags[0]}${os.EOL}${assetMetadata.regions[1].tags[0]}`;
158+
expect(writeLabelsCall[1]).toEqual(expectedLabelData);
159+
160+
const expectedBoxData = [];
161+
// tslint:disable-next-line:max-line-length
162+
expectedBoxData.push(`${assetMetadata.regions[0].boundingBox.left}\t${assetMetadata.regions[0].boundingBox.left + assetMetadata.regions[0].boundingBox.width}\t${assetMetadata.regions[0].boundingBox.top}\t${assetMetadata.regions[0].boundingBox.top + assetMetadata.regions[0].boundingBox.height}`);
163+
// tslint:disable-next-line:max-line-length
164+
expectedBoxData.push(`${assetMetadata.regions[1].boundingBox.left}\t${assetMetadata.regions[1].boundingBox.left + assetMetadata.regions[1].boundingBox.width}\t${assetMetadata.regions[1].boundingBox.top}\t${assetMetadata.regions[1].boundingBox.top + assetMetadata.regions[1].boundingBox.height}`);
165+
expect(writeBoxesCall[1]).toEqual(expectedBoxData.join(os.EOL));
166+
}
167+
});

0 commit comments

Comments
 (0)