Program Listing for File batch_translator.cpp¶
↰ Return to documentation for file (src/translator/batch_translator.cpp
)
#include "batch_translator.h"
#include "batch.h"
#include "byte_array_util.h"
#include "common/logging.h"
#include "data/corpus.h"
#include "data/text_input.h"
#include "translator/beam_search.h"
namespace marian {
namespace bergamot {
BatchTranslator::BatchTranslator(DeviceId const device, Vocabs &vocabs, Ptr<Options> options,
const AlignedMemory *modelMemory, const AlignedMemory *shortlistMemory)
: device_(device),
options_(options),
vocabs_(vocabs),
modelMemory_(modelMemory),
shortlistMemory_(shortlistMemory) {}
void BatchTranslator::initialize() {
// Initializes the graph.
if (options_->hasAndNotEmpty("shortlist")) {
int srcIdx = 0, trgIdx = 1;
bool shared_vcb =
vocabs_.sources().front() ==
vocabs_.target(); // vocabs_->sources().front() is invoked as we currently only support one source vocab
if (shortlistMemory_->size() > 0 && shortlistMemory_->begin() != nullptr) {
slgen_ = New<data::BinaryShortlistGenerator>(shortlistMemory_->begin(), shortlistMemory_->size(),
vocabs_.sources().front(), vocabs_.target(), srcIdx, trgIdx,
shared_vcb, options_->get<bool>("check-bytearray"));
} else {
// Changed to BinaryShortlistGenerator to enable loading binary shortlist file
// This class also supports text shortlist file
slgen_ = New<data::BinaryShortlistGenerator>(options_, vocabs_.sources().front(), vocabs_.target(), srcIdx,
trgIdx, shared_vcb);
}
}
graph_ = New<ExpressionGraph>(true); // set the graph to be inference only
auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
graph_->setDefaultElementType(typeFromString(prec[0]));
graph_->setDevice(device_);
graph_->getBackend()->configureDevice(options_);
graph_->reserveWorkspaceMB(options_->get<size_t>("workspace"));
if (modelMemory_->size() > 0 &&
modelMemory_->begin() !=
nullptr) { // If we have provided a byte array that contains the model memory, we can initialise the model
// from there, as opposed to from reading in the config file
ABORT_IF((uintptr_t)modelMemory_->begin() % 256 != 0,
"The provided memory is not aligned to 256 bytes and will crash when vector instructions are used on it.");
if (options_->get<bool>("check-bytearray")) {
ABORT_IF(!validateBinaryModel(*modelMemory_, modelMemory_->size()),
"The binary file is invalid. Incomplete or corrupted download?");
}
const std::vector<const void *> container = {
modelMemory_->begin()}; // Marian supports multiple models initialised in this manner hence std::vector.
// However we will only ever use 1 during decoding.
scorers_ = createScorers(options_, container);
} else {
scorers_ = createScorers(options_);
}
for (auto scorer : scorers_) {
scorer->init(graph_);
if (slgen_) {
scorer->setShortlistGenerator(slgen_);
}
}
graph_->forward();
}
void BatchTranslator::translate(Batch &batch) {
std::vector<data::SentenceTuple> batchVector;
auto &sentences = batch.sentences();
size_t batchSequenceNumber{0};
for (auto &sentence : sentences) {
data::SentenceTuple sentence_tuple(batchSequenceNumber);
Segment segment = sentence.getUnderlyingSegment();
sentence_tuple.push_back(segment);
batchVector.push_back(sentence_tuple);
++batchSequenceNumber;
}
size_t batchSize = batchVector.size();
std::vector<size_t> sentenceIds;
std::vector<int> maxDims;
for (auto &ex : batchVector) {
if (maxDims.size() < ex.size()) maxDims.resize(ex.size(), 0);
for (size_t i = 0; i < ex.size(); ++i) {
if (ex[i].size() > (size_t)maxDims[i]) maxDims[i] = (int)ex[i].size();
}
sentenceIds.push_back(ex.getId());
}
typedef marian::data::SubBatch SubBatch;
typedef marian::data::CorpusBatch CorpusBatch;
std::vector<Ptr<SubBatch>> subBatches;
for (size_t j = 0; j < maxDims.size(); ++j) {
subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_.sources().at(j)));
}
std::vector<size_t> words(maxDims.size(), 0);
for (size_t i = 0; i < batchSize; ++i) {
for (size_t j = 0; j < maxDims.size(); ++j) {
for (size_t k = 0; k < batchVector[i][j].size(); ++k) {
subBatches[j]->data()[k * batchSize + i] = batchVector[i][j][k];
subBatches[j]->mask()[k * batchSize + i] = 1.f;
words[j]++;
}
}
}
for (size_t j = 0; j < maxDims.size(); ++j) subBatches[j]->setWords(words[j]);
auto corpus_batch = Ptr<CorpusBatch>(new CorpusBatch(subBatches));
corpus_batch->setSentenceIds(sentenceIds);
auto search = New<BeamSearch>(options_, scorers_, vocabs_.target());
auto histories = std::move(search->search(graph_, corpus_batch));
batch.completeBatch(histories);
}
} // namespace bergamot
} // namespace marian