From b398a4641ae5ebb5bde027d12f96059f2bdae814 Mon Sep 17 00:00:00 2001 From: Bartosz Taudul Date: Thu, 22 May 2025 01:51:15 +0200 Subject: [PATCH] Switch from Ollama API to OpenAI API commonly used by all LLM providers. --- cmake/ollama-hpp-badcode.patch | 13 - cmake/ollama-hpp-string.patch | 16 -- cmake/vendor.cmake | 12 - profiler/CMakeLists.txt | 5 +- profiler/src/main.cpp | 4 +- profiler/src/profiler/TracyConfig.hpp | 1 - profiler/src/profiler/TracyLlm.cpp | 365 ++++++++++++------------ profiler/src/profiler/TracyLlm.hpp | 38 +-- profiler/src/profiler/TracyLlmApi.cpp | 177 ++++++++++++ profiler/src/profiler/TracyLlmApi.hpp | 52 ++++ profiler/src/profiler/TracyLlmTools.cpp | 21 +- profiler/src/profiler/TracyLlmTools.hpp | 8 +- profiler/src/profiler/TracyView.cpp | 5 +- 13 files changed, 438 insertions(+), 279 deletions(-) delete mode 100644 cmake/ollama-hpp-badcode.patch delete mode 100644 cmake/ollama-hpp-string.patch create mode 100644 profiler/src/profiler/TracyLlmApi.cpp create mode 100644 profiler/src/profiler/TracyLlmApi.hpp diff --git a/cmake/ollama-hpp-badcode.patch b/cmake/ollama-hpp-badcode.patch deleted file mode 100644 index 2b94b8c4..00000000 --- a/cmake/ollama-hpp-badcode.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git i/include/ollama.hpp w/include/ollama.hpp -index eaa7a51..9c8ca18 100644 ---- i/include/ollama.hpp -+++ w/include/ollama.hpp -@@ -851,7 +848,7 @@ class Ollama - } - else - { -- throw ollama::exception("Error retrieving version: "+res->status); -+ throw ollama::exception("Error retrieving version: "+std::to_string(res->status)); - } - - return version; diff --git a/cmake/ollama-hpp-string.patch b/cmake/ollama-hpp-string.patch deleted file mode 100644 index 41957770..00000000 --- a/cmake/ollama-hpp-string.patch +++ /dev/null @@ -1,16 +0,0 @@ -diff --git i/include/ollama.hpp w/include/ollama.hpp -index eaa7a51..0f343b8 100644 ---- i/include/ollama.hpp -+++ w/include/ollama.hpp -@@ -347,10 +347,7 @@ namespace ollama - return type; - } - -- //operator std::string() const { return this->as_simple_string(); } -- operator std::__cxx11::basic_string() const { return this->as_simple_string(); } -- //const operator std::string() const { return this->as_simple_string(); } -- -+ operator std::string() const { return this->as_simple_string(); } - - private: - diff --git a/cmake/vendor.cmake b/cmake/vendor.cmake index 1f19569e..f47b82bf 100644 --- a/cmake/vendor.cmake +++ b/cmake/vendor.cmake @@ -237,18 +237,6 @@ CPMAddPackage( EXCLUDE_FROM_ALL TRUE ) -# ollama-hpp - -CPMAddPackage( - NAME ollama-hpp - GITHUB_REPOSITORY jmont-dev/ollama-hpp - VERSION 0.9.5 - DOWNLOAD_ONLY TRUE - PATCHES - "${CMAKE_CURRENT_LIST_DIR}/ollama-hpp-string.patch" - "${CMAKE_CURRENT_LIST_DIR}/ollama-hpp-badcode.patch" -) - # base64 set(BUILD_SHARED_LIBS_SAVE ${BUILD_SHARED_LIBS}) diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index add2ec2b..95fdf562 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -60,6 +60,7 @@ set(SERVER_FILES TracyFilesystem.cpp TracyImGui.cpp TracyLlm.cpp + TracyLlmApi.cpp TracyLlmTools.cpp TracyMicroArchitecture.cpp TracyMouse.cpp @@ -219,8 +220,8 @@ else() endif() find_package(Threads REQUIRED) -target_link_libraries(${PROJECT_NAME} PRIVATE TracyServer TracyImGui Threads::Threads TracyLibcurl base64 tidy-static TracyPugixml) -target_include_directories(${PROJECT_NAME} PRIVATE ${ollama-hpp_SOURCE_DIR}/include ${tidy_SOURCE_DIR}/include) +target_link_libraries(${PROJECT_NAME} PRIVATE TracyServer TracyImGui Threads::Threads TracyLibcurl base64 tidy-static TracyPugixml nlohmann_json::nlohmann_json) +target_include_directories(${PROJECT_NAME} PRIVATE ${tidy_SOURCE_DIR}/include) if(NOT DEFINED GIT_REV) set(GIT_REV "HEAD") diff --git a/profiler/src/main.cpp b/profiler/src/main.cpp index 7762947e..77bc03b5 100644 --- a/profiler/src/main.cpp +++ b/profiler/src/main.cpp @@ -236,12 +236,11 @@ static void LoadConfig() if( ini_sget( ini, "llm", "enabled", "%d", &v ) ) s_config.llm = v; if( v2 = ini_get( ini, "llm", "address" ); v2 ) s_config.llmAddress = v2; if( v2 = ini_get( ini, "llm", "model" ); v2 ) s_config.llmModel = v2; - if( ini_sget( ini, "llm", "context", "%d", &v ) ) s_config.llmContext = v; ini_free( ini ); } -static bool SaveConfig() +bool SaveConfig() { const auto fn = tracy::GetSavePath( "tracy.ini" ); FILE* f = fopen( fn, "wb" ); @@ -275,7 +274,6 @@ static bool SaveConfig() fprintf( f, "enabled = %i\n", (int)s_config.llm ); fprintf( f, "address = %s\n", s_config.llmAddress.c_str() ); fprintf( f, "model = %s\n", s_config.llmModel.c_str() ); - fprintf( f, "context = %i\n", s_config.llmContext ); fclose( f ); return true; diff --git a/profiler/src/profiler/TracyConfig.hpp b/profiler/src/profiler/TracyConfig.hpp index b6f564d0..a6a7be0a 100644 --- a/profiler/src/profiler/TracyConfig.hpp +++ b/profiler/src/profiler/TracyConfig.hpp @@ -27,7 +27,6 @@ struct Config bool llm = true; std::string llmAddress = "http://localhost:11434"; std::string llmModel; - int llmContext = 32*1024; }; } diff --git a/profiler/src/profiler/TracyLlm.cpp b/profiler/src/profiler/TracyLlm.cpp index 4f89af44..cbf3094f 100644 --- a/profiler/src/profiler/TracyLlm.cpp +++ b/profiler/src/profiler/TracyLlm.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -7,12 +6,14 @@ #include "TracyConfig.hpp" #include "TracyImGui.hpp" #include "TracyLlm.hpp" +#include "TracyLlmApi.hpp" #include "TracyPrint.hpp" #include "../Fonts.hpp" #include "data/SystemPrompt.hpp" extern tracy::Config s_config; +extern bool SaveConfig(); namespace tracy { @@ -35,30 +36,16 @@ TracyLlm::TracyLlm() atexit( curl_global_cleanup ); } - try - { - m_ollama = std::make_unique( s_config.llmAddress ); - if( !m_ollama->is_running() ) - { - m_ollama.reset(); - return; - } - } - catch( const std::exception& e ) - { - m_ollama.reset(); - return; - } - - m_input = new char[InputBufferSize]; - *m_input = 0; - m_systemPrompt = Unembed( SystemPrompt ); - + m_input = new char[InputBufferSize]; + m_apiInput = new char[InputBufferSize]; ResetChat(); + m_api = std::make_unique(); + + m_busy = true; m_jobs.emplace_back( WorkItem { - .task = Task::LoadModels, + .task = Task::Connect, .callback = [this] { UpdateModels(); } } ); m_thread = std::thread( [this] { Worker(); } ); @@ -67,6 +54,7 @@ TracyLlm::TracyLlm() TracyLlm::~TracyLlm() { delete[] m_input; + delete[] m_apiInput; if( m_thread.joinable() ) { @@ -80,11 +68,6 @@ TracyLlm::~TracyLlm() } } -std::string TracyLlm::GetVersion() const -{ - return m_ollama->get_version(); -} - void TracyLlm::Draw() { const auto scale = GetScale(); @@ -92,22 +75,6 @@ void TracyLlm::Draw() ImGui::Begin( "Tracy AI", &m_show, ImGuiWindowFlags_NoScrollbar ); if( ImGui::GetCurrentWindowRead()->SkipItems ) { ImGui::End(); return; } - if( !m_ollama ) - { - const auto ty = ImGui::GetTextLineHeight(); - ImGui::PushFont( g_fonts.big ); - ImGui::Dummy( ImVec2( 0, ( ImGui::GetContentRegionAvail().y - ImGui::GetTextLineHeight() * 2 - ty ) * 0.5f ) ); - TextCentered( ICON_FA_PLUG_CIRCLE_XMARK ); - TextCentered( "Cannot connect to ollama server!" ); - ImGui::PopFont(); - ImGui::Dummy( ImVec2( 0, ty * 2 ) ); - ImGui::PushFont( g_fonts.small ); - TextCentered( "Server address:" ); - TextCentered( s_config.llmAddress.c_str() ); - ImGui::PopFont(); - ImGui::End(); - return; - } if( IsBusy() ) { ImGui::PushFont( g_fonts.big ); @@ -123,89 +90,74 @@ void TracyLlm::Draw() auto& style = ImGui::GetStyle(); std::lock_guard lock( m_lock ); - if( !m_models.empty() ) + const auto hasChat = m_chat.size() <= 1 && *m_input == 0; + if( hasChat ) ImGui::BeginDisabled(); + if( ImGui::Button( ICON_FA_BROOM " Clear chat" ) ) { - if( ImGui::Button( ICON_FA_BROOM " Clear chat" ) ) - { - if( m_responding ) m_stop = true; - ResetChat(); - m_chatCache.clear(); - *m_input = 0; - } - ImGui::SameLine(); + if( m_responding ) m_stop = true; + ResetChat(); } - if( ImGui::Button( ICON_FA_ARROWS_ROTATE " Reload models" ) ) + if( hasChat ) ImGui::EndDisabled(); + ImGui::SameLine(); + if( ImGui::Button( ICON_FA_ARROWS_ROTATE " Reconnect" ) ) { if( m_responding ) m_stop = true; m_jobs.emplace_back( WorkItem { - .task = Task::LoadModels, + .task = Task::Connect, .callback = [this] { UpdateModels(); } } ); m_cv.notify_all(); } - if( m_models.empty() ) - { - ImGui::PushFont( g_fonts.big ); - ImGui::Dummy( ImVec2( 0, ( ImGui::GetContentRegionAvail().y - ImGui::GetTextLineHeight() * 10 ) * 0.5f ) ); - TextCentered( ICON_FA_WORM ); - ImGui::Spacing(); - TextCentered( "No models available." ); - ImGui::Dummy( ImVec2( 0, ImGui::GetTextLineHeight() * 1.5f ) ); - ImGui::PopFont(); - ImGui::TextWrapped( "You need to retrieve at least one model with the ollama tools before you can use this feature." ); - ImGui::TextWrapped( "Models can be downloaded by running the 'ollama pull ' command." ); - ImGui::TextWrapped( "The https://ollama.com/ website contains a list of available models. The 'gemma3' model works quite well." ); - ImGui::End(); - return; - } - ImGui::SameLine(); if( ImGui::TreeNode( "Settings" ) ) { ImGui::Spacing(); ImGui::AlignTextToFramePadding(); - TextDisabledUnformatted( "Model:" ); + TextDisabledUnformatted( "API:" ); ImGui::SameLine(); - if( ImGui::BeginCombo( "##model", m_models[m_modelIdx].name.c_str() ) ) + const auto sz = std::min( InputBufferSize-1, s_config.llmAddress.size() ); + memcpy( m_apiInput, s_config.llmAddress.c_str(), sz ); + m_apiInput[sz] = 0; + if( ImGui::InputTextWithHint( "##api", "http://127.0.0.1:1234", m_apiInput, InputBufferSize ) ) { - for( size_t i = 0; i < m_models.size(); ++i ) - { - const auto& model = m_models[i]; - if( ImGui::Selectable( model.name.c_str(), i == m_modelIdx ) ) - { - m_modelIdx = i; - s_config.llmModel = model.name; - m_tools.SetModelMaxContext( model.ctxSize ); - } - if( m_modelIdx == i ) ImGui::SetItemDefaultFocus(); - ImGui::SameLine(); - ImGui::TextDisabled( "(max context: %s)", tracy::RealToString( m_models[i].ctxSize ) ); - } - ImGui::EndCombo(); + s_config.llmAddress = m_apiInput; + SaveConfig(); + m_jobs.emplace_back( WorkItem { + .task = Task::Connect, + .callback = [this] { UpdateModels(); } + } ); + m_cv.notify_all(); } + const auto& models = m_api->GetModels(); ImGui::AlignTextToFramePadding(); - ImGui::TextUnformatted( "Context size:" ); + TextDisabledUnformatted( "Model:" ); ImGui::SameLine(); - ImGui::SetNextItemWidth( 120 * scale ); - if( ImGui::InputInt( "##contextsize", &s_config.llmContext, 1024, 8192 ) ) + if( models.empty() ) { - s_config.llmContext = std::clamp( s_config.llmContext, 2048, 10240 * 1024 ); + ImGui::TextUnformatted( "No models available" ); + } + else + { + if( ImGui::BeginCombo( "##model", models[m_modelIdx].name.c_str() ) ) + { + for( size_t i = 0; i < models.size(); ++i ) + { + const auto& model = models[i]; + if( ImGui::Selectable( model.name.c_str(), i == m_modelIdx ) ) + { + m_modelIdx = i; + s_config.llmModel = model.name; + SaveConfig(); + } + if( m_modelIdx == i ) ImGui::SetItemDefaultFocus(); + ImGui::SameLine(); + ImGui::TextDisabled( "(%s)", model.quant.c_str() ); + } + ImGui::EndCombo(); + } } - ImGui::Indent(); - if( ImGui::Button( "4K" ) ) s_config.llmContext = 4 * 1024; - ImGui::SameLine(); - if( ImGui::Button( "8K" ) ) s_config.llmContext = 8 * 1024; - ImGui::SameLine(); - if( ImGui::Button( "16K" ) ) s_config.llmContext = 16 * 1024; - ImGui::SameLine(); - if( ImGui::Button( "32K" ) ) s_config.llmContext = 32 * 1024; - ImGui::SameLine(); - if( ImGui::Button( "64K" ) ) s_config.llmContext = 64 * 1024; - ImGui::SameLine(); - if( ImGui::Button( "128K" ) ) s_config.llmContext = 128 * 1024; - ImGui::Unindent(); ImGui::Checkbox( ICON_FA_TEMPERATURE_HALF " Temperature", &m_setTemperature ); ImGui::SameLine(); @@ -217,28 +169,60 @@ void TracyLlm::Draw() ImGui::TreePop(); } - const auto ctxSize = std::min( m_models[m_modelIdx].ctxSize, s_config.llmContext ); - ImGui::Spacing(); - ImGui::PushStyleVar( ImGuiStyleVar_FramePadding, ImVec2( 0, 0 ) ); - ImGui::ProgressBar( m_usedCtx / (float)ctxSize, ImVec2( -1, 0 ), "" ); - ImGui::PopStyleVar(); - if( ImGui::IsItemHovered() ) + if( !m_api->IsConnected() ) { - ImGui::BeginTooltip(); - TextFocused( "Used context size:", RealToString( m_usedCtx ) ); - ImGui::SameLine(); - char buf[64]; - PrintStringPercent( buf, m_usedCtx / (float)ctxSize * 100 ); - tracy::TextDisabledUnformatted( buf ); - TextFocused( "Available context size:", RealToString( ctxSize ) ); - ImGui::Separator(); - tracy::TextDisabledUnformatted( ICON_FA_TRIANGLE_EXCLAMATION " Context use may be an estimate" ); - ImGui::EndTooltip(); + ImGui::PushFont( g_fonts.big ); + ImGui::Dummy( ImVec2( 0, ( ImGui::GetContentRegionAvail().y - ImGui::GetTextLineHeight() * 2 ) * 0.5f ) ); + TextCentered( ICON_FA_PLUG_CIRCLE_XMARK ); + TextCentered( "No connection to LLM API" ); + ImGui::PopFont(); + ImGui::End(); + return; + } + + if( m_api->GetModels().empty() ) + { + ImGui::PushFont( g_fonts.big ); + ImGui::Dummy( ImVec2( 0, ( ImGui::GetContentRegionAvail().y - ImGui::GetTextLineHeight() * 2 ) * 0.5f ) ); + TextCentered( ICON_FA_WORM ); + ImGui::Spacing(); + TextCentered( "No models available." ); + ImGui::Dummy( ImVec2( 0, ImGui::GetTextLineHeight() * 1.5f ) ); + ImGui::PopFont(); + ImGui::TextWrapped( "Use the LLM backend tooling to download models." ); + ImGui::End(); + return; + } + + const auto ctxSize = m_api->GetContextSize(); + if( ctxSize > 0 ) + { + ImGui::Spacing(); + ImGui::PushStyleVar( ImGuiStyleVar_FramePadding, ImVec2( 0, 0 ) ); + ImGui::ProgressBar( m_usedCtx / (float)ctxSize, ImVec2( -1, 0 ), "" ); + ImGui::PopStyleVar(); + if( ImGui::IsItemHovered() ) + { + ImGui::BeginTooltip(); + TextFocused( "Used context size:", RealToString( m_usedCtx ) ); + ImGui::SameLine(); + char buf[64]; + PrintStringPercent( buf, m_usedCtx / (float)ctxSize * 100 ); + tracy::TextDisabledUnformatted( buf ); + TextFocused( "Available context size:", RealToString( ctxSize ) ); + ImGui::Separator(); + tracy::TextDisabledUnformatted( ICON_FA_TRIANGLE_EXCLAMATION " Context use may be an estimate" ); + ImGui::EndTooltip(); + } + } + else + { + tracy::TextDisabledUnformatted( ICON_FA_TRIANGLE_EXCLAMATION " Context size is not available" ); } ImGui::Spacing(); - ImGui::BeginChild( "##ollama", ImVec2( 0, -( ImGui::GetFrameHeight() + style.ItemSpacing.y * 2 ) ), ImGuiChildFlags_Borders, ImGuiWindowFlags_AlwaysVerticalScrollbar ); - if( m_chat->size() <= 1 ) // account for system prompt + ImGui::BeginChild( "##chat", ImVec2( 0, -( ImGui::GetFrameHeight() + style.ItemSpacing.y * 2 ) ), ImGuiChildFlags_Borders, ImGuiWindowFlags_AlwaysVerticalScrollbar ); + if( m_chat.size() <= 1 ) // account for system prompt { ImGui::Dummy( ImVec2( 0, ( ImGui::GetContentRegionAvail().y - ImGui::GetTextLineHeight() * 10 ) * 0.5f ) ); ImGui::PushStyleColor( ImGuiCol_Text, style.Colors[ImGuiCol_TextDisabled] ); @@ -255,7 +239,7 @@ void TracyLlm::Draw() int cacheIdx = 0; int treeIdx = 0; int num = 0; - for( auto& line : *m_chat ) + for( auto& line : m_chat ) { const auto uw = ImGui::CalcTextSize( ICON_FA_USER ).x; const auto rw = ImGui::CalcTextSize( ICON_FA_ROBOT ).x; @@ -479,7 +463,7 @@ void TracyLlm::Draw() auto buttonSize = ImGui::CalcTextSize( buttonText ); buttonSize.x += ImGui::GetStyle().FramePadding.x * 2.0f + ImGui::GetStyle().ItemSpacing.x; ImGui::PushItemWidth( ImGui::GetContentRegionAvail().x - buttonSize.x ); - bool send = ImGui::InputTextWithHint( "##ollama_input", "Write your question here...", m_input, InputBufferSize, ImGuiInputTextFlags_EnterReturnsTrue ); + bool send = ImGui::InputTextWithHint( "##chat_input", "Write your question here...", m_input, InputBufferSize, ImGuiInputTextFlags_EnterReturnsTrue ); ImGui::SameLine(); send |= ImGui::Button( buttonText ); if( send ) @@ -492,14 +476,17 @@ void TracyLlm::Draw() } if( *ptr ) { - m_chat->emplace_back( ollama::message( "user", m_input ) ); + nlohmann::json msg; + msg["role"] = "user"; + msg["content"] = m_input; + m_chat.emplace_back( std::move( msg ) ); *m_input = 0; m_responding = true; m_jobs.emplace_back( WorkItem { .task = Task::SendMessage, .callback = nullptr, - .chat = std::make_unique( *m_chat ) + .chat = m_chat } ); m_cv.notify_all(); } @@ -527,16 +514,16 @@ void TracyLlm::Worker() switch( job.task ) { - case Task::LoadModels: + case Task::Connect: m_busy = true; lock.unlock(); - LoadModels(); + m_api->Connect( s_config.llmAddress.c_str() ); job.callback(); lock.lock(); m_busy = false; break; case Task::SendMessage: - SendMessage( *job.chat ); + SendMessage( job.chat ); break; default: assert( false ); @@ -545,37 +532,18 @@ void TracyLlm::Worker() } }; -void TracyLlm::LoadModels() -{ - std::vector m; - - const auto models = m_ollama->list_models(); - for( const auto& model : models ) - { - const auto info = m_ollama->show_model_info( model ); - const auto& modelInfo = info["model_info"]; - const auto& architecture = modelInfo["general.architecture"].get_ref(); - const auto& ctx = modelInfo[architecture + ".context_length"]; - m.emplace_back( LlmModel { .name = model, .ctxSize = ctx.get() } ); - } - - m_modelsLock.lock(); - std::swap( m_models, m ); - m_modelsLock.unlock(); -} - void TracyLlm::UpdateModels() { - auto it = std::ranges::find_if( m_models, []( const auto& model ) { return model.name == s_config.llmModel; } ); - if( it == m_models.end() ) + auto& models = m_api->GetModels(); + auto it = std::ranges::find_if( models, []( const auto& model ) { return model.name == s_config.llmModel; } ); + if( it == models.end() ) { m_modelIdx = 0; } else { - m_modelIdx = std::distance( m_models.begin(), it ); + m_modelIdx = std::distance( models.begin(), it ); } - if( !m_models.empty() ) m_tools.SetModelMaxContext( m_models[m_modelIdx].ctxSize ); } void TracyLlm::ResetChat() @@ -583,47 +551,48 @@ void TracyLlm::ResetChat() auto systemPrompt = std::string( m_systemPrompt->data(), m_systemPrompt->size() ); systemPrompt += "The current time is: " + m_tools.GetCurrentTime() + "\n"; - m_chat = std::make_unique(); - m_chat->emplace_back( ollama::message( "system", systemPrompt ) ); + *m_input = 0; + m_chat.clear(); + nlohmann::json msg; + msg["role"] = "system"; + msg["content"] = systemPrompt; + m_chat.emplace_back( std::move( msg ) ); m_chatId++; m_usedCtx = systemPrompt.size() / 4; + m_chatCache.clear(); } -void TracyLlm::SendMessage( const ollama::messages& messages ) +void TracyLlm::SendMessage( const std::vector& messages ) { - // The chat() call will fire a callback right away, so the assistant message needs to be there already - m_chat->emplace_back( ollama::message( "assistant", "" ) ); + nlohmann::json msg; + msg["role"] = "assistant"; + msg["content"] = ""; + m_chat.emplace_back( std::move( msg ) ); m_lock.unlock(); bool res; try { - ollama::request req( ollama::message_type::chat ); - req["model"] = m_models[m_modelIdx].name; - req["messages"] = messages.to_json(); + nlohmann::json req; + req["model"] = m_api->GetModels()[m_modelIdx].name; + req["messages"] = messages; req["stream"] = true; - req["options"]["num_ctx"] = std::min( m_models[m_modelIdx].ctxSize, s_config.llmContext ); - if( m_setTemperature ) req["options"]["temperature"] = m_temperature; + if( m_setTemperature ) req["temperature"] = m_temperature; - res = m_ollama->chat( req, [this]( const ollama::response& response ) -> bool { return OnResponse( response ); }); + res = m_api->ChatCompletion( req, [this]( const nlohmann::json& response ) -> bool { return OnResponse( response ); } ); } catch( std::exception& e ) { m_lock.lock(); - if( !m_chat->empty() && m_chat->back()["role"].get_ref() == "assistant" ) m_chat->pop_back(); - m_chat->emplace_back( ollama::message( "error", e.what() ) ); + if( !m_chat.empty() && m_chat.back()["role"].get_ref() == "assistant" ) m_chat.pop_back(); + nlohmann::json err; + err["role"] = "error"; + err["content"] = e.what(); + m_chat.emplace_back( std::move( err ) ); m_responding = false; m_stop = false; return; } - - m_lock.lock(); - if( !res ) - { - m_chat->pop_back(); - m_responding = false; - m_stop = false; - } } static std::vector SplitLines( const std::string& str ) @@ -640,7 +609,7 @@ static std::vector SplitLines( const std::string& str ) return lines; } -bool TracyLlm::OnResponse( const ollama::response& response ) +bool TracyLlm::OnResponse( const nlohmann::json& json ) { std::unique_lock lock( m_lock ); @@ -652,17 +621,37 @@ bool TracyLlm::OnResponse( const ollama::response& response ) return false; } - auto& back = m_chat->back(); + auto& back = m_chat.back(); auto& content = back["content"]; const auto& str = content.get_ref(); - auto responseStr = response.as_simple_string(); - std::erase( responseStr, '\r' ); - content = str + responseStr; - m_usedCtx++; - auto& json = response.as_json(); - auto& message = json["message"]; - if( json["done"] ) + std::string responseStr; + bool done; + try + { + auto& node = json["choices"][0]; + auto& delta = node["delta"]; + if( delta.contains( "content" ) ) responseStr = delta["content"].get_ref(); + done = !node["finish_reason"].empty(); + } + catch( const nlohmann::json::exception& e ) + { + if( m_responding ) + { + m_responding = false; + m_focusInput = true; + } + return false; + } + + if( !responseStr.empty() ) + { + std::erase( responseStr, '\r' ); + content = str + responseStr; + m_usedCtx++; + } + + if( done ) { bool isTool = false; auto& str = back["content"].get_ref(); @@ -680,24 +669,29 @@ bool TracyLlm::OnResponse( const ollama::response& response ) auto tool = lines[0]; lines.erase( lines.begin() ); lock.unlock(); - const auto reply = m_tools.HandleToolCalls( tool, lines ); + const auto reply = m_tools.HandleToolCalls( tool, lines, *m_api ); const auto output = "\n" + reply.reply; lock.lock(); - if( reply.image.empty() ) + //if( reply.image.empty() ) { - m_chat->emplace_back( ollama::message( "user", output ) ); + nlohmann::json msg; + msg["role"] = "user"; + msg["content"] = output; + m_chat.emplace_back( std::move( msg ) ); } + /* else { std::vector images; images.emplace_back( ollama::image::from_base64_string( reply.image ) ); m_chat->emplace_back( ollama::message( "user", output, images ) ); } + */ m_jobs.emplace_back( WorkItem { .task = Task::SendMessage, .callback = nullptr, - .chat = std::make_unique( *m_chat ) + .chat = m_chat } ); m_cv.notify_all(); } @@ -709,7 +703,6 @@ bool TracyLlm::OnResponse( const ollama::response& response ) m_focusInput = true; } - m_usedCtx = json["prompt_eval_count"].get() + json["eval_count"].get(); return false; } @@ -762,7 +755,7 @@ void TracyLlm::PrintLine( LineContext& ctx, const std::string& str, int num ) else { char tmp[64]; - snprintf( tmp, sizeof( tmp ), "##ollama_code_%d", num ); + snprintf( tmp, sizeof( tmp ), "##chat_code_%d", num ); ImGui::BeginChild( tmp, ImVec2( 0, 0 ), ImGuiChildFlags_FrameStyle | ImGuiChildFlags_Borders | ImGuiChildFlags_AutoResizeY ); if( ptr[3] ) { diff --git a/profiler/src/profiler/TracyLlm.hpp b/profiler/src/profiler/TracyLlm.hpp index 35e383d9..c4548a3a 100644 --- a/profiler/src/profiler/TracyLlm.hpp +++ b/profiler/src/profiler/TracyLlm.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -14,23 +15,16 @@ #include "TracyLlmTools.hpp" #include "tracy_robin_hood.h" -class Ollama; - -namespace ollama -{ -class message; -class messages; -class response; -} - namespace tracy { +class TracyLlmApi; + class TracyLlm { enum class Task { - LoadModels, + Connect, SendMessage, }; @@ -38,7 +32,7 @@ class TracyLlm { Task task; std::function callback; - std::unique_ptr chat; + std::vector chat; }; struct ChatCache @@ -53,21 +47,11 @@ class TracyLlm }; public: - struct LlmModel - { - std::string name; - int ctxSize; - }; - TracyLlm(); ~TracyLlm(); - [[nodiscard]] bool IsValid() const { return (bool)m_ollama; } [[nodiscard]] bool IsBusy() const { std::lock_guard lock( m_lock); return m_busy; } - [[nodiscard]] std::string GetVersion() const; - [[nodiscard]] std::vector GetModels() const { std::lock_guard lock( m_modelsLock ); return m_models; } - void Draw(); bool m_show = false; @@ -80,8 +64,8 @@ private: void ResetChat(); - void SendMessage( const ollama::messages& messages ); - bool OnResponse( const ollama::response& response ); + void SendMessage( const std::vector& messages ); + bool OnResponse( const nlohmann::json& json ); void UpdateCache( ChatCache& cache, const std::string& str ); @@ -89,10 +73,7 @@ private: void PrintMarkdown( const char* str ); void CleanContext( LineContext& ctx); - std::unique_ptr m_ollama; - - mutable std::mutex m_modelsLock; - std::vector m_models; + std::unique_ptr m_api; size_t m_modelIdx; @@ -112,7 +93,8 @@ private: bool m_setTemperature = false; char* m_input; - std::unique_ptr m_chat; + char* m_apiInput; + std::vector m_chat; unordered_flat_map m_chatCache; std::shared_ptr m_systemPrompt; diff --git a/profiler/src/profiler/TracyLlmApi.cpp b/profiler/src/profiler/TracyLlmApi.cpp new file mode 100644 index 00000000..8702618c --- /dev/null +++ b/profiler/src/profiler/TracyLlmApi.cpp @@ -0,0 +1,177 @@ +#include +#include +#include +#include + +#include "TracyLlmApi.hpp" + +namespace tracy +{ + +static size_t WriteFn( void* _data, size_t size, size_t num, void* ptr ) +{ + const auto data = (unsigned char*)_data; + const auto sz = size*num; + auto& v = *(std::string*)ptr; + v.append( (const char*)data, sz ); + return sz; +} + + +TracyLlmApi::~TracyLlmApi() +{ + if( m_curl ) curl_easy_cleanup( m_curl ); +} + +bool TracyLlmApi::Connect( const char* url ) +{ + m_contextSize = -1; + m_url = url; + m_models.clear(); + if( m_curl ) curl_easy_cleanup( m_curl ); + + m_curl = curl_easy_init(); + if( !m_curl ) return false; + + curl_easy_setopt( m_curl, CURLOPT_NOSIGNAL, 1L ); + curl_easy_setopt( m_curl, CURLOPT_CA_CACHE_TIMEOUT, 604800L ); + curl_easy_setopt( m_curl, CURLOPT_FOLLOWLOCATION, 1L ); + curl_easy_setopt( m_curl, CURLOPT_TIMEOUT, 300 ); + curl_easy_setopt( m_curl, CURLOPT_USERAGENT, "Tracy Profiler" ); + + std::string buf; + if( GetRequest( m_url + "/v1/models", buf ) != 200 ) + { + curl_easy_cleanup( m_curl ); + m_curl = nullptr; + return false; + } + + try + { + m_type = Type::Unknown; + nlohmann::json json = nlohmann::json::parse( buf ); + for( auto& model : json["data"] ) + { + auto& id = model["id"].get_ref(); + m_models.emplace_back( LlmModel { .name = id } ); + + std::string buf2; + if( GetRequest( m_url + "/api/v0/models/" + id, buf2 ) == 200 ) + { + m_type = Type::LmStudio; + auto json2 = nlohmann::json::parse( buf2 ); + m_models.back().quant = json2["quantization"].get_ref(); + } + else if( PostRequest( m_url + "/api/show", "{\"name\":\"" + id + "\"}", buf2 ) == 200 ) + { + m_type = Type::Ollama; + auto json2 = nlohmann::json::parse( buf2 ); + m_models.back().quant = json2["details"]["quantization_level"].get_ref(); + } + } + } + catch( const std::exception& e ) + { + m_models.clear(); + curl_easy_cleanup( m_curl ); + m_curl = nullptr; + return false; + } + + return true; +} + +struct StreamData +{ + std::string str; + const std::function& callback; +}; + +static size_t StreamFn( void* _data, size_t size, size_t num, void* ptr ) +{ + auto data = (const char*)_data; + const auto sz = size*num; + auto& v = *(StreamData*)ptr; + v.str.append( data, sz ); + + for(;;) + { + auto pos = v.str.find( "data: " ); + if( pos == std::string::npos ) break; + pos += 6; + auto end = v.str.find( "\n\n", pos ); + if( end == std::string::npos ) break; + + nlohmann::json json = nlohmann::json::parse( v.str.c_str() + pos, v.str.c_str() + end ); + if( !v.callback( json ) ) return 0; + v.str.erase( 0, end + 2 ); + } + return sz; +} + +bool TracyLlmApi::ChatCompletion( const nlohmann::json& req, const std::function& callback ) +{ + assert( m_curl ); + StreamData data = { .callback = callback }; + + const auto url = m_url + "/v1/chat/completions"; + const auto reqStr = req.dump(); + + curl_slist *hdr = nullptr; + hdr = curl_slist_append( hdr, "Accept: application/json" ); + hdr = curl_slist_append( hdr, "Content-Type: application/json" ); + + curl_easy_setopt( m_curl, CURLOPT_URL, url.c_str() ); + curl_easy_setopt( m_curl, CURLOPT_HTTPHEADER, hdr ); + curl_easy_setopt( m_curl, CURLOPT_POSTFIELDS, reqStr.c_str() ); + curl_easy_setopt( m_curl, CURLOPT_POSTFIELDSIZE, reqStr.size() ); + curl_easy_setopt( m_curl, CURLOPT_WRITEDATA, &data.str ); + curl_easy_setopt( m_curl, CURLOPT_WRITEFUNCTION, StreamFn ); + + auto res = curl_easy_perform( m_curl ); + curl_slist_free_all( hdr ); + if( res != CURLE_OK ) return false; + + int64_t http_code = 0; + curl_easy_getinfo( m_curl, CURLINFO_RESPONSE_CODE, &http_code ); + return http_code == 200; +} + +int64_t TracyLlmApi::GetRequest( const std::string& url, std::string& response ) +{ + assert( m_curl ); + response.clear(); + + curl_easy_setopt( m_curl, CURLOPT_URL, url.c_str() ); + curl_easy_setopt( m_curl, CURLOPT_WRITEDATA, &response ); + curl_easy_setopt( m_curl, CURLOPT_WRITEFUNCTION, WriteFn ); + + auto res = curl_easy_perform( m_curl ); + if( res != CURLE_OK ) return -1; + + int64_t http_code = 0; + curl_easy_getinfo( m_curl, CURLINFO_RESPONSE_CODE, &http_code ); + return http_code; +} + +int64_t TracyLlmApi::PostRequest( const std::string& url, const std::string& data, std::string& response ) +{ + assert( m_curl ); + response.clear(); + + curl_easy_setopt( m_curl, CURLOPT_URL, url.c_str() ); + curl_easy_setopt( m_curl, CURLOPT_POSTFIELDS, data.c_str() ); + curl_easy_setopt( m_curl, CURLOPT_POSTFIELDSIZE, data.size() ); + curl_easy_setopt( m_curl, CURLOPT_WRITEDATA, &response ); + curl_easy_setopt( m_curl, CURLOPT_WRITEFUNCTION, WriteFn ); + + auto res = curl_easy_perform( m_curl ); + if( res != CURLE_OK ) return -1; + + int64_t http_code = 0; + curl_easy_getinfo( m_curl, CURLINFO_RESPONSE_CODE, &http_code ); + return http_code; +} + +} diff --git a/profiler/src/profiler/TracyLlmApi.hpp b/profiler/src/profiler/TracyLlmApi.hpp new file mode 100644 index 00000000..e4eb06a5 --- /dev/null +++ b/profiler/src/profiler/TracyLlmApi.hpp @@ -0,0 +1,52 @@ +#ifndef __TRACYLLMAPI_HPP__ +#define __TRACYLLMAPI_HPP__ + +#include +#include +#include +#include +#include + +namespace tracy +{ + +struct LlmModel +{ + std::string name; + std::string quant; +}; + +class TracyLlmApi +{ + enum class Type + { + Unknown, + Ollama, + LmStudio, + }; + +public: + ~TracyLlmApi(); + + bool Connect( const char* url ); + bool ChatCompletion( const nlohmann::json& req, const std::function& callback );; + + [[nodiscard]] bool IsConnected() const { return m_curl != nullptr; } + [[nodiscard]] const std::vector& GetModels() const { return m_models; } + [[nodiscard]] int GetContextSize() const { return m_contextSize; } + +private: + int64_t GetRequest( const std::string& url, std::string& response ); + int64_t PostRequest( const std::string& url, const std::string& data, std::string& response ); + + void* m_curl = nullptr; + std::string m_url; + Type m_type; + + std::vector m_models; + int m_contextSize; +}; + +} + +#endif diff --git a/profiler/src/profiler/TracyLlmTools.cpp b/profiler/src/profiler/TracyLlmTools.cpp index 0a59e98f..55999dd7 100644 --- a/profiler/src/profiler/TracyLlmTools.cpp +++ b/profiler/src/profiler/TracyLlmTools.cpp @@ -1,24 +1,17 @@ #include -#include +#include #include #include #include #include #include -#include "TracyConfig.hpp" +#include "TracyLlmApi.hpp" #include "TracyLlmTools.hpp" -extern tracy::Config s_config; - namespace tracy { -void TracyLlmTools::SetModelMaxContext( int modelMaxContext ) -{ - m_modelMaxContext = modelMaxContext; -} - static std::string UrlEncode( const std::string& str ) { std::string out; @@ -45,8 +38,10 @@ static std::string UrlEncode( const std::string& str ) return out; } -TracyLlmTools::ToolReply TracyLlmTools::HandleToolCalls( const std::string& name, const std::vector& args ) +TracyLlmTools::ToolReply TracyLlmTools::HandleToolCalls( const std::string& name, const std::vector& args, const TracyLlmApi& api) { + m_ctxSize = api.GetContextSize(); + if( name == "fetch_web_page" ) { if( args.empty() ) return { .reply = "Missing URL argument" }; @@ -91,11 +86,11 @@ std::string TracyLlmTools::GetCurrentTime() int TracyLlmTools::CalcMaxSize() const { + if( m_ctxSize <= 0 ) return 32*1024; + // Limit the size of the response to avoid exceeding the context size // Assume average token size is 4 bytes. Make space for 3 articles to be retrieved. - assert( m_modelMaxContext != 0 ); - const auto ctxSize = std::min( m_modelMaxContext, s_config.llmContext ); - const auto maxSize = ( ctxSize * 4 ) / 3; + const auto maxSize = ( m_ctxSize * 4 ) / 3; return maxSize; } diff --git a/profiler/src/profiler/TracyLlmTools.hpp b/profiler/src/profiler/TracyLlmTools.hpp index 2f6d0dca..29ac9851 100644 --- a/profiler/src/profiler/TracyLlmTools.hpp +++ b/profiler/src/profiler/TracyLlmTools.hpp @@ -9,6 +9,8 @@ namespace tracy { +class TracyLlmApi; + class TracyLlmTools { public: @@ -18,9 +20,7 @@ public: std::string image; }; - void SetModelMaxContext( int modelMaxContext ); - - ToolReply HandleToolCalls( const std::string& name, const std::vector& args ); + ToolReply HandleToolCalls( const std::string& name, const std::vector& args, const TracyLlmApi& api ); std::string GetCurrentTime(); bool m_netAccess = true; @@ -36,7 +36,7 @@ private: unordered_flat_map m_webCache; - int m_modelMaxContext = 0; + int m_ctxSize; }; } diff --git a/profiler/src/profiler/TracyView.cpp b/profiler/src/profiler/TracyView.cpp index db6c2a43..87135afb 100644 --- a/profiler/src/profiler/TracyView.cpp +++ b/profiler/src/profiler/TracyView.cpp @@ -13,6 +13,7 @@ #include "imgui.h" +#include "TracyConfig.hpp" #include "TracyFileRead.hpp" #include "TracyFilesystem.hpp" #include "TracyImGui.hpp" @@ -31,6 +32,8 @@ #define M_PI_2 1.57079632679489661923 #endif +extern tracy::Config s_config; + namespace tracy { @@ -952,7 +955,7 @@ bool View::DrawImpl() ImGui::EndPopup(); } } - if( m_llm.IsValid() ) + if( s_config.llm ) { ImGui::SameLine(); ToggleButton( ICON_FA_ROBOT, m_llm.m_show );