Text Generation Using LLMs
Convert and Optimize Model
- From Hugging Face
- From Model Scope
Use optimum-intel
package to convert and optimize models:
pip install optimum-intel[openvino]
Download and convert a model to the OpenVINO IR format:
- Compress weights to the int4 precision
- Keep full model precision
optimum-cli export openvino --model meta-llama/Llama-2-7b-chat-hf --weight-format int4 ov_llama_2_7b_int4 --trust-remote-code
optimum-cli export openvino --model meta-llama/Llama-2-7b-chat-hf --weight-format fp16 ov_llama_2_7b_fp16 --trust-remote-code
Check a full list of conversion options here.
You can also use pre-converted LLMs.
Use modelscope
and optimum-intel
packages to convert and optimize models:
pip install modelscope optimum-intel[openvino]
Download the required model to a local folder:
modelscope download --model 'Qwen/Qwen2-7b' --local_dir model_path
Convert the model and compress weights:
- INT4
- INT8
- FP16
optimum-cli export openvino -m model_path --weight-format int4 ov_qwen2_7b_int4 --task text-generation-with-past
optimum-cli export openvino -m model_path --weight-format int8 ov_qwen2_7b_int8 --task text-generation-with-past
optimum-cli export openvino -m model_path --weight-format fp16 ov_qwen2_7b_fp16 --task text-generation-with-past
Run Model Using OpenVINO™ GenAI
LLMPipeline
is the main object used for decoding. You can construct it straight away from the folder with the converted model.
It will automatically load the main model, tokenizer, detokenizer and default generation configuration.
- Python
- C++
- CPU
- GPU
import openvino_genai as ov_genai
pipe = ov_genai.LLMPipeline(model_path, "CPU")
print(pipe.generate("What is OpenVINO?", max_new_tokens=100))
import openvino_genai as ov_genai
pipe = ov_genai.LLMPipeline(model_path, "GPU")
print(pipe.generate("What is OpenVINO?", max_new_tokens=100))
- CPU
- GPU
#include "openvino/genai/llm_pipeline.hpp"
#include <iostream>
int main(int argc, char* argv[]) {
std::string models_path = argv[1];
ov::genai::LLMPipeline pipe(model_path, "CPU");
std::cout << pipe.generate("What is OpenVINO?", ov::genai::max_new_tokens(100)) << '\n';
}
#include "openvino/genai/llm_pipeline.hpp"
#include <iostream>
int main(int argc, char* argv[]) {
std::string models_path = argv[1];
ov::genai::LLMPipeline pipe(model_path, "GPU");
std::cout << pipe.generate("What is OpenVINO?", ov::genai::max_new_tokens(100)) << '\n';
}
Use CPU or GPU as devices without any other code change.
Additional Usage Options
Use Different Generation Parameters
Fine-tune your LLM's output by adjusting various generation parameters. OpenVINO GenAI supports multiple sampling strategies and generation configurations to help you achieve the desired balance between deterministic and creative outputs.
Basic Generation Configuration
- Get the model default config with
get_generation_config()
- Modify parameters
- Apply the updated config:
- Use
set_generation_config(config)
- Pass config directly to
generate()
(e.g.generate(prompt, config)
) - Specify options as inputs in the
generate()
method (e.g.generate(prompt, max_new_tokens=100)
)
- Use
- Python
- C++
import openvino_genai as ov_genai
pipe = ov_genai.LLMPipeline(model_path, "CPU")
# Get default configuration
config = pipe.get_generation_config()
# Modify parameters
config.max_new_tokens = 100
config.temperature = 0.7
config.top_k = 50
config.top_p = 0.9
config.repetition_penalty = 1.2
# Generate text with custom configuration
output = pipe.generate("The Sun is yellow because", config)
int main() {
ov::genai::LLMPipeline pipe(model_path, "CPU");
// Get default configuration
auto config = pipe.get_generation_config();
// Modify parameters
config.max_new_tokens = 100;
config.temperature = 0.7f;
config.top_k = 50;
config.top_p = 0.9f;
config.repetition_penalty = 1.2f;
// Generate text with custom configuration
auto output = pipe.generate("The Sun is yellow because", config);
}
max_new_tokens
: The maximum numbers of tokens to generate, excluding the number of tokens in the prompt.max_new_tokens
has priority overmax_length
.temperature
: Controls the level of creativity in AI-generated text:- Low temperature (e.g. 0.2) leads to more focused and deterministic output, choosing tokens with the highest probability.
- Medium temperature (e.g. 1.0) maintains a balance between creativity and focus, selecting tokens based on their probabilities without significant bias.
- High temperature (e.g. 2.0) makes output more creative and adventurous, increasing the chances of selecting less likely tokens.
top_k
: Limits token selection to the k most likely next tokens. Higher values allow more diverse outputs.top_p
: Selects from the smallest set of tokens whose cumulative probability exceeds p. Helps balance diversity and quality.repetition_penalty
: Reduces the likelihood of repeating tokens. Values above 1.0 discourage repetition.
For the full list of generation parameters, refer to the API reference.
Optimizing Generation with Grouped Beam Search
Beam search helps explore multiple possible text completions simultaneously, often leading to higher quality outputs.
- Python
- C++
import openvino_genai as ov_genai
pipe = ov_genai.LLMPipeline(model_path, "CPU")
# Get default generation config
config = pipe.get_generation_config()
# Modify parameters
config.max_new_tokens = 256
config.num_beams = 15
config.num_beam_groups = 3
config.diversity_penalty = 1.0
# Generate text with custom configuration
print(pipe.generate("The Sun is yellow because", config))
int main(int argc, char* argv[]) {
std::string model_path = argv[1];
ov::genai::LLMPipeline pipe(model_path, "CPU");
// Get default generation config
ov::genai::GenerationConfig config = pipe.get_generation_config();
// Modify parameters
config.max_new_tokens = 256;
config.num_beams = 15;
config.num_beam_groups = 3;
config.diversity_penalty = 1.0f;
// Generate text with custom configuration
cout << pipe.generate("The Sun is yellow because", config);
}
max_new_tokens
: The maximum numbers of tokens to generate, excluding the number of tokens in the prompt.max_new_tokens
has priority overmax_length
.num_beams
: The number of beams for beam search. 1 disables beam search.num_beam_groups
: The number of groups to dividenum_beams
into in order to ensure diversity among different groups of beams.diversity_penalty
: value is subtracted from a beam's score if it generates the same token as any beam from other group at a particular time.
For the full list of generation parameters, refer to the API reference.
Using GenAI in Chat Scenario
For chat applications, OpenVINO GenAI provides special optimizations to maintain conversation context and improve performance using KV-cache.
Use start_chat()
and finish_chat()
to properly manage the chat session's KV-cache. This improves performance by reusing context between messages.
A simple chat example (with grouped beam search decoding):
- Python
- C++
import openvino_genai as ov_genai
pipe = ov_genai.LLMPipeline(model_path, 'CPU')
config = {'max_new_tokens': 100, 'num_beam_groups': 3, 'num_beams': 15, 'diversity_penalty': 1.5}
pipe.set_generation_config(config)
pipe.start_chat()
while True:
try:
prompt = input('question:\n')
except EOFError:
break
answer = pipe.generate(prompt)
print('answer:\n')
print(answer)
print('\n----------\n')
pipe.finish_chat()
#include "openvino/genai/llm_pipeline.hpp"
#include <iostream>
int main(int argc, char* argv[]) {
std::string prompt;
std::string model_path = argv[1];
ov::genai::LLMPipeline pipe(model_path, "CPU");
ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
config.num_beam_groups = 3;
config.num_beams = 15;
config.diversity_penalty = 1.0f;
pipe.start_chat();
std::cout << "question:\n";
while (std::getline(std::cin, prompt)) {
std::cout << "answer:\n";
auto answer = pipe.generate(prompt, config);
std::cout << answer << std::endl;
std::cout << "\n----------\n"
"question:\n";
}
pipe.finish_chat();
}
Streaming the Output
For more interactive UIs during generation, you can stream output tokens.
Streaming Function
In this example, a function outputs words to the console immediately upon generation:
- Python
- C++
import openvino_genai as ov_genai
pipe = ov_genai.LLMPipeline(model_path, "CPU")
# Create a streamer function
def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
return ov_genai.StreamingStatus.RUNNING
pipe.start_chat()
while True:
try:
prompt = input('question:\n')
except EOFError:
break
pipe.generate(prompt, streamer=streamer, max_new_tokens=100)
print('\n----------\n')
pipe.finish_chat()
#include "openvino/genai/llm_pipeline.hpp"
#include <iostream>
int main(int argc, char* argv[]) {
std::string prompt;
std::string model_path = argv[1];
ov::genai::LLMPipeline pipe(model_path, "CPU");
// Create a streamer function
auto streamer = [](std::string word) {
std::cout << word << std::flush;
// Return flag corresponds whether generation should be stopped.
return ov::genai::StreamingStatus::RUNNING;
};
pipe.start_chat();
std::cout << "question:\n";
while (std::getline(std::cin, prompt)) {
pipe.generate(prompt, ov::genai::streamer(streamer), ov::genai::max_new_tokens(100));
std::cout << "\n----------\n"
"question:\n";
}
pipe.finish_chat();
}
For more information, refer to the chat sample.
Custom Streamer Class
You can also create your custom streamer for more sophisticated processing:
- Python
- C++
import openvino_genai as ov_genai
pipe = ov_genai.LLMPipeline(model_path, "CPU")
# Create custom streamer class
class CustomStreamer(ov_genai.StreamerBase):
def __init__(self):
super().__init__()
# Initialization logic.
def write(self, token_id) -> bool:
# Custom decoding/tokens processing logic.
# Return flag corresponds whether generation should be stopped.
return ov_genai.StreamingStatus.RUNNING
def end(self):
# Custom finalization logic.
pass
pipe.start_chat()
while True:
try:
prompt = input('question:\n')
except EOFError:
break
pipe.generate(prompt, streamer=CustomStreamer(), max_new_tokens=100)
print('\n----------\n')
pipe.finish_chat()
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/llm_pipeline.hpp"
#include <iostream>
// Create custom streamer class
class CustomStreamer: public ov::genai::StreamerBase {
public:
bool write(int64_t token) {
// Custom decoding/tokens processing logic.
// Return flag corresponds whether generation should be stopped.
return ov::genai::StreamingStatus::RUNNING;
};
void end() {
// Custom finalization logic.
};
};
int main(int argc, char* argv[]) {
std::string prompt;
CustomStreamer custom_streamer;
std::string model_path = argv[1];
ov::genai::LLMPipeline pipe(model_path, "CPU");
pipe.start_chat();
std::cout << "question:\n";
while (std::getline(std::cin, prompt)) {
pipe.generate(prompt, ov::genai::streamer(custom_streamer), ov::genai::max_new_tokens(100));
std::cout << "\n----------\n"
"question:\n";
}
pipe.finish_chat();
}
For fully implemented iterable CustomStreamer refer to multinomial_causal_lm sample.
Working with LoRA Adapters
LoRA (Low-Rank Adaptation) allows you to fine-tune models for specific tasks efficiently. OpenVINO GenAI supports dynamic loading and switching between LoRA adapters without recompilation.
- Python
- C++
import openvino_genai as ov_genai
# Initialize pipeline with adapters
adapter_config = ov_genai.AdapterConfig()
# Add multiple adapters with different weights
adapter1 = ov_genai.Adapter("path/to/lora1.safetensors")
adapter2 = ov_genai.Adapter("path/to/lora2.safetensors")
adapter_config.add(adapter1, alpha=0.5)
adapter_config.add(adapter2, alpha=0.5)
pipe = ov_genai.LLMPipeline(
model_path,
"CPU",
adapters=adapter_config
)
# Generate with current adapters
output1 = pipe.generate("Generate story about", max_new_tokens=100)
# Switch to different adapter configuration
new_config = ov_genai.AdapterConfig()
new_config.add(adapter1, alpha=1.0)
output2 = pipe.generate(
"Generate story about",
max_new_tokens=100,
adapters=new_config
)
#include "openvino/genai/llm_pipeline.hpp"
int main() {
ov::genai::AdapterConfig adapter_config;
// Add multiple adapters with different weights
ov::genai::Adapter adapter1("path/to/lora1.safetensors");
ov::genai::Adapter adapter2("path/to/lora2.safetensors");
adapter_config.add(adapter1, 0.5f);
adapter_config.add(adapter2, 0.5f);
ov::genai::LLMPipeline pipe(
model_path,
"CPU",
ov::genai::adapters(adapter_config)
);
// Generate with current adapters
auto output1 = pipe.generate("Generate story about", ov::genai::max_new_tokens(100));
// Switch to different adapter configuration
ov::genai::AdapterConfig new_config;
new_config.add(adapter1, 1.0f);
auto output2 = pipe.generate(
"Generate story about",
ov::genai::adapters(new_config),
ov::genai::max_new_tokens(100)
);
}
Accelerate Generation via Speculative Decoding
Speculative decoding (or assisted-generation in HF terminology) is a recent technique, that allows to speed up token generation when an additional smaller draft model is used alongside with the main model. This reduces the number of infer requests to the main model, increasing performance.
How Speculative Decoding Works
The draft model predicts the next K tokens one by one in an autoregressive manner, while the main model validates these predictions and corrects them if necessary. We go through each predicted token, and if a difference is detected between the draft and main model, we stop and keep the last token predicted by the main model. Then the draft model gets the latest main prediction and again tries to predict the next K tokens, repeating the cycle.
This approach reduces the need for multiple infer requests to the main model, enhancing performance. For instance, in more predictable parts of text generation, the draft model can, in best-case scenarios, generate the next K tokens that exactly match the target. In that case they are validated in a single inference request to the main model (which is bigger, more accurate but slower) instead of running K subsequent requests.
More details can be found in the original papers:
- Python
- C++
import openvino_genai
import queue
import threading
def streamer(subword):
print(subword, end='', flush=True)
return openvino_genai.StreamingStatus.RUNNING
def infer(model_dir: str, draft_model_dir: str, prompt: str):
main_device = 'CPU' # GPU can be used as well.
draft_device = 'CPU'
# Configure cache for better performance
scheduler_config = openvino_genai.SchedulerConfig()
scheduler_config.cache_size = 2 # in GB
# Initialize draft model
draft_model = openvino_genai.draft_model(
draft_model_dir,
draft_device
)
# Create pipeline with draft model
pipe = openvino_genai.LLMPipeline(
model_dir,
main_device,
scheduler_config=scheduler_config,
draft_model=draft_model
)
# Configure speculative decoding
config = openvino_genai.GenerationConfig()
config.max_new_tokens = 100
config.num_assistant_tokens = 5 # Number of tokens to predict speculatively
pipe.generate("The Sun is yellow because", config, streamer)
#include <openvino/openvino.hpp>
#include "openvino/genai/llm_pipeline.hpp"
int main(int argc, char* argv[]) {
if (4 != argc) {
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> <DRAFT_MODEL_DIR> '<PROMPT>'");
}
ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
config.num_assistant_tokens = 5; // Number of tokens to predict speculatively
std::string main_model_path = argv[1];
std::string draft_model_path = argv[2];
std::string prompt = argv[3];
std::string main_device = "CPU", draft_device = "CPU";
ov::genai::SchedulerConfig scheduler_config;
scheduler_config.cache_size = 5; // in GB
ov::genai::LLMPipeline pipe(
main_model_path,
main_device,
ov::genai::draft_model(draft_model_path, draft_device),
ov::genai::scheduler_config(scheduler_config));
auto streamer = [](std::string word) {
std::cout << word << std::flush;
return ov::genai::StreamingStatus::RUNNING;
};
pipe.generate("The Sun is yellow because", config, streamer);
}
For more information, refer to the Speculative Decoding sample.