Skip to content
This repository was archived by the owner on Dec 7, 2021. It is now read-only.

feat: CNTK Export Provider #771

Merged
merged 4 commits into from
Apr 17, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/common/localization/en-us.ts
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ export const english: IAppStrings = {
tagged: "Only tagged Assets",
},
},
testTrainSplit: {
title: "Test / Train Split",
description: "The test train split to use for exported data",
},
},
},
vottJson: {
Expand Down Expand Up @@ -344,15 +348,14 @@ export const english: IAppStrings = {
},
pascalVoc: {
displayName: "Pascal VOC",
testTrainSplit: {
title: "Test / Train Split",
description: "The test train split to use for exported data",
},
exportUnassigned: {
title: "Export Unassigned",
description: "Whether or not to include unassigned tags in exported data",
},
},
cntk: {
displayName: "Microsoft Cognitive Toolkit (CNTK)",
},
},
messages: {
saveSuccess: "Successfully saved export settings",
Expand Down
11 changes: 7 additions & 4 deletions src/common/localization/es-cl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ export const spanish: IAppStrings = {
tagged: "Solo activos etiquetados",
},
},
testTrainSplit: {
title: "Prueba/tren Split",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

La división para entrenar y comprobar

description: "La división del tren de prueba que se utilizará para los datos exportados",
Copy link
Contributor

@tbarlow12 tbarlow12 Apr 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

La división de datos para utilizar entre el entrenamiento y la comprobación

},
},
},
vottJson: {
Expand Down Expand Up @@ -346,15 +350,14 @@ export const spanish: IAppStrings = {
},
pascalVoc: {
displayName: "Pascal VOC",
testTrainSplit: {
title: "Prueba/tren Split",
description: "La división del tren de prueba que se utilizará para los datos exportados",
},
exportUnassigned: {
title: "Exportar sin asignar",
description: "Si se incluyen o no etiquetas no asignadas en los datos exportados",
},
},
cntk: {
displayName: "Microsoft Cognitive Toolkit (CNTK)",
},
},
messages: {
saveSuccess: "Configuración de exportación guardada correctamente",
Expand Down
11 changes: 7 additions & 4 deletions src/common/strings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ export interface IAppStrings {
tagged: string,
},
},
testTrainSplit: {
title: string,
description: string,
},
},
},
vottJson: {
Expand Down Expand Up @@ -342,15 +346,14 @@ export interface IAppStrings {
},
pascalVoc: {
displayName: string,
testTrainSplit: {
title: string,
description: string,
},
exportUnassigned: {
title: string,
description: string,
},
},
cntk: {
displayName: string,
},
},
messages: {
saveSuccess: string;
Expand Down
2 changes: 1 addition & 1 deletion src/providers/export/azureCustomVision.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { ExportProviderFactory } from "./exportProviderFactory";
import MockFactory from "../../common/mockFactory";
import {
IProject, AssetState, IAsset, IAssetMetadata,
RegionType, IRegion, IExportProviderOptions, AssetType,
RegionType, IRegion, IExportProviderOptions,
} from "../../models/applicationState";
import { ExportAssetState } from "./exportProvider";
jest.mock("./azureCustomVision/azureCustomVisionService");
Expand Down
28 changes: 28 additions & 0 deletions src/providers/export/cntk.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"type": "object",
"title": "${strings.export.providers.cntk.displayName}",
"properties": {
"assetState": {
"type": "string",
"title": "${strings.export.providers.common.properties.assetState.title}",
"description": "${strings.export.providers.common.properties.assetState.description}",
"enum": [
"all",
"visited",
"tagged"
],
"default": "visited",
"enumNames": [
"${strings.export.providers.common.properties.assetState.options.all}",
"${strings.export.providers.common.properties.assetState.options.visited}",
"${strings.export.providers.common.properties.assetState.options.tagged}"
]
},
"testTrainSplit": {
"title": "${strings.export.providers.common.properties.testTrainSplit.title}",
"description": "${strings.export.providers.common.properties.testTrainSplit.description}",
"type": "number",
"default": 80
}
}
}
167 changes: 167 additions & 0 deletions src/providers/export/cntk.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import _ from "lodash";
import os from "os";
import { CntkExportProvider, ICntkExportProviderOptions } from "./cntk";
import { IProject, AssetState, IAssetMetadata } from "../../models/applicationState";
import { AssetProviderFactory } from "../storage/assetProviderFactory";
import { ExportAssetState } from "./exportProvider";
import MockFactory from "../../common/mockFactory";
import registerMixins from "../../registerMixins";
import registerProviders from "../../registerProviders";
import { ExportProviderFactory } from "./exportProviderFactory";
jest.mock("../../services/assetService");
import { AssetService } from "../../services/assetService";

jest.mock("../storage/localFileSystemProxy");
import { LocalFileSystemProxy } from "../storage/localFileSystemProxy";
import HtmlFileReader from "../../common/htmlFileReader";
import { appInfo } from "../../common/appInfo";

describe("CNTK Export Provider", () => {
const testAssets = MockFactory.createTestAssets(10, 1);
let testProject: IProject = null;

const defaultOptions: ICntkExportProviderOptions = {
assetState: ExportAssetState.Tagged,
testTrainSplit: 80,
};

function createProvider(project: IProject): CntkExportProvider {
return new CntkExportProvider(
project,
project.exportFormat.providerOptions as ICntkExportProviderOptions,
);
}

beforeAll(() => {
registerMixins();
registerProviders();

HtmlFileReader.getAssetBlob = jest.fn(() => {
return Promise.resolve(new Blob(["Some binary data"]));
});
});

beforeEach(() => {
jest.resetAllMocks();

testAssets.forEach((asset) => {
asset.state = AssetState.Tagged;
});

testProject = {
...MockFactory.createTestProject("TestProject"),
assets: _.keyBy(testAssets, (a) => a.id),
exportFormat: {
providerType: "cntk",
providerOptions: defaultOptions,
},
};

AssetProviderFactory.create = jest.fn(() => {
return {
getAssets: jest.fn(() => Promise.resolve(testAssets)),
};
});

const assetServiceMock = AssetService as jest.Mocked<typeof AssetService>;
assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => {
const assetMetadata = {
asset: { ...asset },
regions: [
MockFactory.createTestRegion("region-1", ["tag1"]),
MockFactory.createTestRegion("region-2", ["tag1"]),
],
version: appInfo.version,
};

return Promise.resolve(assetMetadata);
});
});

it("Is defined", () => {
expect(CntkExportProvider).toBeDefined();
});

it("Can be instantiated through the factory", () => {
const options: ICntkExportProviderOptions = {
assetState: ExportAssetState.All,
testTrainSplit: 80,
};
const exportProvider = ExportProviderFactory.create("cntk", testProject, options);
expect(exportProvider).not.toBeNull();
expect(exportProvider).toBeInstanceOf(CntkExportProvider);
});

it("Creates correct folder structure", async () => {
const provider = createProvider(testProject);
await provider.export();

const storageProviderMock = LocalFileSystemProxy as any;
const createContainerCalls = storageProviderMock.mock.instances[0].createContainer.mock.calls;
const createContainerArgs = createContainerCalls.map((args) => args[0]);

const expectedFolderPath = "Project-TestProject-CNTK-export";
expect(createContainerArgs).toContain(expectedFolderPath);
expect(createContainerArgs).toContain(`${expectedFolderPath}/positive`);
expect(createContainerArgs).toContain(`${expectedFolderPath}/negative`);
expect(createContainerArgs).toContain(`${expectedFolderPath}/testImages`);
});

it("Writes export files to storage provider", async () => {
const provider = createProvider(testProject);
const getAssetsSpy = jest.spyOn(provider, "getAssetsForExport");

await provider.export();

const assetsToExport = await getAssetsSpy.mock.results[0].value;
const testSplit = (100 - (defaultOptions.testTrainSplit || 80)) / 100;
const testCount = Math.ceil(assetsToExport.length * testSplit);
const testArray = assetsToExport.slice(0, testCount);
const trainArray = assetsToExport.slice(testCount, assetsToExport.length);

const storageProviderMock = LocalFileSystemProxy as any;
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls;

expect(writeBinaryCalls).toHaveLength(testAssets.length);
expect(writeTextFileCalls).toHaveLength(testAssets.length * 2);

testArray.forEach((assetMetadata) => {
const testFolderPath = "Project-TestProject-CNTK-export/testImages";
assertExportedAsset(testFolderPath, assetMetadata);
});

trainArray.forEach((assetMetadata) => {
const trainFolderPath = "Project-TestProject-CNTK-export/positive";
assertExportedAsset(trainFolderPath, assetMetadata);
});
});

function assertExportedAsset(folderPath: string, assetMetadata: IAssetMetadata) {
const storageProviderMock = LocalFileSystemProxy as any;
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;
const writeBinaryFilenameArgs = writeBinaryCalls.map((args) => args[0]);
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls;
const writeTextFilenameArgs = writeTextFileCalls.map((args) => args[0]);

expect(writeBinaryFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}`);
expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.labels.tsv`);
expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.tsv`);

const writeLabelsCall = writeTextFileCalls
.find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.labels.tsv`) >= 0);

const writeBoxesCall = writeTextFileCalls
.find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.tsv`) >= 0);

const expectedLabelData = `${assetMetadata.regions[0].tags[0]}${os.EOL}${assetMetadata.regions[1].tags[0]}`;
expect(writeLabelsCall[1]).toEqual(expectedLabelData);

const expectedBoxData = [];
// tslint:disable-next-line:max-line-length
expectedBoxData.push(`${assetMetadata.regions[0].boundingBox.left}\t${assetMetadata.regions[0].boundingBox.left + assetMetadata.regions[0].boundingBox.width}\t${assetMetadata.regions[0].boundingBox.top}\t${assetMetadata.regions[0].boundingBox.top + assetMetadata.regions[0].boundingBox.height}`);
// tslint:disable-next-line:max-line-length
expectedBoxData.push(`${assetMetadata.regions[1].boundingBox.left}\t${assetMetadata.regions[1].boundingBox.left + assetMetadata.regions[1].boundingBox.width}\t${assetMetadata.regions[1].boundingBox.top}\t${assetMetadata.regions[1].boundingBox.top + assetMetadata.regions[1].boundingBox.height}`);
expect(writeBoxesCall[1]).toEqual(expectedBoxData.join(os.EOL));
}
});
Loading