This post is about the development of slimt, software I have been working on for the past 10-days or so. It’s slim machine-translation (MT) inference code. It was fun doing it, and some bits along the journey documentable.

Flashback One of the first things I was tasked with starting undergraduate research at university was to move an Optical Character Recognition (OCR) library from C++ to Python. This was circa 2015 - PyTorch was new, TensorFlow was dark magic research seniors rejected in favour of PyTorch. But surprise, what the lab had powering it’s document research efforts was rnnlib. Story goes that some researcher managed to get it to work, and it’s been powering text-recognition ever since. Training was on CPUs, the library predated ImageNet. But Bidirectional LSTMs were the best the lab had back then.

I went ahead, checked the source-code and reported back we should just do it in PyTorch. Back then I’d blame lack of documentation and a missing map to the source-code. My advisor, unsurprisingly held on to the requirement - it’s just a bunch of matrix multiplies, why is it so hard to move from C++ to Python? Basically I had to find the parts where the said matrices showed up in code and use it - except it was taking long. I tried a lot of stuff - used Doxygen to generate diagrams, opened up source files based on names. Tried to step through the GDB debugger to find which lines of source where getting executed. I’d be naive and go back and report these things to my advisor who mostly works in computer vision to see blank faces all the time. The efforts didn’t succeed (or we did not wait for it rather).

On the bright side, I did learn quite a bit of C++ tooling and developed some interest in the area. Fast forward 5 years, I have moved from Python machine learning training to AI inference and learning to write machine learning frameworks and had come full circle to a similar task.

The task

A task pending long in my to-do list is polish lemonade, a translation input method engine for easier use. The blocker was that the input-method kept hanging on language-switches, and I could not reason with the large bergamot-source why it was happening. I was procrastinating diving deeper, until contemporary efforts such as ggerganov/ggml and karpathy/llama2.c got me thinking - wouldn’t my translation inference only need specialized code for just one class of models?

The models are transformers with only minor modifications. The tiny11 class of models are actually quite small after 8-bit quantization - ende.student.tiny11 is 17MB on my system. marian-dev is source-code I’d been lurking around for about 2 years.

bergamot-translator builds on top of marian-dev and uses the inference code-path from marian-dev. While marian is a a capable neural network library with focus on machine translation, all the bells and whistles that come with it are not necessary to run inference on client-machines (e.g: autograd, multiple sequence-to-sequence architecture support, beam-search). When I started marian used to be a monstrous 30 minute compile on my laptop, which I brought down using ccache and wrote about - stripping the code to only inference would make that look lame and useless. I would not even need ccache anymore.

So - write an own transformer forward pass, bring compile-times down like crazy, reduce source-complexity*, minimize dependencies, get something of possible value out of it - lot of boxes were checking themselves. Be a shame if I decided not to give it a go, especially given all the free-time I have.

So, there’s a clear destination - the approach and path still had elements of uncertainties. This is essentially the same task as 7 years ago - produce new inference code for a model trained from a library.

Approach

I had some angles in mind - (1) inspect the model binary, (2) step through a debugger checking activated paths, understand the logic behind and then implement. Copying code was okay (and a requirement, since the code computing float operations had to match), so we’d have it slightly easier.

Python

I thought of projecting 8-bit int to 32-bit float and doing the operations in PyTorch for a faster first iteration. Once there was clarity in the network architecture and I managed to realize it fast and verify with Python, implementing C++ should be easier - was the thinking.

I wrote a quick script to load the model in Python.

model.bin (truncated)

Item(Wemb, intgemm8, Shape([32000, 256]), 8192256)
Item(Wemb_QuantMultA, intgemm8, Shape([1, 1]), 256)
Item(decoder_ff_logit_out_b, float32, Shape([1, 32000]), 128000)
Item(decoder_l1_context_Wk, intgemm8, Shape([256, 256]), 65792)
...
Item(decoder_l1_ffn_b2, float32, Shape([1, 256]), 1024)
Item(decoder_l1_ffn_ffn_ln_bias, float32, Shape([1, 256]), 1024)
Item(decoder_l1_ffn_ffn_ln_scale, float32, Shape([1, 256]), 1024)
Item(decoder_l1_rnn_W, intgemm8, Shape([256, 256]), 65792)
Item(decoder_l1_rnn_W_QuantMultA, float32, Shape([1, 1]), 256)
...
Item(decoder_l2_context_Wq, intgemm8, Shape([256, 256]), 65792)
Item(decoder_l2_context_Wq_QuantMultA, float32, Shape([1, 1]), 256)
Item(decoder_l2_context_Wv, intgemm8, Shape([256, 256]), 65792)
Item(decoder_l2_context_Wv_QuantMultA, float32, Shape([1, 1]), 256)
Item(decoder_l2_context_bk, float32, Shape([1, 256]), 1024)
Item(decoder_l2_context_bo, float32, Shape([1, 256]), 1024)
...
Item(encoder_l1_ffn_W1, intgemm8, Shape([256, 1536]), 393472)
Item(encoder_l1_ffn_W1_QuantMultA, float32, Shape([1, 1]), 256)
Item(encoder_l1_ffn_b1, float32, Shape([1, 1536]), 6144)
...
Item(encoder_l1_ffn_ffn_ln_bias, float32, Shape([1, 256]), 1024)
Item(encoder_l1_ffn_ffn_ln_scale, float32, Shape([1, 256]), 1024)
Item(encoder_l1_self_Wk, intgemm8, Shape([256, 256]), 65792)
Item(encoder_l1_self_Wk_QuantMultA, float32, Shape([1, 1]), 256)

Some things make sense. There are attention layers, feed-forward networks, RNNs, and output layer, some layernorm. I have the flat data, not the structure. I was hoping to recover neural network structure from a research or system-description paper of some sort. I checked Kim et al. (2019), there’s some mention of the RNN (SSRU), but a network diagram is missing. I could do further level searches on papers this paper was referring to. Without the structure, it would be difficult to reproduce in Python. I’ll also need some tooling to verify I’m on the right path. So I decided to see if throwing the debugger at this problem and stepping through would provide some information I could use. After-all, code was absolute truth.

Debugger

IDEs back in 2015 when I took on the previous porting had as well I’d guess. But back then I knew only gdb and mostly used vim. I still do use vim most of the time, and VSCode with an vim emulation layer. VSCode has a nice debugger. I currently use VSCodium, which is a fork. Internally, VSCode uses gdb, but it’s easier to move around and inspect.

Starting with bergamot’s translateBatch entrypoint, I jumped to function definitions and checked a few call-stacks. Created a script to link and save, as I unrolled the code while porting. See a few examples below.

Embedding (index_select)
marian::cpu::CopyRows(marian::Tensor out_, const marian::Tensor in_, const marian::Tensor indices) 
marian::CopyRows(marian::Tensor arg1, const marian::Tensor arg2, const marian::Tensor arg3) 
marian::RowsNodeOp::forwardOps()::{lambda()#1}::operator()() const(const struct {...} * const __closure) 
marian::Node::runForward(std::vector, std::allocator > > const&)(marian::Node * const this, const marian::NodeOps & ops) 
marian::Node::forward(marian::Node * const this) 
marian::ExpressionGraph::forward(marian::ExpressionGraph * const this, std::__cxx11::list > >, std::allocator > > > > & forwardTape, bool finalPass) 
marian::ExpressionGraph::forwardNext(marian::ExpressionGraph * const this) 
marian::ExpressionGraph::forward(marian::ExpressionGraph * const this) 
marian::BeamSearch::search(marian::BeamSearch * const this, marian::Ptr graph, marian::Ptr batch) 
marian::bergamot::TranslationModel::translateBatch(marian::bergamot::TranslationModel * const this, marian::bergamot::Workspace & workspace, marian::bergamot::Batch & batch)

Positional Embeddings
marian::inits::sinusoidalPositionEmbeddings(int start) 
marian::Transformer::addPositionalEmbeddings(const marian::Transformer * const this, marian::Expr input, int start, bool trainPosEmbeddings) 
marian::Transformer::addSpecialEmbeddings(const marian::Transformer * const this, marian::Expr input, int start) 
marian::EncoderTransformer::apply(marian::EncoderTransformer * const this, marian::Ptr batch) 
marian::EncoderTransformer::build(marian::EncoderTransformer * const this, marian::Ptr graph, marian::Ptr batch) 
marian::EncoderDecoder::startState(marian::EncoderDecoder * const this, marian::Ptr graph, marian::Ptr batch) 
marian::models::Stepwise::startState(marian::models::Stepwise * const this, marian::Ptr graph, marian::Ptr batch) 
marian::ScorerWrapper::startState(marian::ScorerWrapper * const this, marian::Ptr graph, marian::Ptr batch) 
marian::BeamSearch::search(marian::BeamSearch * const this, marian::Ptr graph, marian::Ptr batch) 
marian::bergamot::TranslationModel::translateBatch(marian::bergamot::TranslationModel * const this, marian::bergamot::Workspace & workspace, marian::bergamot::Batch & batch) 
operator()(const struct {...} * const __closure)

Encoder forward
marian::EncoderTransformer::apply(marian::EncoderTransformer * const this, marian::Ptr batch) 
marian::EncoderTransformer::build(marian::EncoderTransformer * const this, marian::Ptr graph, marian::Ptr batch) 
marian::EncoderDecoder::startState(marian::EncoderDecoder * const this, marian::Ptr graph, marian::Ptr batch) 
marian::models::Stepwise::startState(marian::models::Stepwise * const this, marian::Ptr graph, marian::Ptr batch) 
marian::ScorerWrapper::startState(marian::ScorerWrapper * const this, marian::Ptr graph, marian::Ptr batch) 
marian::BeamSearch::search(marian::BeamSearch * const this, marian::Ptr graph, marian::Ptr batch) 
marian::bergamot::TranslationModel::translateBatch(marian::bergamot::TranslationModel * const this, marian::bergamot::Workspace & workspace, marian::bergamot::Batch & batch) 
operator()(const struct {...} * const __closure)

transposed log mask
marian::Transformer::transposedLogMask(marian::Expr mask) 
marian::EncoderTransformer::apply(marian::EncoderTransformer * const this, marian::Ptr batch) 
marian::EncoderTransformer::build(marian::EncoderTransformer * const this, marian::Ptr graph, marian::Ptr batch) 
marian::EncoderDecoder::startState(marian::EncoderDecoder * const this, marian::Ptr graph, marian::Ptr batch) 
marian::models::Stepwise::startState(marian::models::Stepwise * const this, marian::Ptr graph, marian::Ptr batch) 
marian::ScorerWrapper::startState(marian::ScorerWrapper * const this, marian::Ptr graph, marian::Ptr batch) 
marian::BeamSearch::search(marian::BeamSearch * const this, marian::Ptr graph, marian::Ptr batch) 
marian::bergamot::TranslationModel::translateBatch(marian::bergamot::TranslationModel * const this, marian::bergamot::Workspace & workspace, marian::bergamot::Batch & batch) 
operator()(const struct {...} * const __closure)

postprocess: Applied after each FFN (Takes care of skip, off dropout and LayerNorm)
marian::Transformer::postProcess(const marian::Transformer * const this, std::string prefix, std::string ops, marian::Expr input, marian::Expr prevInput, float dropProb) 
marian::Transformer::LayerAttention(marian::Transformer * const this, std::string prefix, marian::Expr input, const marian::Expr & keys, const marian::Expr & values, const marian::Expr & mask, int dimHeads, bool cache, bool saveAttentionWeights) 
marian::EncoderTransformer::apply(marian::EncoderTransformer * const this, marian::Ptr batch) 
marian::EncoderTransformer::build(marian::EncoderTransformer * const this, marian::Ptr graph, marian::Ptr batch) 
marian::EncoderDecoder::startState(marian::EncoderDecoder * const this, marian::Ptr graph, marian::Ptr batch) 
marian::models::Stepwise::startState(marian::models::Stepwise * const this, marian::Ptr graph, marian::Ptr batch) 
marian::ScorerWrapper::startState(marian::ScorerWrapper * const this, marian::Ptr graph, marian::Ptr batch) 
marian::BeamSearch::search(marian::BeamSearch * const this, marian::Ptr graph, marian::Ptr batch) 
marian::bergamot::TranslationModel::translateBatch(marian::bergamot::TranslationModel * const this, marian::bergamot::Workspace & workspace, marian::bergamot::Batch & batch) 
operator()(const struct {...} * const __closure)

I quickly discovered the code-paths that were consistently getting activated when I ran input through, and discovered some of the realizations in code of the transformer network discussed in the paper. but this process was slower than what I’d wished for. Success was viable with this particular angle, just not fast enough. I needed something faster.

Breakthrough hack

During this process, I discovered the NodeOp macro, which looked as follows.

#define NodeOp(op) [=]() { op; }

For some reason, marian was using this to describe forward and backward in the computational graph, mostly consistently. Since searching for, finding the macro, manually putting the breakpoint wasn’t cutting it - I quickly googled if I could programmatically stop for debugger. Turns out, I can - std::raise(SIGTRAP).

The operating system notifies the debugger on SIGTRAP signal (if no debugger listening to handle, the program simply exits). The debugger can map the instruction pointer to the line in source both ways (provided compiled with -g and ideally, -O0, which is -DCMAKE_BUILD_TYPE=Debug).

I set a programmatic breakpoint and tried to extract call-stack programmatically as well:

#define NodeOp(op)                                                  \
  [=]() {                                                           \
    std::raise(SIGTRAP);                                            \
    std::string callstack = marian::getCallStack(/*skipLevels=*/3); \
    std::cerr << callstack << "\n";                                 \
    std::cerr << __PRETTY_FUNCTION__ ;                              \
    std::cerr << " " << __FILE__ << ":";                            \
    std::cerr << __LINE__ << "\n";                                  \
    op;                                                             \
  }

I quickly realized I no longer needed the SIGTRAP to know which functions where getting called. Dropping it and just using __PRETTY_FUNCTION__ information identified ops using NodeOp(...).

Ops (click to expand)
marian::cpu::integer::fetchAlphaFromModelNodeOp::forwardOps()::<lambda>() 
marian::DotBatchedNodeOp::forwardOps()::<lambda>() 
marian::GatherNodeOp::forwardOps()::<lambda>() 
marian::HighwayNodeOp::forwardOps()::<lambda>() 
marian::LayerNormalizationOp::forwardOps()::<lambda>() 
marian::LogSoftmaxNodeOp::forwardOps()::<lambda>() 
marian::NegNodeOp::forwardOps()::<lambda>() 
marian::PlusNodeOp::forwardOps()::<lambda>() 
marian::ReLUNodeOp::forwardOps()::<lambda>() 
marian::RowsNodeOp::forwardOps()::<lambda>() 
marian::ScalarAddNodeOp::forwardOps()::<lambda>() 
marian::ScalarMultNodeOp::forwardOps()::<lambda>() 
marian::SoftmaxNodeOp::forwardOps()::<lambda>() 
marian::TransposeNodeOp::forwardOps()::<lambda>()

The Ops that the trace rendered first were not exhaustive. There are some functions that simply use a capturing lambda and not the macro. That’s okay, I don’t think the authors of the macro ever intended it to be used this way. To my surprise, I discovered NodeOp was even mentioned by a documentation effort. Macros are supposed to be bad and evil per established wisdom.

If as reader, you feel that this article is jumping all around the place - know that it is an accurate reflection my mental state and uncertainty regarding the path at this point. But from here-on, I had clarity.

Tracing execution

I’m mostly lurking and operating in the compilers intersection machine-learning space now. So far I’ve also followed the minitorch tutorial twice - once in Python and once in C++ to know what a computation graph is and how to build autograd. This puts me in place with certain theory and understanding that could make life further simpler for me.

All that aside, we will consider a toy language of data-types being Expr, indicating an expression (more specifically, the result of an expression). We can do the following with Exprs for an example.

We’ll make a simplified definition of Expr as follows:

// Expr lhs = Op(rhs[0], rhs[1], ...)
struct Expr { 
   float* value; // Holds underlying 
   storage size_t size;  // Size of the storage, in count(float).

   using Operands = std::vector<Expr>; 
   Operands rhs; // operands that populate the result value.

   using Op  = std::function<void(void)>; 
   using Ops = std::vector<Op>;

   Ops forward() { 
     auto op = [=](){ 
       // Open up rhs, apply intended function.
       // write to value.  
     }; 
     return { op }; 
   }

   float *grad; // Same size as value, holds gradient

   // Operates on grad, after receiving gradients from Expr(s) ahead.  
   Ops backward(float *grad_from_successor_node);  
};

An abstraction like Expr forms the basis for autograd frameworks. Since we’re tracing the inference path, we’re not interested in backward and grad.

Consider the expression as an LHS which is obtained by some operation on the rhs operands.

Expr x = ones(2, 2);
Expr y = zero(2, 2);
Expr z = x + y;

Consider z, which is a result of + on [x, y], Expr would be as follows:

struct Add: public Expr {
public:
  // ...
  Ops forward() {
    // Open up rhs, apply intended function.
    // write to value.
    // x, y are available in rhs.
    Op add = [=](){
      for(size_t i = 0; i < size; i++){
          value[i] = rhs[0].value[i] + rhs[1].value[i];
      }
    };
    return { add };
  }
};

Note that we’re only recording that so and so operations must be done using so and so storage locations - thunks. We’ve not actually executed them yet. Execution would look as follows:

// Compute loss
loss = f(rhs1, rhs2, ...);

// Note: Topological order begins from first expr, it is the
// reverse-topological-order that starts with loss. 
std::vector<Expr> order = topological_order(loss); 

// Run forward ops
for(auto expr: order){
  forward_ops = expr->forward();
  // Execute functions, ends up in order of construction.
  for(auto &op: forward_ops){
    op();
  }
}

Okay.. what’s the point of all this? Turns out NodeOp being used to package the thunk means I can use NodeOp macro to add a pre and post hook to the statements. This means by the following modification, I can inspect the values of lhs (value) before and after, and also inspect the state of rhs (operands) during the op.

The modification I’m looking for looks like:

#define WrapStatementInLambda(statement)                            \
  [=]() {                                                           \
    // Open up and inspect value (Before op)                        \
    // (Not usually required.)                                      \
                                                                    \
    // Extra local information                                      \
    std::cerr << __PRETTY_FUNCTION__ ;                              \
    std::cerr << " " << __FILE__ << ":";                            \
    std::cerr << __LINE__ << "\n";                                  \
                                                                    \
    // Execute the operation                                        \
    // Just leave the arg to unroll statements wrapped by macro.    \
                                                                    \
    statement;                                                      \
                                                                    \
    // save value (lhs), rhs[0], rhs[1] ... to disk(?)              \
  }


#define NodeOp(op)  WrapStatementInLambda(op)

The rich-version I eventually ended up using is available in slimt/marian-trace-gen.h. I was also able to extract shape metadata at runtime, some name and unique-identifier information that was stored in the values. The unique-id meant I could conditionally stop during execution based on the identifier value. The macro-modification is a one-off throwaway creation, but can be refined to trace the exact final operations executed by a marian forward pass and backward-pass if need be. If I want to take advantage of MLIR provisions to optimize these Op primitives to any (supported) target hardware, this could be a viable route - but I digress. There were a few Exprs not using NodeOp, but were easy to tame.

My traces looked like below:

file: "/home/jerin/code/bergamot-translator/3rd_party/marian-dev/src/graph/node_operators_binary.h"
line: 836
fn: "marian::PlusNodeOp::forwardOps()::<lambda()>"
op: { Element(_1 = _2 + _3, val_, child(0)->val(), child(1)->val()) }
before: var_45 float32 [2x8x4x4]
after: var_45 float32 [2x8x4x4] var_45-PlusNodeOp-float32_2x8x4x4-lhs.bin
operands: 
  - var_44 float32 [2x8x4x4] var_45-PlusNodeOp-float32_2x8x4x4-rhs0-float32_2x8x4x4.bin
  - var_16 float32 [2x1x1x4] var_45-PlusNodeOp-float32_2x8x4x4-rhs1-float32_2x1x1x4.bin


file: "/home/jerin/code/bergamot-translator/3rd_party/marian-dev/src/graph/node_operators_unary.h"
line: 425
fn: "marian::SoftmaxNodeOp::forwardOps()::<lambda()>"
op: { Softmax(val_, child(0)->val()) }
before: var_46 float32 [2x8x4x4]
after: var_46 float32 [2x8x4x4] var_46-SoftmaxNodeOp-float32_2x8x4x4-lhs.bin
operands: 
  - var_45 float32 [2x8x4x4] var_46-SoftmaxNodeOp-float32_2x8x4x4-rhs0-float32_2x8x4x4.bin

The full execution trace is available here.

Notice, how I had the LHS and RHS for the ops saved onto-disk under unique names. This meant I could even unit-test my ops. The trace was linear unlike the nested functions I’d been hopping through back and forth, context switching. The linear nature made the underlying operations easier to reason with. With some domain knowledge it’s easy to recognize the above code as the softmax in attention after addition of mask for pad-tokens.

At this point, I knew what I wanted was realizable at a pace I was happy with. The problem was more or less solved inside my head. I had all the missing pieces, and a really small chance of failure. Note that I hadn’t completed the solution yet, I’ve just figured out the solution.

Finishing up

I had made the process mechanical. I traversed the trace porting code step-by-step, checking LHS and RHS tensors matched what I computed using my ported code. I built some verification convenience functions to check as I progressed as well, which looked within source as follows.

Tensor &encoder_out = x;
VERIFY_MATCH(encoder_out,
             "var_394-LayerNormalizationOp-float32_1x2x4x256-lhs.bin");
return decoder_.decode(encoder_out, mask, batch.words());

I hit some hiccups at 8-bit matrix multiply using intgemm (and ruy later on) in marian. But the saved tensor input/output pairs I left for myself via the tracing helped a lot.

That I was familiar with the components help speed things up a bit. Some corners were cut, but no problem - we can fix it slowly if need be. There is a lot more room or optimizations. I am currently trying my hands at compilers and parallel-programming, and weak-baselines should be opportunity to learn more things on the way.

Changing Node.h and recompiling I’d estimate take 20+ minutes on my laptop, which is enough time to walk away elsewhere while developing tracing scripts - so the new more powerful PC helped a bit.

The source-code is made public on achieving bare-minimum functionality on x86_64, and have a PR open to support aarch64. Some code that I wrote for ARM support for Mozilla back in the day and the experience ended up helping. As of now I am aware of KDE using bergamot’s models in KTextAddons. A refined version of this could be useful to Mozilla, who I know to be using only tiny11 models.

This post mostly deals with the development process. I hope to write in the future about the actual technical and math content surrounding these models.

References

Young Jin Kim, Marcin Junczys-Dowmunt, Hany Hassan Awadalla, Alham Fikri Aji, Kenneth Heafield, Roman Grundkiewicz, and Nikolay Bogoychev. 2019. From research to production and back: Ludicrously fast neural machine translation. In Proceedings of the 3rd workshop on neural generation and translation, pages 280–288.