Llama2.jl
Documentation for Llama2.jl.
Reference
Llama2.ConfigLlama2.ProbIndexLlama2.RunStateLlama2.SamplerLlama2.SamplerLlama2.TokenizerLlama2.TransformerLlama2.TransformerWeightsLlama2.decodeLlama2.encodeLlama2.forward!Llama2.generateLlama2.read_karpathyLlama2.read_karpathy_configLlama2.read_karpathy_weightsLlama2.rmsnorm!Llama2.sample_argmaxLlama2.sample_multLlama2.sample_toppLlama2.softmax!Llama2.swiglu!
Llama2.Config — Typestruct Config{T<:Integer}Used to configure the initial parameters.
Fields
dim::Integer: Transformer Dimensionhidden_dim::Integer: ffn Layersn_layers::Integer: Number of Layersn_heads::Integer: Number of Query Headsn_kv_heads::Integer: Number of key/value headsvocab_size::Integer: Vocabulary Sizeseq_len::Integer: Max Sequence Length
Initializes parameters and checks for the correct dimensions. For example, the config can be read from a file using the read_karpathy_config function and is part of the TransformerWeights function.
llama2.c correspondence Config (l.19)
Llama2.ProbIndex — Typestruct ProbIndex{T<:Real}Used when sorting probabilities during top-p sampling
Llama2.RunState — Typemutable struct RunState{T<:Real}State of the transformer model. The matrices are modified during a forward pass. It should never be necessary to manually modify this. While some of these arrays preserve actual neccessary state, some of them serve as preallocated buffers to speed up computation in the forward! method.
Fields
x::Vector{T} where T<:Real: Activations at current time stamp. Shape: (dim,)xb::Vector{T} where T<:Real: Activations at current time stamp inside a residual branch. Shape: (dim,)xb2::Vector{T} where T<:Real: An additional activation buffer for convenience. Shape: (dim,)hb::Vector{T} where T<:Real: Buffer for the hidden dimension in the feed-forward net. Shape: (hidden_dim,)hb2::Vector{T} where T<:Real: Buffer for the hidden dimension in the feed-forward net. Shape: (hidden_dim,)q::Vector{T} where T<:Real: Stores the query vector in the attention part. Shape: (nheads * headsize,)att::Matrix{T} where T<:Real: Buffer for the attention scores. Shape: (nheads, seqlen)logits::Vector{T} where T<:Real: The output logits. Shape: (vocab_size,)key_cache::Array{T, 3} where T<:Real: Cache for all the keys in the attention part. Shape: (nkvheads * headsize, seqlen, n_layers)value_cache::Array{T, 3} where T<:Real: Cache for all the values in the attention part. Shape: (nkvheads * headsize, seqlen, n_layers)
llama2.c correspondence: RunState (l. 50)
Allocate from config
function RunState(config::Config) where {T<:Real}Initializes the matrices in RunState based on the shapes provided in the Config.
Llama2.Sampler — Typestruct Sampler{T<:Real}Sampler()
function Sampler{T}(temperature::T, topp::T, rng_seed::Integer) where {T<:Real}Used to return a sampled token (index) based on given logits. Depending on the parameters, the sampler supports greedy argmax, multinomial, or top-p sampling. It is recommended to either adjust the temperature or top-p to a non-default value but not both since they do similar things (constrain the sampling).
Fields
temperature::Real: Logits are divided by this value. A higher temperature value makes the output more diverse while a lower temperature makes the output more deterministic, converging to greedy argmax sampling at 0.topp::Real: Used for top-p sampling. Only consider the set of most likely tokens whose probabilities sum up to this value. If this is 0 or 1, no top-p sampling is used. For other values, this prevents less likely tokens from being sampled.rng_state::Random.MersenneTwister
llama2.c correspondence: Sampler (l. 577 - 715)
Example
julia> sampler_mult = Sampler{Float64}(0.5, 0.0, 1)
Sampler{Float64}(0.5, 0.0, Random.MersenneTwister(1))
julia> [sampler_mult([-0.5, 0.5, 0.2]) for i in 1:10]
10-element Vector{Int64}:
2
2
2
1
2
2
3
3
2
3
julia> sampler_det = Sampler{Float64}(0.0, 0.0, 1)
Sampler{Float64}(0.0, 0.0, Random.MersenneTwister(42))
julia> [sampler_det([-0.5, 0.5, 0.2]) for i in 1:10]
10-element Vector{Int64}:
2
2
2
2
2
2
2
2
2
2
julia> sampler_topp = Sampler{Float64}(1.0, 0.5, 1)
Sampler{Float64}(1.0, 0.5, Random.MersenneTwister(1))
julia> [sampler_topp([-0.5, 0.5, 0.2]) for i in 1:10]
10-element Vector{Int64}:
2
2
2
2
2
2
3
3
2
3Llama2.Sampler — MethodSample the next token id based on the logits.
The sampling strategy is selected based on the temperature and topp parameters of the Sampler:
- If
temperature == 0, always take the token with the highest probability (greedy argmax sampling), seesample_argmax. - If
toppis 0 or 1, apply the temperature to the logits and sample from the predicted probability distribution (multinomial sampling), seesample_mult. - Otherwise, only sample from the smallest set of most likely tokens whose probabilities sum up to at least
topp(top-p sampling), seesample_topp. The temperature is still applied before.
Llama2.Tokenizer — Typestruct Tokenizer{T<:Real}Used for mapping from strings to token arrays (Int vectors) and back.
Fields
index_to_token::Vector{String}: Maps a token index to its string representation, for decodingtoken_to_index::Dict{String, Int64}: Maps a token string to its token index, for encodingvocab_scores::Vector{T} where T<:Real: Scores of individual tokens for encoding
llama2.c correspondence: Tokenizer (l. 372)
- indextotoken = vocab
- tokentoindex = sorted_vocab
- removed maxtokenlength (not required in Julia)
- removed byte_pieces (not required in Julia)
Load from Karpathy bin file
Tokenizer(tokenizer_path::String, vocab_size::Int)Constructs a Tokenizer by loading the vocabulary from a file in the llama2.c format. The vocabulary size must be known from the config.
Example
julia> Tokenizer("bin/tokenizer/tokenizer.bin", 32000)
Tokenizer(["<unk>", "
<s>
", "
</s>
", "<0x00>", "<0x01>", "<0x02>", "<0x03>", "<0x04>", "<0x05>", "<0x06>" … "ὀ", "げ", "べ", "边", "还", "黃", "왕", "收", "弘", "给"], Dict("âr" => 28727, " properly" => 6285, "chem" => 14970, " patients" => 22070, " Plan" => 8403, "<0x2A>" => 46, "рос" => 10375, "null" => 4305, "rę" => 15387, "ört" => 21069…), Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … -31731.0, -31732.0, -31733.0, -31734.0, -31735.0, -31736.0, -31737.0, -31738.0, -31739.0, -31740.0])llama2.c correspondence: build_tokenizer (l. 385)
Llama2.Transformer — Typestruct Transformer{T<:Real}A transformer model, consisting of a config, weights, and a run state.
Fields
config::Config: Hyperparameters of the architectureweights::TransformerWeights: Weights of the modulestate::RunState: Buffers for the wave of activations in the forward pass
llama2.c correspondence: Transformer (l. 67)
Llama2.TransformerWeights — Typestruct TransformerWeights{T<:Real}function TransformerWeights(config::Config) where {T<:Real}Holds the weights for the Llama2 transformer model.
Fields
token_embedding_table::Matrix{T} where T<:Real: Token embedding table: Mapping from token index to embedding vector. Shape: (dim, vocab_size)rms_att_weight::Matrix{T} where T<:Real: Weights for rmsnorm before the attention for each layer. Shape: (dim, n_layers)rms_ffn_weight::Matrix{T} where T<:Real: Weights for rmsnorm before the feed-forward net for each layer. Shape: (dim, n_layers)wq::Array{T, 3} where T<:Real: Query weights for each attention layer. Shape: (nheads * headsize, dim, n_layers)wk::Array{T, 3} where T<:Real: Key weights for each attention layer. Shape: (dim, kvdim, nlayers)wv::Array{T, 3} where T<:Real: Value weights for each attention layer. Shape: (dim, kvdim, nlayers)wo::Array{T, 3} where T<:Real: Output weights for each attention layer. Shape: (nheads * headsize, dim, n_layers)w1::Array{T, 3} where T<:Real: First weight matrix for each feed forward layer (in -> hidden). Shape: (dim, hiddendim, nlayers)w2::Array{T, 3} where T<:Real: Second weight matrix for each feed forward layer (hidden -> out). Shape: (hiddendim, dim, nlayers)w3::Array{T, 3} where T<:Real: Third weight matrix for each feed forward layer (in -> hidden). Shape: (dim, hiddendim, nlayers)rms_final_weight::Vector{T} where T<:Real: Weights for the final rmsnorm before the optional classifier head. Shape: (dim,)wcls::Matrix{Float32}: Weights for the optional classifier head. If there is no classifier (the usual case), this should equal tokenembeddingtable, translating embeddings back to logits. This is inspired by the original llama2.c implementation. Shape: (dim, vocab_size)
llama2.c correspondence: TransformerWeights (l. 29)
Allocate from config
To create a new TransformerWeights instance with preallocated matrices, use the config constructor:
function TransformerWeights(config::Config) where {T<:Real}llama2.c correspondence: memorymapweights (l. 111)
Llama2.decode — Methoddecode(
tokenizer::Tokenizer,
prev_token::Int64,
token::Int64
) -> String
Decodes a token index to a string. If the previous token is BOS (=2) and the token value starts with a leading space, the leading space is removed. Token indices are 1-based (different to the 0-based system in llama2.c).
Example
julia> [decode(tokenizer, 1, t) for t in [2, 15044, 3187, 29992]]
4-element Vector{String}:
"
<s>
"
" Hello"
" world"
"!"
julia> decode(tokenizer, 1, 15044)
" Hello"
julia> decode(tokenizer, 2, 15044) # BOS strips leading space
"Hello"llama2.c correspondence: decode (l. 418)
Llama2.encode — Functionencode(tokenizer::Tokenizer, text::String) -> Vector{Int64}
encode(
tokenizer::Tokenizer,
text::String,
eos_token::Bool
) -> Vector{Int64}
Encode a string text using a Tokenizer. An optional EOS token can be added. Encoded text can be decoded with the decode function.
Works by encoding each code unit as a single token, then iteratively merging them together according to the Tokenizer's vocab_scores.
Note that token indices are 1-based (different to the 0-based system in the llama2.c).
Example
julia> encode(tokenizer, "Hello world!")
4-element Vector{Int64}:
2
15044
3187
29992llama2.c correspondence: encode (l. 452)
Llama2.forward! — Methodforward!(
transformer::Transformer{T<:Real},
token::Integer,
pos::Integer
) -> Vector{T} where T<:Real
A single complete transformer forward pass for input token token at position pos, returning the output logits.
posis one-based, i.e. 1 <=pos<=seq_len.tokenis also a one-based token index, 1 <=token<=vocab_size.- The output logits are a vector of length
vocab_size, representing the predictions of the likelihood of each token (before softmax).
This modifies the RunState of the transformer. To generate sequences using the transformer, call this method repeatedly with increasing pos values, starting from 1.
llama2.c correspondence: forward (l. 231)
Example
To run token 5 at position 1 through the transformer and get the predicted output logits:
julia> forward!(transformer, 5, 1)
32000-element Vector{Float32}:
-2.1009917
1.664739
-2.1005554
-2.1007848
-2.1005578
-2.1009412
⋮
-2.1007295
-2.100759
-2.1007874
-2.1009996
-2.1009269
-2.1007652Llama2.generate — Methodgenerate(
model::Transformer{T<:Real},
tokenizer::Tokenizer,
sampler::Sampler{T<:Real},
prompt::String;
verbose,
display_output,
display_prompt,
max_steps
) -> String
Generate a sequence based on a given language model, tokenizer, sampler and prompt.
There are several optional boolean flags:
verbose::Bool: Print the achieved tokens/sdisplay_output::Bool: Print the outputdisplay_prompt::Bool: Print the prompt. Ignored ifdisplay_outputisfalse.max_steps::Int: Maximum number of generation steps.
llama2.c correspondence: generation loop (l. 729-783)
Llama2.read_karpathy — Methodread_karpathy(
file_path::String
) -> Tuple{Config{Int32}, TransformerWeights{Float32}}
Reads a Karpathy file and returns the Config and Weights using the read_karpathy_config function and the read_karpathy_weights function.
Llama2.read_karpathy_config — Methodread_karpathy_config(file::IOStream) -> Config{Int32}
Read a Config from a Karpathy model file.
llama2.c correspondence: read_config (l. 147)
Llama2.read_karpathy_weights — Methodread_karpathy_weights(
config::Config,
file::IOStream
) -> TransformerWeights{Float32}
Read the weights of a Karpathy file and return them using the TransformerWeights function.
llama2.c correspondence: memorymapweights (l. 111)
Llama2.rmsnorm! — Methodrmsnorm!(
o::AbstractArray{T<:Real},
x::AbstractArray{T<:Real},
weight::AbstractArray{T<:Real}
)
Calculate the root mean square norm of a vector. Reference in llama2.c lines 182-195
Llama2.sample_argmax — Methodsample_argmax(logits::AbstractArray{T<:Real, 1}) -> Any
Deterministically sample the token with the highest probability.
Example
julia> sample_argmax([-0.5, 0.0, 0.5])
3Llama2.sample_mult — Methodsample_mult(
probabilities::AbstractArray{T<:Real, 1},
coin::Real
) -> Any
Sample index from a probability distribution (must sum to 1). Coin is a random number in [0, 1). Find the index that coin falls into.
Examples
julia> sample_mult([0.1, 0.2, 0.3, 0.4], 0.05)
1
julia> sample_mult([0.1, 0.2, 0.3, 0.4], 0.15)
2
julia> sample_mult([0.1, 0.2, 0.3, 0.4], 0.8)
4Llama2.sample_topp — Methodsample_topp(
probabilities::AbstractArray{T<:Real, 1},
topp::Real,
coin::Real
) -> Any
Top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed probability topp. This way we never sample tokens that have very low probabilities and are less likely to go "off the rails". Coin is a random number in [0, 1).
Examples
julia> sample_topp([0.1, 0.2, 0.3, 0.4], 1.0, 0.9)
1
julia> sample_topp([0.1, 0.2, 0.3, 0.4], 0.5, 0.9)
3
julia> sample_topp([0.1, 0.2, 0.3, 0.4], 0.4, 0.9)
3
julia> sample_topp([0.1, 0.2, 0.3, 0.4], 0.39, 0.9)
4Llama2.softmax! — Methodsoftmax!(x::AbstractArray{T<:Real})
Calculate the softmax of a vector. Reference in llama2.c lines 197-215
Llama2.swiglu! — Methodswiglu!(
x::AbstractArray{T<:Real},
x2::AbstractArray{T<:Real}
)
Activation function that combines GLU and Swish functions.
\[swiglu(x, x_2) = x * x_2 * sigmoid(x)\]
Reference in llama2.c lines 338-345