|
| 1 | +import _ from "lodash"; |
| 2 | +import os from "os"; |
| 3 | +import { CntkExportProvider, ICntkExportProviderOptions } from "./cntk"; |
| 4 | +import { IProject, AssetState, IAssetMetadata } from "../../models/applicationState"; |
| 5 | +import { AssetProviderFactory } from "../storage/assetProviderFactory"; |
| 6 | +import { ExportAssetState } from "./exportProvider"; |
| 7 | +import MockFactory from "../../common/mockFactory"; |
| 8 | +import registerMixins from "../../registerMixins"; |
| 9 | +import registerProviders from "../../registerProviders"; |
| 10 | +import { ExportProviderFactory } from "./exportProviderFactory"; |
| 11 | +jest.mock("../../services/assetService"); |
| 12 | +import { AssetService } from "../../services/assetService"; |
| 13 | + |
| 14 | +jest.mock("../storage/localFileSystemProxy"); |
| 15 | +import { LocalFileSystemProxy } from "../storage/localFileSystemProxy"; |
| 16 | +import HtmlFileReader from "../../common/htmlFileReader"; |
| 17 | +import { appInfo } from "../../common/appInfo"; |
| 18 | + |
| 19 | +describe("CNTK Export Provider", () => { |
| 20 | + const testAssets = MockFactory.createTestAssets(10, 1); |
| 21 | + let testProject: IProject = null; |
| 22 | + |
| 23 | + const defaultOptions: ICntkExportProviderOptions = { |
| 24 | + assetState: ExportAssetState.Tagged, |
| 25 | + testTrainSplit: 80, |
| 26 | + }; |
| 27 | + |
| 28 | + function createProvider(project: IProject): CntkExportProvider { |
| 29 | + return new CntkExportProvider( |
| 30 | + project, |
| 31 | + project.exportFormat.providerOptions as ICntkExportProviderOptions, |
| 32 | + ); |
| 33 | + } |
| 34 | + |
| 35 | + beforeAll(() => { |
| 36 | + registerMixins(); |
| 37 | + registerProviders(); |
| 38 | + |
| 39 | + HtmlFileReader.getAssetBlob = jest.fn(() => { |
| 40 | + return Promise.resolve(new Blob(["Some binary data"])); |
| 41 | + }); |
| 42 | + }); |
| 43 | + |
| 44 | + beforeEach(() => { |
| 45 | + jest.resetAllMocks(); |
| 46 | + |
| 47 | + testAssets.forEach((asset) => { |
| 48 | + asset.state = AssetState.Tagged; |
| 49 | + }); |
| 50 | + |
| 51 | + testProject = { |
| 52 | + ...MockFactory.createTestProject("TestProject"), |
| 53 | + assets: _.keyBy(testAssets, (a) => a.id), |
| 54 | + exportFormat: { |
| 55 | + providerType: "cntk", |
| 56 | + providerOptions: defaultOptions, |
| 57 | + }, |
| 58 | + }; |
| 59 | + |
| 60 | + AssetProviderFactory.create = jest.fn(() => { |
| 61 | + return { |
| 62 | + getAssets: jest.fn(() => Promise.resolve(testAssets)), |
| 63 | + }; |
| 64 | + }); |
| 65 | + |
| 66 | + const assetServiceMock = AssetService as jest.Mocked<typeof AssetService>; |
| 67 | + assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => { |
| 68 | + const assetMetadata = { |
| 69 | + asset: { ...asset }, |
| 70 | + regions: [ |
| 71 | + MockFactory.createTestRegion("region-1", ["tag1"]), |
| 72 | + MockFactory.createTestRegion("region-2", ["tag1"]), |
| 73 | + ], |
| 74 | + version: appInfo.version, |
| 75 | + }; |
| 76 | + |
| 77 | + return Promise.resolve(assetMetadata); |
| 78 | + }); |
| 79 | + }); |
| 80 | + |
| 81 | + it("Is defined", () => { |
| 82 | + expect(CntkExportProvider).toBeDefined(); |
| 83 | + }); |
| 84 | + |
| 85 | + it("Can be instantiated through the factory", () => { |
| 86 | + const options: ICntkExportProviderOptions = { |
| 87 | + assetState: ExportAssetState.All, |
| 88 | + testTrainSplit: 80, |
| 89 | + }; |
| 90 | + const exportProvider = ExportProviderFactory.create("cntk", testProject, options); |
| 91 | + expect(exportProvider).not.toBeNull(); |
| 92 | + expect(exportProvider).toBeInstanceOf(CntkExportProvider); |
| 93 | + }); |
| 94 | + |
| 95 | + it("Creates correct folder structure", async () => { |
| 96 | + const provider = createProvider(testProject); |
| 97 | + await provider.export(); |
| 98 | + |
| 99 | + const storageProviderMock = LocalFileSystemProxy as any; |
| 100 | + const createContainerCalls = storageProviderMock.mock.instances[0].createContainer.mock.calls; |
| 101 | + const createContainerArgs = createContainerCalls.map((args) => args[0]); |
| 102 | + |
| 103 | + const expectedFolderPath = "Project-TestProject-CNTK-export"; |
| 104 | + expect(createContainerArgs).toContain(expectedFolderPath); |
| 105 | + expect(createContainerArgs).toContain(`${expectedFolderPath}/positive`); |
| 106 | + expect(createContainerArgs).toContain(`${expectedFolderPath}/negative`); |
| 107 | + expect(createContainerArgs).toContain(`${expectedFolderPath}/testImages`); |
| 108 | + }); |
| 109 | + |
| 110 | + it("Writes export files to storage provider", async () => { |
| 111 | + const provider = createProvider(testProject); |
| 112 | + const getAssetsSpy = jest.spyOn(provider, "getAssetsForExport"); |
| 113 | + |
| 114 | + await provider.export(); |
| 115 | + |
| 116 | + const assetsToExport = await getAssetsSpy.mock.results[0].value; |
| 117 | + const testSplit = (100 - (defaultOptions.testTrainSplit || 80)) / 100; |
| 118 | + const testCount = Math.ceil(assetsToExport.length * testSplit); |
| 119 | + const testArray = assetsToExport.slice(0, testCount); |
| 120 | + const trainArray = assetsToExport.slice(testCount, assetsToExport.length); |
| 121 | + |
| 122 | + const storageProviderMock = LocalFileSystemProxy as any; |
| 123 | + const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls; |
| 124 | + const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls; |
| 125 | + |
| 126 | + expect(writeBinaryCalls).toHaveLength(testAssets.length); |
| 127 | + expect(writeTextFileCalls).toHaveLength(testAssets.length * 2); |
| 128 | + |
| 129 | + testArray.forEach((assetMetadata) => { |
| 130 | + const testFolderPath = "Project-TestProject-CNTK-export/testImages"; |
| 131 | + assertExportedAsset(testFolderPath, assetMetadata); |
| 132 | + }); |
| 133 | + |
| 134 | + trainArray.forEach((assetMetadata) => { |
| 135 | + const trainFolderPath = "Project-TestProject-CNTK-export/positive"; |
| 136 | + assertExportedAsset(trainFolderPath, assetMetadata); |
| 137 | + }); |
| 138 | + }); |
| 139 | + |
| 140 | + function assertExportedAsset(folderPath: string, assetMetadata: IAssetMetadata) { |
| 141 | + const storageProviderMock = LocalFileSystemProxy as any; |
| 142 | + const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls; |
| 143 | + const writeBinaryFilenameArgs = writeBinaryCalls.map((args) => args[0]); |
| 144 | + const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls; |
| 145 | + const writeTextFilenameArgs = writeTextFileCalls.map((args) => args[0]); |
| 146 | + |
| 147 | + expect(writeBinaryFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}`); |
| 148 | + expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.labels.tsv`); |
| 149 | + expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.tsv`); |
| 150 | + |
| 151 | + const writeLabelsCall = writeTextFileCalls |
| 152 | + .find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.labels.tsv`) >= 0); |
| 153 | + |
| 154 | + const writeBoxesCall = writeTextFileCalls |
| 155 | + .find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.tsv`) >= 0); |
| 156 | + |
| 157 | + const expectedLabelData = `${assetMetadata.regions[0].tags[0]}${os.EOL}${assetMetadata.regions[1].tags[0]}`; |
| 158 | + expect(writeLabelsCall[1]).toEqual(expectedLabelData); |
| 159 | + |
| 160 | + const expectedBoxData = []; |
| 161 | + // tslint:disable-next-line:max-line-length |
| 162 | + expectedBoxData.push(`${assetMetadata.regions[0].boundingBox.left}\t${assetMetadata.regions[0].boundingBox.left + assetMetadata.regions[0].boundingBox.width}\t${assetMetadata.regions[0].boundingBox.top}\t${assetMetadata.regions[0].boundingBox.top + assetMetadata.regions[0].boundingBox.height}`); |
| 163 | + // tslint:disable-next-line:max-line-length |
| 164 | + expectedBoxData.push(`${assetMetadata.regions[1].boundingBox.left}\t${assetMetadata.regions[1].boundingBox.left + assetMetadata.regions[1].boundingBox.width}\t${assetMetadata.regions[1].boundingBox.top}\t${assetMetadata.regions[1].boundingBox.top + assetMetadata.regions[1].boundingBox.height}`); |
| 165 | + expect(writeBoxesCall[1]).toEqual(expectedBoxData.join(os.EOL)); |
| 166 | + } |
| 167 | +}); |
0 commit comments