与早期版本(详情可以参考这里:4e23ad8 23'7/27)不同,目前的版本有两种模式,一是generate, 根据给定的prompt生成一个简短的英文小故事。二是chat, 需要输入system prompt和问题,生成一个回答并退出,没有多轮对话,并且回答与问题没有太大关联性而且有不少问题(尤其是15M模型),所以该模式还不完善,并且有一个 feature/chat 分支处于开发中有兴趣可以关注该分支。我们主要关注代码更新的部分。

执行流程

该版本代码进行了重构,sampler封装到结构体里,其次把分阶段的执行流程各自封装到函数里面,调用逻辑会更清晰一些。同样的,main主要处理命令行参数,然后是模型推理流程。

main 执行流程

  1. 解析命令行参数,包括 checkpoint 文件路径, temperature, steps等。
  2. 调用build_transformer,会调用read_checkpoint加载checkpoint 文件内容,包括模型配置和权重。其次调用malloc_run_state初始化运行时变量,包括kv cache等。注意,这里没有再对s->k, s->v分配内存,直接指向kv cache相应的位置,省掉一次内存拷贝操作。
  3. 调用build_tokenizer,加载tokenizer.bin
  4. 调用build_sampler,初始化采样对象。
  5. 接下来就是根据给定运行模式,调用generate或者chat执行生成或问答流程。
  6. 内存清理并退出

generate

计算过程:

  1. 首先调用encode,把用户输入的prompt编码为token,存储到prompt_tokens,维度为(strlen(prompt)+3,)的int数组。
  2. 然后从prompt_tokens的第一个token开始(通常是BOS)执行主循环生成文本,即根据给定的**seq_len(steps)**执行循环 [0, steps]
    1. 调用forward得到当前位置poslogits
      1. 如果当前词元还是prompt,则从prompt里面取下一个token
      2. 否则,根据logits调用sample采样得到下一个token
    2. 如果下一个token是BOS,则终止循环。
    3. 调用decode从tokenizer里面解码得到token对应的单词,并打印。
    4. 继续下一个循环
  3. 生成完毕,释放prompt_tokens临时堆内存

chat

Todo:

forward

模型推理的计算都封装到该函数里了,具体计算过程与4e23ad8 23'7/27里面的transformer基本类似,差异是增加了对**Multi-Query Attention(MQA)Grouped-Query Attention(GQA)**的支持。
先看下MHAMQAGQA之间的区别,

在该代码实现里,

  1. 如果n_kv_heads==n_heads则为MHA,每个query都对应一组独立的key和value。
  2. 如果n_kv_heads==1则为MQA,所有query对应一组key和value。
  3. 否则1<n_kv_heads<n_heads则为GQA,多个query对应一组key和value,并且不止一组。

具体的代码实现流程如下:

  1. 初始化局部变量,新增kv_dim = dim/n_heads * n_kv_heads = head_dim * n_kv_heads,因为要跟query保持一致,所以kv head的head_dim也是head_dim = dim/n_heads固定的与n_kv_heads无关。补充一下,这里kv_dim代表每个token对应在kv cache里面的偏移大小或者说是每个token的dim维度,具体可以参考后面kv_cache内存结构。接下来进入正题,从token_embedding_table中取出当前token的embedding(dim, ) 放入 s->x 中作为模型输入。
  2. 循环遍历每一层for l in 0..n_layers
    1. s->x计算rmsnorm并存入 s->xb;计算该层kv cache偏移,loff=l*seq_len*kv_dim,如上分析,kv_dim就是每个token的维度。
    2. 计算 Q、K、V,结果直接存入 s->q, s->k, s->v,而后两者是指向kv cache当前位置的,所以省掉一次内存拷贝。
    3. 计算RoPE编码的Q、K,结果存入 s->q, s->k。注意这里的改进,多头合在一起一并计算,效率更高。
    4. 计算多头注意力,循环遍历每个头for h in 0..n_heads:
      1. Q起始位置,Q原始维度(dim, )拆分多头后为(n_heads, head_dim),所以第h头的起始位置为**float* q = s->q + h * head_size
      2. atten score (n_heads, seq_len),因为是按头存放的,所以在第h头的起始位置就是 s->att + h * p->seq_len
      3. kv cache在dim维度拆分多头,并考虑MQA/GQA后维度为(layer, seq_len, n_kv_heads, head_dim)或者简化为(layer, seq_len, kv_dim),因为n_kv_heads <= n_heads要均分共享,这里的关键是需要计算query的第h个head对应的kv cache偏移在哪里。代码实现里面有个小技巧,我们先来看sequence维度,t位置的偏移为t * kv_dim,这个好理解,前面已经讲过,然后再计算kv_dim里面的偏移是多少,这里会用到临时变量kv_mul = n_kv_heads / n_heads, 则query第h头在第t个token维度内也就是kv_dim内,偏移为(h/kv_mul) * head_size
      4. 计算该词元对它之前所有词元的注意力分数(包含自己),其实就是masked attention,for t in 0..pos+1:
        • t 个词元对应在第 h 头内的偏移为 t * kv_dim,这个好理解,前面已经讲过。综上,t位置的cache为 s->key_cache + loff + (h/kv_mul) * head_size + t * kv_dim
        • 计算 $att(t)=\sum_{i=0}^{head\_dim}Q_i*K_i/\sqrt{head\_dim}$ ,其维度为 (seq_len, ),但因为仅仅[0, pos]的值为有效的,也即每个已生成的词元都有一个分数。(注意,att数组内容在计算下一个token时被重新填充)
      5. att经过softmax转换为注意力权重 $att=softmax(att)$
      6. 接下来一步,是把每个词元的V向量与其对应的att值相乘,然后加总形成一个新的V向量,存储到s->xb中,计算公式为 $\hat{V}=\sum_{t=0}^{pos}att(t)*V(t)$,其维度为(head_dim,)V相同,所以head计算完成之后,维度为(dim,)至此,多头循环结束。
    5. 注意力最后一步,与 $W^O$ 矩阵相乘,得到该层最终的注意力输出向量 (dim, ),存储到 s->xb2
    6. 接下来是 residual 的 Add&Norm,elemwise add 和 rmsnorm,结果存入 s->xb
    7. 接下来是FFN网络计算: self.w2(F.silu(self.w1(x)) * self.w3(x))
      1. s->hb=matmul(w->w1, s->xb)
      2. s->hb2=matmul(w->w3, s->xb)
      3. silu(x)=x*σ(x) for s->hb
      4. elemwise mul s->hb * s->hb2
      5. matmul,得到FFN输出,存储在 s->xb
    8. layer计算最后一步,residual 的 Add&Norm,注意该实现中rmsnorm放到了layer循环的最开始,这样除最后一层外都会residual add之后执行rmsnorm,而最后一层则放在classifier 的rmsnorm。至此,layer循环结束
  3. Classifier 前面的 final rmsnorm, 就是针对最后一层residual的结果做rmsnorm,存入 s->xb 中
  4. classifier into logits ,是个Linear层,从(dim,) → (vocab_size,) 的转换, 就是一个matmul: (vacob_size, dim) * (dim,1) → (vocab_size, 1) 结果就是logits,对应字典里每个token的分数。

其它函数

build_transformer

做了两件事情:

  1. 读取checkpoint文件并解析模型配置参数权重参数,存入结构体变量t里。checkpoint文件结构详见上一篇文章。
  2. 初始化运行时的RunState结构体,分配内存空间,包括模型输入输出,attention分数,kv cache等。

    注:t->s->kt->s->v不分配内存空间,计算是直接读取kv cache对应的内存区域。这个跟之前的版本有区别。

build_transformerrun.c
164
165
166
167
168
169
void build_transformer(Transformer *t, char* checkpoint_path) {
// read in the Config and the Weights from the checkpoint
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
// allocate the RunState buffers
malloc_run_state(&t->state, &t->config);
}

build_tokenizer

tokenizer文件结构详见后面章节,注意跟之前的版本有差别且不兼容。此外新增了

  • vocab_scores维度(vocab_size,)
  • byte_pieces维度(512,),初始化为 0~255 字符每个都带’\0’结尾隔开,总共512个字符。

该函数代码基本上就是文件和内存操作,初始化Tokenizer结构体,具体来说就是把tokenizer.bin文件读取到内存中,总共32k个词元字符串,每个词元开辟一块字符串数组内存区域单独存储,所有字符串指针放到一个数组里,可以按顺序遍历,也就是给定token,作为index即可取出对应词元的字符串。

文件结构详细信息可参考上一篇文章。

Tokenizerrun.c
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
// ----------------------------------------------------------------------------
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens

typedef struct {
char *str;
int id;
} TokenIndex;

typedef struct {
char** vocab;
float*vocab_scores;
TokenIndex*sorted_vocab;
int vocab_size;
unsigned int max_token_length;
unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer;
build_tokenizerrun.c
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
void build_tokenizer(Tokenizer*t, char* tokenizer_path, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh
t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings
t->vocab = (char**)malloc(vocab_size *sizeof(char*));
t->vocab_scores = (float*)malloc(vocab_size* sizeof(float));
t->sorted_vocab = NULL; // initialized lazily
for (int i = 0; i < 256; i++) {
t->byte_pieces[i * 2] = (unsigned char)i;
t->byte_pieces[i * 2 + 1] = '\0';
}
// read in the file
FILE *file = fopen(tokenizer_path, "rb");
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
int len;
for (int i = 0; i < vocab_size; i++) {
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
t->vocab[i] = (char*)malloc(len + 1);
if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
t->vocab[i][len] = '\0'; // add the string terminating token
}
fclose(file);
}

build_sampler

该版本将logits到文本字符串的采样过程重新做了封装,相关参数放到了Sampler结构体里。其次新增了ProbIndex以支持top-p采样方式,总共支持三种采样方式:

  1. 贪心算法,argmax取最大值
  2. 随机采样
  3. topp采样
Samplerrun.c
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
// ----------------------------------------------------------------------------
// The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling

typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling

typedef struct {
int vocab_size;
ProbIndex* probindex; // buffer used in top-p sampling
float temperature;
float topp;
unsigned long long rng_state;
} Sampler;

该函数初始化结构体及分配内存。probindex是维度为(vocab_size,)ProbIndex结构体数组。

build_samplerrun.c
667
668
669
670
671
672
673
674
void build_sampler(Sampler*sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
sampler->vocab_size = vocab_size;
sampler->temperature = temperature;
sampler->topp = topp;
sampler->rng_state = rng_seed;
// buffer only used with nucleus sampling; may not need but it's ~small
sampler->probindex = malloc(sampler->vocab_size* sizeof(ProbIndex));
}

encode

Todo:

decode

解码过程就是token:int -> vocab:str的转换,此处处理了空格及一些特殊符号。

decoderun.c
418
419
420
421
422
423
424
425
426
427
428
429
char* decode(Tokenizer* t, int prev_token, int token) {
char *piece = t->vocab[token];
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
if (prev_token == 1 && piece[0] == ' ') { piece++; }
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
// parse this and convert and return the actual byte
unsigned char byte_val;
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
piece = (char*)t->byte_pieces + byte_val * 2;
}
return piece;
}

sample

根据temperature、topp等超参数,将推理输出结果logits采样得到token id。根据超参数值的不同,支持三种采样方式:

  1. temperature == 0.0f,则采用贪心算法,取argmax最大值,返回sample_argmax结果
  2. topp<=0 || topp>=1,随机采样,返回sample_mult结果
  3. 否则 top-p 采样,返回sample_topp结果
samplerun.c 691
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
int sample(Sampler*sampler, float* logits) {
// sample the token given the logits and some hyperparameters
int next;
if (sampler->temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = sample_argmax(logits, sampler->vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(logits, sampler->vocab_size);
// flip a (float) coin (this is our source of entropy for sampling)
float coin = random_f32(&sampler->rng_state);
// we sample from this distribution to get the next token
if (sampler->topp <= 0 || sampler->topp >= 1) {
// simply sample from the predicted probability distribution
next = sample_mult(logits, sampler->vocab_size, coin);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
}
}
return next;
}

sample_argmax

循环遍历所有元素,返回最大值的索引。

sample_argmaxrun.c 590
590
591
592
593
594
595
596
597
598
599
600
601
int sample_argmax(float* probabilities, int n) {
// return the index that has the highest probability
int max_i = 0;
float max_p = probabilities[0];
for (int i = 1; i < n; i++) {
if (probabilities[i] > max_p) {
max_i = i;
max_p = probabilities[i];
}
}
return max_i;
}

sample_mult

累加结果超过给定随机值,就返回当前的索引。

sample_multrun.c 603
603
604
605
606
607
608
609
610
611
612
613
614
int sample_mult(float* probabilities, int n, float coin) {
// sample index from probabilities (they must sum to 1!)
// coin is a random number in [0, 1), usually from random_f32()
float cdf = 0.0f;
for (int i = 0; i < n; i++) {
cdf += probabilities[i];
if (coin < cdf) {
return i;
}
}
return n - 1; // in case of rounding errors
}

sample_topp

$$ topp(p) = \max{k : \sum_{i=1}^{k} P(w_i) \leq p} $$

计算过程:

  1. 首先设定阈值为 float cutoff = (1.0f - topp) / (n-1)
  2. 初筛,挑选出概率大于阈值的所有token,其它的舍弃,结果存入probindex数组中。
  3. 再按照概率值降序排列
  4. 细筛,从前往后累加,得到不大于topp的和cumulative_prob及对应的所以last_idx,然后截取last_idx之前的部分作为候选列表。
  5. 采样,设定阈值float r = coin * cumulative_prob,从候选列表里找累加概率值不大于阈值r的token返回

关键数据结构

kv_cache内存结构

首先kv cache每个head的维度必须要与query保持一致才能做点积所以都是:head_dim=dim/n_heads,其次是kv各自有n_kv_heads个head,所以很容易得出kv cache的维度为(n_layers, seq_len, n_kv_heads, head_dim)。那么从layer层面看(尚未区分多头)的话,每个token对应kv的维度就是kv_dim = n_kv_heads * head_dim,其内存结构示意图如下。这跟query(dim)是不一样的,切记切记。相应的,$W_k$, $W_v$权重矩阵维度也变成了(layer, dim, kv_dim)跟Query不一样。切记切记。所以我们可以看到,MQA/GQA减少了kv head,实际上是缩小了两个权重矩阵的大小,同时有另外一层隐藏的含义,就是每个token对应的kv dim维度也变小了,也就是信息容量更少了,性能肯定受影响。

注:也许好奇,GQA为何kv总的维度不合q保持一致都是dim,因为head_dimdim只能有一个相同,否则head数量就得一致,也就是MHA了。

接下来,因为n_kv_heads<=n_heads所以必然有kv_dim是要被多个相邻的query head共享的,举个例子,假如n_heads=6,n_kv_heads=3,则kv_dim = 3 * head_dim, 那么每个kv head会被2个query head共享。所以6个query head对应的kv head索引分别是:(0, 0, 1, 1, 2, 2),每个长度是head_dim所以加起来正好是kv_dim,相邻两个query对应的kv head是同一个。

layer-0 layer-1 layer-n
token-0 token-1 token-k token-0 token-1 token-k token-0 token-1 token-k
(kv_dim,) (kv_dim,) (kv_dim,) (kv_dim,) (kv_dim,) (kv_dim,) (kv_dim,) (kv_dim,) (kv_dim,)

tokenizer文件

该文件结构与之前版本有差异,文件开头增加了一个4字节的intmax_token_length,统计了所有token最大的长度。其次是遍历每个token的内容,包括新增的一个floatvocab_scores,然后是4个字节int型的token长度,紧接着是token内容(不包括结束符的字符串),所有token一个挨一个存储。内存结构示意图如下。

vocab_size 存储在 checkpoint 文件里。

文件头 token-0 token-1 token-n
max_token_length vocab_scores token len token content vocab_scores token len token content vocab_scores token len token content
27 0.0 1 l 0.0 4 like 0.0 11 suggestions

总结

MQAGQA是最近出现的技术,原本的注意力机制是Multi-Head Attention(MHA),每个query head对应一个key、value的head,最简化的情况是MQA,所有query head共享一组key、value head,$W_k$和$W_v$矩阵维度缩小为原来的1/n_heads,内存占用和推理速度提升非常明显,但相应的性能下降了(每个token对应的维度也变为原来的1/n_heads,信息被严重压缩)。折中的方案就是GQA,处于两者中间,且可以配置n_kv_heads大小,自由权衡取舍。

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

Joshua AinslieGQA