diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c98bca9e..1b540ff5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Decoding multi-source models in marian-server with --tsv - GitHub workflows on Ubuntu, Windows, and MacOS - LSH indexing to replace short list - ONNX support for transformer models diff --git a/src/data/text_input.cpp b/src/data/text_input.cpp index 5d68d9fb4..7c9df149e 100644 --- a/src/data/text_input.cpp +++ b/src/data/text_input.cpp @@ -23,8 +23,11 @@ const SentenceTuple& TextIterator::dereference() const { TextInput::TextInput(std::vector inputs, std::vector> vocabs, Ptr options) - : DatasetBase(inputs, options), vocabs_(vocabs) { - // note: inputs are automatically stored in the inherited variable named paths_, but these are + : DatasetBase(inputs, options), + vocabs_(vocabs), + maxLength_(options_->get("max-length")), + maxLengthCrop_(options_->get("max-length-crop")) { + // Note: inputs are automatically stored in the inherited variable named paths_, but these are // texts not paths! for(const auto& text : paths_) files_.emplace_back(new std::istringstream(text)); @@ -42,6 +45,10 @@ SentenceTuple TextInput::next() { std::string line; if(io::getline(*files_[i], line)) { Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, /*inference =*/ inference_); + if(this->maxLengthCrop_ && words.size() > this->maxLength_) { + words.resize(maxLength_); + words.back() = vocabs_.back()->getEosId(); // note: this will not work with class-labels + } if(words.empty()) words.push_back(Word::ZERO); // @TODO: What is this for? @BUGBUG: addEOS=true, so this can never happen, right? tup.push_back(words); diff --git a/src/data/text_input.h b/src/data/text_input.h index db99ef6ae..b08a4fdcc 100644 --- a/src/data/text_input.h +++ b/src/data/text_input.h @@ -33,6 +33,9 @@ class TextInput : public DatasetBase { size_t pos_{0}; + size_t maxLength_{0}; + bool maxLengthCrop_{false}; + public: typedef SentenceTuple Sample; diff --git a/src/translator/translator.h b/src/translator/translator.h index cc68a4f01..a042e7973 100755 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "data/batch_generator.h" #include "data/corpus.h" #include "data/shortlist.h" @@ -245,7 +247,11 @@ class TranslateService : public ModelServiceTask { } std::string run(const std::string& input) override { - auto corpus_ = New(std::vector({input}), srcVocabs_, options_); + // split tab-separated input into fields if necessary + auto inputs = options_->get("tsv", false) + ? convertTsvToLists(input, options_->get("tsv-fields", 1)) + : std::vector({input}); + auto corpus_ = New(inputs, srcVocabs_, options_); data::BatchGenerator batchGenerator(corpus_, options_); auto collector = New(); @@ -258,7 +264,6 @@ class TranslateService : public ModelServiceTask { ThreadPool threadPool_(numDevices_, numDevices_); for(auto batch : batchGenerator) { - auto task = [=](size_t id) { thread_local Ptr graph; thread_local std::vector> scorers; @@ -287,5 +292,30 @@ class TranslateService : public ModelServiceTask { auto translations = collector->collect(options_->get("n-best")); return utils::join(translations, "\n"); } + +private: + // Converts a multi-line input with tab-separated source(s) and target sentences into separate lists + // of sentences from source(s) and target sides, e.g. + // "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"] + std::vector convertTsvToLists(const std::string& inputText, size_t numFields) { + std::vector outputFields(numFields); + + std::string line; + std::vector lineFields(numFields); + std::istringstream inputStream(inputText); + bool first = true; + while(std::getline(inputStream, line)) { + utils::splitTsv(line, lineFields, numFields); + for(size_t i = 0; i < numFields; ++i) { + if(!first) + outputFields[i] += "\n"; // join sentences with a new line sign + outputFields[i] += lineFields[i]; + } + if(first) + first = false; + } + + return outputFields; + } }; } // namespace marian