Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[basicdataset] Add PennTreebank dataset #1580

Merged
merged 6 commits into from
May 3, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebank.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.basicdataset.nlp;

import ai.djl.Application;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.basicdataset.RawDataset;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;
import java.io.IOException;
import java.nio.file.Path;

/**
* The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal
* (WSJ) collection of 98,732 stories for syntactic annotation.
*/
public class PennTreebank implements RawDataset<Path> {

private static final String VERSION = "1.0";
private static final String ARTIFACT_ID = "penntreebank";

private Dataset.Usage usage;
private Path root;

private MRL mrl;
private boolean prepared;

PennTreebank(Builder builder) {
this.usage = builder.usage;
mrl = builder.getMrl();
}

/**
* Creates a builder to build a {@link PennTreebank}.
*
* @return a new {@link PennTreebank.Builder} object
*/
public static Builder builder() {
return new Builder();
}
/**
* Fetches an iterator that can iterate through the {@link Dataset}. This method is not
* implemented for the PennTreebank dataset because the PennTreebank dataset is not suitable for
* iteration. If the method is called, it will directly return {@code null}.
*
* @param manager the dataset to iterate through
* @return an {@link Iterable} of {@link Batch} that contains batches of data from the dataset
*/
@Override
public Iterable<Batch> getData(NDManager manager) throws IOException, TranslateException {
return null;
}

/**
* Get data from the PennTreebank dataset. This method will directly return the path of required
* dataset.
*
* @return a {@link Path} object locating the PennTreebank dataset file
*/
@Override
public Path getData() throws IOException {
prepare(null);
return root;
}

/**
* Prepares the dataset for use with tracked progress.
*
* @param progress the progress tracker
* @throws IOException for various exceptions depending on the dataset
*/
@Override
public void prepare(Progress progress) throws IOException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Artifact.Item item;

switch (usage) {
case TRAIN:
item = artifact.getFiles().get("train");
break;
case TEST:
item = artifact.getFiles().get("test");
break;
case VALIDATION:
item = artifact.getFiles().get("valid");
break;
default:
throw new UnsupportedOperationException("Unsupported usage type.");
}
root = mrl.getRepository().getFile(item, "").toAbsolutePath();
prepared = true;
}

/** A builder to construct a {@link PennTreebank} . */
public static final class Builder {

Repository repository;
String groupId;
String artifactId;
Dataset.Usage usage;

/** Constructs a new builder. */
public Builder() {
repository = BasicDatasets.REPOSITORY;
groupId = BasicDatasets.GROUP_ID;
artifactId = ARTIFACT_ID;
usage = Dataset.Usage.TRAIN;
}

/**
* Sets the optional repository for the dataset.
*
* @param repository the new repository
* @return this builder
*/
public Builder optRepository(Repository repository) {
this.repository = repository;
return this;
}

/**
* Sets optional groupId.
*
* @param groupId the groupId
* @return this builder
*/
public Builder optGroupId(String groupId) {
this.groupId = groupId;
return this;
}

/**
* Sets the optional artifactId.
*
* @param artifactId the artifactId
* @return this builder
*/
public Builder optArtifactId(String artifactId) {
if (artifactId.contains(":")) {
String[] tokens = artifactId.split(":");
groupId = tokens[0];
this.artifactId = tokens[1];
} else {
this.artifactId = artifactId;
}
return this;
}

/**
* Sets the optional usage for the dataset.
*
* @param usage the usage
* @return this builder
*/
public Builder optUsage(Dataset.Usage usage) {
this.usage = usage;
return this;
}

/**
* Builds a new {@link PennTreebank} object.
*
* @return the new {@link PennTreebank} object
*/
public PennTreebank build() {
return new PennTreebank(this);
}

MRL getMrl() {
return repository.dataset(Application.NLP.ANY, groupId, artifactId, VERSION);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.basicdataset;

import ai.djl.basicdataset.nlp.PennTreebank;
import ai.djl.repository.Repository;
import ai.djl.training.dataset.Dataset;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import org.testng.Assert;
import org.testng.annotations.Test;

public class PennTreebankTest {

@Test
public void testPennTreebankTrainLocal() throws IOException {
Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo");
PennTreebank trainingSet =
PennTreebank.builder()
.optRepository(repository)
.optUsage(Dataset.Usage.TRAIN)
.build();
Path path = trainingSet.getData();
Assert.assertTrue(Files.isRegularFile(path));
Assert.assertEquals(path.getFileName().toString(), "ptb.train.txt");
}

@Test
public void testPennTreebankTestLocal() throws IOException {
Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo");
PennTreebank trainingSet =
PennTreebank.builder()
.optRepository(repository)
.optUsage(Dataset.Usage.TEST)
.build();
Path path = trainingSet.getData();
Assert.assertTrue(Files.isRegularFile(path));
Assert.assertEquals(path.getFileName().toString(), "ptb.test.txt");
}

@Test
public void testPennTreebankValidationLocal() throws IOException {
Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo");
PennTreebank trainingSet =
PennTreebank.builder()
.optRepository(repository)
.optUsage(Dataset.Usage.VALIDATION)
.build();
Path path = trainingSet.getData();
Assert.assertTrue(Files.isRegularFile(path));
Assert.assertEquals(path.getFileName().toString(), "ptb.valid.txt");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"metadataVersion": "0.2",
"resourceType": "dataset",
"application": "nlp",
"groupId": "ai.djl.basicdataset",
"artifactId": "penntreebank",
"name": "penntreebank",
"description": "The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal (WSJ) collection of 98,732 stories for syntactic annotation.",
"website": "https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/",
"licenses": {
"license": {
"name": "LDC User Agreement for Non-Members",
"url": "https://catalog.ldc.upenn.edu/license/ldc-non-members-agreement.pdf"
}
},
"artifacts": [
{
"version": "1.0",
"snapshot": false,
"name": "penntreebank",
"files": {
"train":{
"uri" : "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt",
"sha1Hash": "f9ffb014fa33bd5730e5029697ad245184f3a678",
"size": 5101618
},
"test":{
"uri" : "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt",
"sha1Hash": "5c15c548b42d80bce9332b788514e6635fb0226e",
"size": 449945
},
"valid":{
"uri" : "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt",
"sha1Hash": "d9f5fed6afa5e1b82cd1e3e5f5040f6852940228",
"size": 399782
}
}
}
]
}