feat(agent): add comprehensive prompt audit system

This commit is contained in:
hdbg
2026-02-27 13:23:21 +01:00
parent 3738272f80
commit 3a210809cf
21 changed files with 1800 additions and 107 deletions

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2024"
[dependencies]
async-trait.workspace = true
chrono.workspace = true
codetaker-db.workspace = true
diesel.workspace = true

View File

@@ -5,10 +5,11 @@ use rig::agent::AgentBuilder;
use rig::client::{Client, CompletionClient};
use rig::completion::{StructuredOutputError, TypedPrompt};
use crate::audit::{self, AuditEntrypoint};
use crate::error::{AgentError, AgentResult};
use crate::git_access::{self};
use crate::memory::{MemoryStore, ProjectContextSnapshot};
use crate::tools::{AstGrepTool, FileContext, SearchHit};
use crate::memory::{AuditStore, MemoryStore, ProjectContextSnapshot};
use crate::tools::{AstGrepTool, FileContext, ProjectMemoryWriteTool, ReadFileTool, SearchHit};
use crate::types::{ConversationInput, ConversationOutput};
const CONVERSATION_RESPONSE_PREAMBLE: &str = r#"
@@ -44,7 +45,7 @@ impl<Ext, H, M> ConversationAnswerAgent<Ext, H, M> {
impl<Ext, H, M> ConversationAnswerAgent<Ext, H, M>
where
Client<Ext, H>: CompletionClient,
M: MemoryStore,
M: MemoryStore + AuditStore + Clone + Send + Sync + 'static,
{
pub async fn conversation_response(
&self,
@@ -63,6 +64,9 @@ where
)?;
let ast_grep = self.ast_grep.bind_to_request(repo, &input.head_ref)?;
let readfile = ReadFileTool::new().bind_to_request(repo, &input.head_ref)?;
let memory_write =
ProjectMemoryWriteTool::new(self.memory.clone()).bind_to_project(input.project.clone());
let search_hits = self.collect_conversation_search_hits(&ast_grep, &input);
let prompt =
build_conversation_prompt(&input, &memory_snapshot, &anchor_context, &search_hits);
@@ -71,29 +75,56 @@ where
let agent = AgentBuilder::new(conversation_model)
.preamble(CONVERSATION_RESPONSE_PREAMBLE)
.tool(ast_grep)
.tool(readfile)
.tool(memory_write)
.temperature(0.1)
.output_schema::<ConversationOutput>()
.build();
agent
.prompt_typed::<ConversationOutput>(prompt)
.await
.map_err(|err| match err {
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
model: self.conversation_model.clone(),
message: prompt_err.to_string(),
},
StructuredOutputError::DeserializationError(deser_err) => {
AgentError::OutputValidationError {
message: format!("failed to deserialize conversation output: {deser_err}"),
raw_output: None,
}
let audit_session = audit::start_session(
self.memory.clone(),
&input.project,
input.pull_request_id,
AuditEntrypoint::ConversationResponse,
)
.await;
let result = if let Some(session) = &audit_session {
agent
.prompt_typed::<ConversationOutput>(prompt)
.with_hook(session.hook())
.await
} else {
agent.prompt_typed::<ConversationOutput>(prompt).await
};
match &result {
Ok(_) => {
if let Some(session) = &audit_session {
session.finish_success().await;
}
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
message: "conversation response was empty".to_owned(),
}
Err(err) => {
if let Some(session) = &audit_session {
session.finish_failure(&err.to_string()).await;
}
}
}
result.map_err(|err| match err {
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
model: self.conversation_model.clone(),
message: prompt_err.to_string(),
},
StructuredOutputError::DeserializationError(deser_err) => {
AgentError::OutputValidationError {
message: format!("failed to deserialize conversation output: {deser_err}"),
raw_output: None,
},
})
}
}
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
message: "conversation response was empty".to_owned(),
raw_output: None,
},
})
}
fn collect_conversation_search_hits(
@@ -126,6 +157,7 @@ fn build_conversation_prompt(
input.project.forge.as_db_value()
)
.ok();
writeln!(prompt, "Pull request ID: {}", input.pull_request_id).ok();
writeln!(prompt, "Head reference: '{}'", input.head_ref).ok();
writeln!(

View File

@@ -5,22 +5,23 @@ use rig::agent::AgentBuilder;
use rig::client::{Client, CompletionClient};
use rig::completion::{StructuredOutputError, TypedPrompt};
use crate::audit::{self, AuditEntrypoint};
use crate::error::{AgentError, AgentResult};
use crate::git_access::{self, PullRequestMaterial};
use crate::memory::{MemoryStore, ProjectContextSnapshot};
use crate::tools::{AstGrepTool, FileContext};
use crate::memory::{AuditStore, MemoryStore, ProjectContextSnapshot};
use crate::tools::{AstGrepTool, FileContext, ProjectMemoryWriteTool, ReadFileTool};
use crate::types::{PullRequestReviewInput, PullRequestReviewOutput};
const PULL_REQUEST_REVIEW_PREAMBLE: &str = r#"
You are a code review agent.
Analyze pull request diffs and respond only with valid JSON.
Use available tools when you need extra repository context.
If you found valuable insight that might help future reviews, save it to the project memory with the memory write tool.
Be concise and focus on the most important issues in the code.
Explain your judgement clearly and provide actionable feedback for the developer.
When using tools, only request the specific information you need to make your review.
In the end, pass judgement if you want to accept it or request changes, and explain your reasoning.
Line numbers must target the head/new file version.
Do not include extra keys.
"#;
pub struct PullRequestReviewAgent<Ext, H, M> {
@@ -49,7 +50,7 @@ impl<Ext, H, M> PullRequestReviewAgent<Ext, H, M> {
impl<Ext, H, M> PullRequestReviewAgent<Ext, H, M>
where
Client<Ext, H>: CompletionClient,
M: MemoryStore,
M: MemoryStore + AuditStore + Clone + Send + Sync + 'static,
{
pub async fn pull_request_review(
&self,
@@ -65,36 +66,66 @@ where
build_pull_request_prompt(&input, &memory_snapshot, &pr_material, &file_contexts);
let ast_grep = self.ast_grep.bind_to_request(repo, &input.head_ref)?;
let readfile = ReadFileTool::new().bind_to_request(repo, &input.head_ref)?;
let memory_write =
ProjectMemoryWriteTool::new(self.memory.clone()).bind_to_project(input.project.clone());
let review_model = self.client.completion_model(&self.review_model);
let agent = AgentBuilder::new(review_model)
.preamble(PULL_REQUEST_REVIEW_PREAMBLE)
.tool(ast_grep)
.tool(readfile)
.tool(memory_write)
.temperature(0.0)
.output_schema::<PullRequestReviewOutput>()
.build();
agent
.prompt_typed::<PullRequestReviewOutput>(prompt)
.await
.map_err(|err| match err {
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
model: self.review_model.clone(),
message: prompt_err.to_string(),
},
StructuredOutputError::DeserializationError(deser_err) => {
AgentError::OutputValidationError {
message: format!(
"failed to deserialize pull request review output: {deser_err}"
),
raw_output: None,
}
let audit_session = audit::start_session(
self.memory.clone(),
&input.project,
input.pull_request_id,
AuditEntrypoint::PullRequestReview,
)
.await;
let result = if let Some(session) = &audit_session {
agent
.prompt_typed::<PullRequestReviewOutput>(prompt)
.with_hook(session.hook())
.await
} else {
agent.prompt_typed::<PullRequestReviewOutput>(prompt).await
};
match &result {
Ok(_) => {
if let Some(session) = &audit_session {
session.finish_success().await;
}
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
message: "pull request review response was empty".to_owned(),
}
Err(err) => {
if let Some(session) = &audit_session {
session.finish_failure(&err.to_string()).await;
}
}
}
result.map_err(|err| match err {
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
model: self.review_model.clone(),
message: prompt_err.to_string(),
},
StructuredOutputError::DeserializationError(deser_err) => {
AgentError::OutputValidationError {
message: format!(
"failed to deserialize pull request review output: {deser_err}"
),
raw_output: None,
},
})
}
}
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
message: "pull request review response was empty".to_owned(),
raw_output: None,
},
})
}
fn collect_diff_file_contexts(
@@ -142,6 +173,7 @@ fn build_pull_request_prompt(
input.project.forge.as_db_value()
)
.ok();
writeln!(prompt, "Pull request ID: {}", input.pull_request_id).ok();
writeln!(
prompt,

View File

@@ -0,0 +1,325 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicI32, Ordering};
use chrono::Utc;
use rig::agent::{HookAction, PromptHook, ToolCallHookAction};
use rig::completion::{CompletionModel, CompletionResponse};
use rig::message::Message;
use serde::Serialize;
use crate::memory::AuditStore;
use crate::types::ProjectRef;
#[derive(Debug, Clone, Copy)]
pub enum AuditEntrypoint {
PullRequestReview,
ConversationResponse,
}
impl AuditEntrypoint {
fn as_db_value(self) -> &'static str {
match self {
Self::PullRequestReview => "pull_request_review",
Self::ConversationResponse => "conversation_response",
}
}
}
#[derive(Clone)]
pub struct PromptAuditHook<M> {
store: M,
session_id: i32,
sequence: Arc<AtomicI32>,
}
impl<M> PromptAuditHook<M> {
pub fn new(store: M, session_id: i32) -> Self {
Self {
store,
session_id,
sequence: Arc::new(AtomicI32::new(0)),
}
}
fn next_sequence(&self) -> i32 {
self.sequence.fetch_add(1, Ordering::Relaxed) + 1
}
}
pub struct AuditSession<M> {
store: M,
session_id: i32,
}
impl<M> AuditSession<M>
where
M: AuditStore + Clone + Send + Sync + 'static,
{
pub fn hook(&self) -> PromptAuditHook<M> {
PromptAuditHook::new(self.store.clone(), self.session_id)
}
pub async fn finish_success(&self) {
let ended_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.finish_prompt_audit_session(self.session_id, ended_at, "completed", None)
.await
{
tracing::warn!(error = %err, "failed to finish successful prompt audit session");
}
}
pub async fn finish_failure(&self, error_message: &str) {
let ended_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.finish_prompt_audit_session(
self.session_id,
ended_at,
"failed",
Some(error_message),
)
.await
{
tracing::warn!(error = %err, "failed to finish failed prompt audit session");
}
}
}
pub async fn start_session<M>(
store: M,
project: &ProjectRef,
pull_request_id: i64,
entrypoint: AuditEntrypoint,
) -> Option<AuditSession<M>>
where
M: AuditStore + Clone + Send + Sync + 'static,
{
let started_at = Utc::now().naive_utc();
match store
.start_prompt_audit_session(
project,
pull_request_id,
entrypoint.as_db_value(),
started_at,
)
.await
{
Ok(session_id) => Some(AuditSession { store, session_id }),
Err(err) => {
tracing::warn!(error = %err, "failed to start prompt audit session");
None
}
}
}
impl<M, Model> PromptHook<Model> for PromptAuditHook<M>
where
M: AuditStore + Clone + Send + Sync + 'static,
Model: CompletionModel,
{
async fn on_completion_call(&self, prompt: &Message, history: &[Message]) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
let prompt_json = serialize_to_json(prompt);
let history_json = serialize_to_json(history);
if let Err(err) = self
.store
.log_on_completion_call(
self.session_id,
event_at,
sequence_no,
&prompt_json,
&history_json,
)
.await
{
tracing::warn!(error = %err, "failed to log on_completion_call");
}
HookAction::cont()
}
async fn on_completion_response(
&self,
prompt: &Message,
response: &CompletionResponse<Model::Response>,
) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
let prompt_json = serialize_to_json(prompt);
let assistant_choice_json = serialize_to_json(&response.choice);
let usage_input_tokens = u64_to_i64(response.usage.input_tokens);
let usage_output_tokens = u64_to_i64(response.usage.output_tokens);
let usage_total_tokens = u64_to_i64(response.usage.total_tokens);
if let Err(err) = self
.store
.log_on_completion_response(
self.session_id,
event_at,
sequence_no,
&prompt_json,
&assistant_choice_json,
usage_input_tokens,
usage_output_tokens,
usage_total_tokens,
)
.await
{
tracing::warn!(error = %err, "failed to log on_completion_response");
}
HookAction::cont()
}
async fn on_tool_call(
&self,
tool_name: &str,
tool_call_id: Option<String>,
internal_call_id: &str,
args: &str,
) -> ToolCallHookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.log_on_tool_call(
self.session_id,
event_at,
sequence_no,
tool_name,
tool_call_id.as_deref(),
internal_call_id,
args,
"continue",
None,
)
.await
{
tracing::warn!(error = %err, "failed to log on_tool_call");
}
ToolCallHookAction::cont()
}
async fn on_tool_result(
&self,
tool_name: &str,
tool_call_id: Option<String>,
internal_call_id: &str,
args: &str,
result: &str,
) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.log_on_tool_result(
self.session_id,
event_at,
sequence_no,
tool_name,
tool_call_id.as_deref(),
internal_call_id,
args,
result,
)
.await
{
tracing::warn!(error = %err, "failed to log on_tool_result");
}
HookAction::cont()
}
async fn on_text_delta(&self, text_delta: &str, aggregated_text: &str) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.log_on_text_delta(
self.session_id,
event_at,
sequence_no,
text_delta,
aggregated_text,
)
.await
{
tracing::warn!(error = %err, "failed to log on_text_delta");
}
HookAction::cont()
}
async fn on_tool_call_delta(
&self,
tool_call_id: &str,
internal_call_id: &str,
tool_name: Option<&str>,
tool_call_delta: &str,
) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.log_on_tool_call_delta(
self.session_id,
event_at,
sequence_no,
tool_call_id,
internal_call_id,
tool_name,
tool_call_delta,
)
.await
{
tracing::warn!(error = %err, "failed to log on_tool_call_delta");
}
HookAction::cont()
}
async fn on_stream_completion_response_finish(
&self,
prompt: &Message,
response: &Model::StreamingResponse,
) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
let prompt_json = serialize_to_json(prompt);
let response_summary_json = serialize_to_json(response);
if let Err(err) = self
.store
.log_on_stream_completion_response_finish(
self.session_id,
event_at,
sequence_no,
&prompt_json,
&response_summary_json,
)
.await
{
tracing::warn!(error = %err, "failed to log on_stream_completion_response_finish");
}
HookAction::cont()
}
}
fn serialize_to_json<T: Serialize + ?Sized>(value: &T) -> String {
serde_json::to_string(value)
.unwrap_or_else(|err| format!(r#"{{"serialization_error":"{}"}}"#, err))
}
fn u64_to_i64(value: u64) -> i64 {
i64::try_from(value).unwrap_or(i64::MAX)
}

View File

@@ -1,4 +1,5 @@
pub mod agent;
pub mod audit;
pub mod error;
pub mod git_access;
pub mod memory;
@@ -8,8 +9,12 @@ pub mod types;
pub use agent::{ConversationAnswerAgent, PullRequestReviewAgent};
pub use error::{AgentError, AgentResult};
pub use git_access::PullRequestMaterial;
pub use memory::{DieselMemoryStore, MemoryStore, ProjectContextSnapshot};
pub use tools::{AstGrepArgs, AstGrepOutput, AstGrepTool, FileContext, SearchHit};
pub use memory::{AuditStore, DieselMemoryStore, MemoryStore, ProjectContextSnapshot};
pub use tools::{
AstGrepArgs, AstGrepOutput, AstGrepTool, FileContext, MemoryWriteOperation,
ProjectMemoryWriteArgs, ProjectMemoryWriteOutput, ProjectMemoryWriteTool, ReadFileArgs,
ReadFileOutput, ReadFileTool, SearchHit,
};
pub use types::{
ConversationInput, ConversationOutput, Forge, MessageAuthor, ProjectRef,
PullRequestReviewInput, PullRequestReviewOutput, ReviewComment, ReviewResult, ThreadMessage,

View File

@@ -2,6 +2,7 @@ mod sqlite;
use std::collections::BTreeMap;
use async_trait::async_trait;
use chrono::NaiveDateTime;
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -24,8 +25,8 @@ pub struct ProjectContextSnapshot {
pub summaries: Vec<MemorySummary>,
}
#[allow(async_fn_in_trait)]
pub trait MemoryStore {
#[async_trait]
pub trait MemoryStore: Send + Sync {
async fn project_context_snapshot(
&self,
project: &ProjectRef,
@@ -49,6 +50,7 @@ pub trait MemoryStore {
async fn create_review_thread(
&self,
project: &ProjectRef,
pull_request_id: i64,
file: &str,
line: i32,
initial_comment: &str,
@@ -63,3 +65,97 @@ pub trait MemoryStore {
async fn load_thread_messages(&self, thread_id: i32) -> AgentResult<Vec<ThreadMessage>>;
}
#[async_trait]
pub trait AuditStore: Send + Sync {
async fn start_prompt_audit_session(
&self,
project: &ProjectRef,
pull_request_id: i64,
entrypoint: &str,
started_at: NaiveDateTime,
) -> AgentResult<i32>;
async fn finish_prompt_audit_session(
&self,
session_id: i32,
ended_at: NaiveDateTime,
status: &str,
error_message: Option<&str>,
) -> AgentResult<()>;
async fn log_on_completion_call(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
history_json: &str,
) -> AgentResult<()>;
async fn log_on_completion_response(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
assistant_choice_json: &str,
usage_input_tokens: i64,
usage_output_tokens: i64,
usage_total_tokens: i64,
) -> AgentResult<()>;
async fn log_on_tool_call(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
tool_name: &str,
tool_call_id: Option<&str>,
internal_call_id: &str,
args_json: &str,
action: &str,
reason: Option<&str>,
) -> AgentResult<()>;
async fn log_on_tool_result(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
tool_name: &str,
tool_call_id: Option<&str>,
internal_call_id: &str,
args_json: &str,
result_json: &str,
) -> AgentResult<()>;
async fn log_on_text_delta(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
text_delta: &str,
aggregated_text: &str,
) -> AgentResult<()>;
async fn log_on_tool_call_delta(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
tool_call_id: &str,
internal_call_id: &str,
tool_name: Option<&str>,
tool_call_delta: &str,
) -> AgentResult<()>;
async fn log_on_stream_completion_response_finish(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
response_summary_json: &str,
) -> AgentResult<()>;
}

View File

@@ -1,13 +1,21 @@
use std::collections::BTreeMap;
use async_trait::async_trait;
use chrono::Utc;
use codetaker_db::models::{
NewProjectMemoryEntryRow, NewProjectMemorySummaryRow, NewProjectRow, NewReviewThreadMessageRow,
NewReviewThreadRow, ProjectMemoryEntryRow, ProjectMemorySummaryRow, ReviewThreadMessageRow,
NewProjectMemoryEntryRow, NewProjectMemorySummaryRow, NewProjectRow, NewPromptAuditSessionRow,
NewPromptHookOnCompletionCallEventRow, NewPromptHookOnCompletionResponseEventRow,
NewPromptHookOnStreamCompletionResponseFinishEventRow, NewPromptHookOnTextDeltaEventRow,
NewPromptHookOnToolCallDeltaEventRow, NewPromptHookOnToolCallEventRow,
NewPromptHookOnToolResultEventRow, NewReviewThreadMessageRow, NewReviewThreadRow,
ProjectMemoryEntryRow, ProjectMemorySummaryRow, ReviewThreadMessageRow,
};
use codetaker_db::schema::{
project_memory_entries, project_memory_summaries, projects, review_thread_messages,
review_threads,
project_memory_entries, project_memory_summaries, projects, prompt_audit_sessions,
prompt_hook_on_completion_call_events, prompt_hook_on_completion_response_events,
prompt_hook_on_stream_completion_response_finish_events, prompt_hook_on_text_delta_events,
prompt_hook_on_tool_call_delta_events, prompt_hook_on_tool_call_events,
prompt_hook_on_tool_result_events, review_thread_messages, review_threads,
};
use codetaker_db::{DatabaseConnection, DatabasePool};
use diesel::OptionalExtension;
@@ -16,7 +24,7 @@ use diesel_async::RunQueryDsl;
use serde_json::Value;
use crate::error::{AgentError, AgentResult};
use crate::memory::{MemoryStore, MemorySummary, ProjectContextSnapshot};
use crate::memory::{AuditStore, MemoryStore, MemorySummary, ProjectContextSnapshot};
use crate::types::{MessageAuthor, ProjectRef, ThreadMessage};
type PooledConnection<'a> =
@@ -90,6 +98,7 @@ impl DieselMemoryStore {
}
}
#[async_trait]
impl MemoryStore for DieselMemoryStore {
async fn project_context_snapshot(
&self,
@@ -221,6 +230,7 @@ impl MemoryStore for DieselMemoryStore {
async fn create_review_thread(
&self,
project: &ProjectRef,
pull_request_id: i64,
file: &str,
line: i32,
initial_comment: &str,
@@ -232,6 +242,7 @@ impl MemoryStore for DieselMemoryStore {
let now = Utc::now().naive_utc();
let new_row = NewReviewThreadRow {
project_id,
pull_request_id,
file,
line,
initial_comment,
@@ -248,6 +259,7 @@ impl MemoryStore for DieselMemoryStore {
dsl::review_threads
.filter(dsl::project_id.eq(project_id))
.filter(dsl::pull_request_id.eq(pull_request_id))
.filter(dsl::file.eq(file))
.filter(dsl::line.eq(line))
.filter(dsl::initial_comment.eq(initial_comment))
@@ -317,3 +329,292 @@ impl MemoryStore for DieselMemoryStore {
.collect()
}
}
#[async_trait]
impl AuditStore for DieselMemoryStore {
async fn start_prompt_audit_session(
&self,
project: &ProjectRef,
pull_request_id: i64,
entrypoint: &str,
started_at: chrono::NaiveDateTime,
) -> AgentResult<i32> {
use codetaker_db::schema::prompt_audit_sessions::dsl;
let mut conn = self.get_conn().await?;
let project_id = Self::ensure_project_id(&mut conn, project).await?;
let new_row = NewPromptAuditSessionRow {
project_id,
pull_request_id,
entrypoint,
started_at,
ended_at: None,
status: "started",
error_message: None,
};
diesel::insert_into(prompt_audit_sessions::table)
.values(&new_row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to create prompt audit session: {err}"),
})?;
dsl::prompt_audit_sessions
.order(dsl::id.desc())
.select(dsl::id)
.first::<i32>(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to fetch prompt audit session id: {err}"),
})
}
async fn finish_prompt_audit_session(
&self,
session_id: i32,
ended_at: chrono::NaiveDateTime,
status: &str,
error_message: Option<&str>,
) -> AgentResult<()> {
use codetaker_db::schema::prompt_audit_sessions::dsl;
let mut conn = self.get_conn().await?;
diesel::update(dsl::prompt_audit_sessions.filter(dsl::id.eq(session_id)))
.set((
dsl::ended_at.eq(Some(ended_at)),
dsl::status.eq(status),
dsl::error_message.eq(error_message),
))
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to finish prompt audit session {session_id}: {err}"),
})?;
Ok(())
}
async fn log_on_completion_call(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
history_json: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnCompletionCallEventRow {
session_id,
event_at,
sequence_no,
prompt_json,
history_json,
};
diesel::insert_into(prompt_hook_on_completion_call_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_completion_call event: {err}"),
})?;
Ok(())
}
async fn log_on_completion_response(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
assistant_choice_json: &str,
usage_input_tokens: i64,
usage_output_tokens: i64,
usage_total_tokens: i64,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnCompletionResponseEventRow {
session_id,
event_at,
sequence_no,
prompt_json,
assistant_choice_json,
usage_input_tokens,
usage_output_tokens,
usage_total_tokens,
};
diesel::insert_into(prompt_hook_on_completion_response_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_completion_response event: {err}"),
})?;
Ok(())
}
async fn log_on_tool_call(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
tool_name: &str,
tool_call_id: Option<&str>,
internal_call_id: &str,
args_json: &str,
action: &str,
reason: Option<&str>,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnToolCallEventRow {
session_id,
event_at,
sequence_no,
tool_name,
tool_call_id,
internal_call_id,
args_json,
action,
reason,
};
diesel::insert_into(prompt_hook_on_tool_call_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_tool_call event: {err}"),
})?;
Ok(())
}
async fn log_on_tool_result(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
tool_name: &str,
tool_call_id: Option<&str>,
internal_call_id: &str,
args_json: &str,
result_json: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnToolResultEventRow {
session_id,
event_at,
sequence_no,
tool_name,
tool_call_id,
internal_call_id,
args_json,
result_json,
};
diesel::insert_into(prompt_hook_on_tool_result_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_tool_result event: {err}"),
})?;
Ok(())
}
async fn log_on_text_delta(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
text_delta: &str,
aggregated_text: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnTextDeltaEventRow {
session_id,
event_at,
sequence_no,
text_delta,
aggregated_text,
};
diesel::insert_into(prompt_hook_on_text_delta_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_text_delta event: {err}"),
})?;
Ok(())
}
async fn log_on_tool_call_delta(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
tool_call_id: &str,
internal_call_id: &str,
tool_name: Option<&str>,
tool_call_delta: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnToolCallDeltaEventRow {
session_id,
event_at,
sequence_no,
tool_call_id,
internal_call_id,
tool_name,
tool_call_delta,
};
diesel::insert_into(prompt_hook_on_tool_call_delta_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_tool_call_delta event: {err}"),
})?;
Ok(())
}
async fn log_on_stream_completion_response_finish(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
response_summary_json: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnStreamCompletionResponseFinishEventRow {
session_id,
event_at,
sequence_no,
prompt_json,
response_summary_json,
};
diesel::insert_into(prompt_hook_on_stream_completion_response_finish_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!(
"failed to log on_stream_completion_response_finish event: {err}"
),
})?;
Ok(())
}
}

View File

@@ -9,7 +9,7 @@ use serde_json::Value;
use crate::error::{AgentError, AgentResult};
use crate::git_access;
use crate::tools::{FileContext, SearchHit};
use crate::tools::SearchHit;
#[derive(Debug, Clone)]
pub struct AstGrepTool {
@@ -21,24 +21,14 @@ pub struct AstGrepTool {
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "operation", rename_all = "snake_case")]
pub enum AstGrepArgs {
SearchSymbol {
query: String,
},
SearchPattern {
query: String,
},
GetFileContext {
file: String,
line: i32,
radius: i32,
},
SearchSymbol { query: String },
SearchPattern { query: String },
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "operation", rename_all = "snake_case")]
pub enum AstGrepOutput {
SearchHits { hits: Vec<SearchHit> },
FileContext { context: FileContext },
}
impl Default for AstGrepTool {
@@ -171,12 +161,6 @@ impl AstGrepTool {
pub fn search_pattern(&self, query: &str) -> AgentResult<Vec<SearchHit>> {
self.run_query(query)
}
pub fn get_file_context(&self, file: &str, line: i32, radius: i32) -> AgentResult<FileContext> {
let (repo_git_dir, head_ref) = self.bound_state()?;
let repo = self.open_repo(&repo_git_dir)?;
git_access::read_file_context_at_ref(&repo, &head_ref, file, line, radius)
}
}
impl Tool for AstGrepTool {
@@ -189,37 +173,18 @@ impl Tool for AstGrepTool {
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_owned(),
description: "Search source code using ast-grep and fetch file context.".to_owned(),
description: "Search source code using ast-grep.".to_owned(),
parameters: serde_json::json!({
"type": "object",
"oneOf": [
{
"type": "object",
"properties": {
"operation": { "type": "string", "const": "search_symbol" },
"query": { "type": "string", "description": "Symbol-like query" }
},
"required": ["operation", "query"]
"properties": {
"operation": {
"type": "string",
"enum": ["search_symbol", "search_pattern"],
"description": "Tool operation. Both operations require `query`."
},
{
"type": "object",
"properties": {
"operation": { "type": "string", "const": "search_pattern" },
"query": { "type": "string", "description": "Pattern query" }
},
"required": ["operation", "query"]
},
{
"type": "object",
"properties": {
"operation": { "type": "string", "const": "get_file_context" },
"file": { "type": "string", "description": "Path to file in repository" },
"line": { "type": "integer", "description": "1-based line number" },
"radius": { "type": "integer", "description": "Number of lines before/after line" }
},
"required": ["operation", "file", "line", "radius"]
}
]
"query": { "type": "string", "description": "Symbol-like or pattern query" },
},
"required": ["operation"]
}),
}
}
@@ -234,10 +199,6 @@ impl Tool for AstGrepTool {
let hits = self.search_pattern(&query)?;
Ok(AstGrepOutput::SearchHits { hits })
}
AstGrepArgs::GetFileContext { file, line, radius } => {
let context = self.get_file_context(&file, line, radius)?;
Ok(AstGrepOutput::FileContext { context })
}
}
}
}

View File

@@ -0,0 +1,148 @@
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::{AgentError, AgentResult};
use crate::memory::MemoryStore;
use crate::types::ProjectRef;
#[derive(Debug, Clone)]
pub struct ProjectMemoryWriteTool<M> {
memory: M,
project: Option<ProjectRef>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ProjectMemoryWriteArgs {
pub operation: MemoryWriteOperation,
pub key: Option<String>,
pub value: Option<String>,
pub source: Option<String>,
pub summary_type: Option<String>,
pub content: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryWriteOperation {
UpsertEntry,
UpsertSummary,
}
#[derive(Debug, Clone, Serialize)]
pub struct ProjectMemoryWriteOutput {
pub stored: bool,
pub detail: String,
}
impl<M> ProjectMemoryWriteTool<M> {
pub fn new(memory: M) -> Self {
Self {
memory,
project: None,
}
}
pub fn bind_to_project(&self, project: ProjectRef) -> Self
where
M: Clone,
{
Self {
memory: self.memory.clone(),
project: Some(project),
}
}
fn project(&self) -> AgentResult<&ProjectRef> {
self.project
.as_ref()
.ok_or_else(|| AgentError::ConfigError {
message: "project memory write tool is not bound to a project context".to_owned(),
})
}
}
impl<M> Tool for ProjectMemoryWriteTool<M>
where
M: MemoryStore + Clone + Send + Sync,
{
const NAME: &'static str = "project_memory_write";
type Error = AgentError;
type Args = ProjectMemoryWriteArgs;
type Output = ProjectMemoryWriteOutput;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_owned(),
description: "Persist project-specific memory entries or summaries for future reviews."
.to_owned(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"operation": {
"type": "string",
"enum": ["upsert_entry", "upsert_summary"],
"description": "upsert_entry requires key and value. upsert_summary requires summary_type and content."
},
"key": { "type": "string", "description": "Memory key for upsert_entry." },
"value": { "type": "string", "description": "Memory value for upsert_entry. JSON text is accepted; plain text is stored as JSON string." },
"source": { "type": "string", "description": "Source tag for upsert_entry. Defaults to agent_tool." },
"summary_type": { "type": "string", "description": "Summary type for upsert_summary." },
"content": { "type": "string", "description": "Summary content for upsert_summary." }
},
"required": ["operation"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let project = self.project()?;
match args.operation {
MemoryWriteOperation::UpsertEntry => {
let key = args.key.ok_or_else(|| AgentError::ToolError {
tool: Self::NAME.to_owned(),
message: "missing required field `key` for upsert_entry".to_owned(),
})?;
let raw_value = args.value.ok_or_else(|| AgentError::ToolError {
tool: Self::NAME.to_owned(),
message: "missing required field `value` for upsert_entry".to_owned(),
})?;
let value: Value =
serde_json::from_str(&raw_value).unwrap_or_else(|_| Value::String(raw_value));
let source = args.source.as_deref().unwrap_or("agent_tool");
self.memory
.upsert_memory_entry(project, &key, &value, source)
.await?;
Ok(ProjectMemoryWriteOutput {
stored: true,
detail: format!("stored memory entry `{key}`"),
})
}
MemoryWriteOperation::UpsertSummary => {
let summary_type = args.summary_type.ok_or_else(|| AgentError::ToolError {
tool: Self::NAME.to_owned(),
message: "missing required field `summary_type` for upsert_summary".to_owned(),
})?;
let content = args.content.ok_or_else(|| AgentError::ToolError {
tool: Self::NAME.to_owned(),
message: "missing required field `content` for upsert_summary".to_owned(),
})?;
self.memory
.upsert_memory_summary(project, &summary_type, &content)
.await?;
Ok(ProjectMemoryWriteOutput {
stored: true,
detail: format!("stored memory summary `{summary_type}`"),
})
}
}
}
}

View File

@@ -1,8 +1,14 @@
mod ast_grep;
mod memory_write;
mod readfile;
use serde::{Deserialize, Serialize};
pub use ast_grep::{AstGrepArgs, AstGrepOutput, AstGrepTool};
pub use memory_write::{
MemoryWriteOperation, ProjectMemoryWriteArgs, ProjectMemoryWriteOutput, ProjectMemoryWriteTool,
};
pub use readfile::{ReadFileArgs, ReadFileOutput, ReadFileTool};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SearchHit {

View File

@@ -0,0 +1,115 @@
use std::path::{Path, PathBuf};
use git2::Repository;
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use serde::{Deserialize, Serialize};
use crate::error::{AgentError, AgentResult};
use crate::git_access;
use crate::tools::FileContext;
#[derive(Debug, Clone)]
pub struct ReadFileTool {
repo_git_dir: Option<PathBuf>,
head_ref: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ReadFileArgs {
pub file: String,
pub line: i32,
pub radius: i32,
}
#[derive(Debug, Clone, Serialize)]
pub struct ReadFileOutput {
pub context: FileContext,
}
impl Default for ReadFileTool {
fn default() -> Self {
Self::new()
}
}
impl ReadFileTool {
pub fn new() -> Self {
Self {
repo_git_dir: None,
head_ref: None,
}
}
pub fn bind_to_request(
&self,
repo: &Repository,
head_ref: impl Into<String>,
) -> AgentResult<Self> {
git_access::ensure_bare_repository(repo)?;
Ok(Self {
repo_git_dir: Some(repo.path().to_path_buf()),
head_ref: Some(head_ref.into()),
})
}
fn bound_state(&self) -> AgentResult<(PathBuf, String)> {
let repo_git_dir = self
.repo_git_dir
.clone()
.ok_or_else(|| AgentError::ConfigError {
message: "readfile tool is not bound to a repository context".to_owned(),
})?;
let head_ref = self
.head_ref
.clone()
.ok_or_else(|| AgentError::ConfigError {
message: "readfile tool is not bound to a head_ref".to_owned(),
})?;
Ok((repo_git_dir, head_ref))
}
fn open_repo(&self, repo_git_dir: &Path) -> AgentResult<Repository> {
Repository::open_bare(repo_git_dir).map_err(|err| AgentError::GitError {
operation: format!("open bare repository at '{}'", repo_git_dir.display()),
message: err.to_string(),
})
}
fn read_context(&self, file: &str, line: i32, radius: i32) -> AgentResult<FileContext> {
let (repo_git_dir, head_ref) = self.bound_state()?;
let repo = self.open_repo(&repo_git_dir)?;
git_access::read_file_context_at_ref(&repo, &head_ref, file, line, radius)
}
}
impl Tool for ReadFileTool {
const NAME: &'static str = "readfile";
type Error = AgentError;
type Args = ReadFileArgs;
type Output = ReadFileOutput;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_owned(),
description: "Read file context from the repository at head_ref.".to_owned(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"file": { "type": "string", "description": "Path to file in repository" },
"line": { "type": "integer", "description": "1-based line number" },
"radius": { "type": "integer", "description": "Number of lines before/after line" }
},
"required": ["file", "line", "radius"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let context = self.read_context(&args.file, args.line, args.radius)?;
Ok(ReadFileOutput { context })
}
}

View File

@@ -50,6 +50,7 @@ pub struct ProjectRef {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PullRequestReviewInput {
pub project: ProjectRef,
pub pull_request_id: i64,
pub base_ref: String,
pub head_ref: String,
}
@@ -57,6 +58,7 @@ pub struct PullRequestReviewInput {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationInput {
pub project: ProjectRef,
pub pull_request_id: i64,
pub head_ref: String,
pub anchor_file: String,
pub anchor_line: i32,

View File

@@ -1,8 +1,11 @@
use std::path::PathBuf;
use codetaker_agent::memory::{DieselMemoryStore, MemoryStore};
use codetaker_agent::memory::{AuditStore, DieselMemoryStore, MemoryStore};
use codetaker_agent::types::{Forge, MessageAuthor, ProjectRef};
use codetaker_db::create_pool;
use diesel::ExpressionMethods;
use diesel::QueryDsl;
use diesel_async::RunQueryDsl;
use serde_json::json;
fn test_db_path(test_name: &str) -> PathBuf {
@@ -87,7 +90,7 @@ async fn thread_messages_round_trip() {
let project = sample_project();
let thread_id = store
.create_review_thread(&project, "src/lib.rs", 42, "Please avoid unwrap here")
.create_review_thread(&project, 123, "src/lib.rs", 42, "Please avoid unwrap here")
.await
.expect("create review thread");
@@ -113,3 +116,195 @@ async fn thread_messages_round_trip() {
assert!(matches!(messages[0].author, MessageAuthor::User));
assert!(matches!(messages[1].author, MessageAuthor::Agent));
}
#[tokio::test]
async fn prompt_audit_session_lifecycle_success_and_failure() {
let db_path = test_db_path("prompt_audit_session_lifecycle_success_and_failure");
let database_url = db_path.display().to_string();
let pool = create_pool(Some(&database_url))
.await
.expect("create sqlite pool");
let store = DieselMemoryStore::new(pool.clone());
let project = sample_project();
let now = chrono::Utc::now().naive_utc();
let success_session = store
.start_prompt_audit_session(&project, 11, "pull_request_review", now)
.await
.expect("start success session");
store
.finish_prompt_audit_session(
success_session,
chrono::Utc::now().naive_utc(),
"completed",
None,
)
.await
.expect("finish success session");
let failed_session = store
.start_prompt_audit_session(&project, 12, "conversation_response", now)
.await
.expect("start failed session");
store
.finish_prompt_audit_session(
failed_session,
chrono::Utc::now().naive_utc(),
"failed",
Some("model error"),
)
.await
.expect("finish failed session");
let mut conn = pool.get().await.expect("get pooled connection");
use codetaker_db::schema::prompt_audit_sessions::dsl;
let statuses = dsl::prompt_audit_sessions
.filter(dsl::id.eq_any([success_session, failed_session]))
.select(dsl::status)
.load::<String>(&mut conn)
.await
.expect("load session statuses");
assert!(statuses.iter().any(|status| status == "completed"));
assert!(statuses.iter().any(|status| status == "failed"));
}
#[tokio::test]
async fn prompt_hook_event_methods_write_rows() {
let db_path = test_db_path("prompt_hook_event_methods_write_rows");
let database_url = db_path.display().to_string();
let pool = create_pool(Some(&database_url))
.await
.expect("create sqlite pool");
let store = DieselMemoryStore::new(pool.clone());
let project = sample_project();
let now = chrono::Utc::now().naive_utc();
let session_id = store
.start_prompt_audit_session(&project, 99, "pull_request_review", now)
.await
.expect("start audit session");
store
.log_on_completion_call(session_id, now, 1, "{\"role\":\"user\"}", "[]")
.await
.expect("log completion call");
store
.log_on_completion_response(session_id, now, 2, "{\"role\":\"user\"}", "[]", 10, 20, 30)
.await
.expect("log completion response");
store
.log_on_tool_call(
session_id,
now,
3,
"ast_grep",
Some("call_1"),
"internal_1",
"{\"query\":\"foo\"}",
"continue",
None,
)
.await
.expect("log tool call");
store
.log_on_tool_result(
session_id,
now,
4,
"ast_grep",
Some("call_1"),
"internal_1",
"{\"query\":\"foo\"}",
"{\"hits\":[]}",
)
.await
.expect("log tool result");
store
.log_on_text_delta(session_id, now, 5, "foo", "foobar")
.await
.expect("log text delta");
store
.log_on_tool_call_delta(
session_id,
now,
6,
"call_2",
"internal_2",
Some("readfile"),
"{\"path\":\"src/lib.rs\"}",
)
.await
.expect("log tool call delta");
store
.log_on_stream_completion_response_finish(session_id, now, 7, "{}", "{\"done\":true}")
.await
.expect("log stream finish");
let mut conn = pool.get().await.expect("get pooled connection");
use codetaker_db::schema::{
prompt_hook_on_completion_call_events::dsl as completion_call_dsl,
prompt_hook_on_completion_response_events::dsl as completion_response_dsl,
prompt_hook_on_stream_completion_response_finish_events::dsl as stream_finish_dsl,
prompt_hook_on_text_delta_events::dsl as text_delta_dsl,
prompt_hook_on_tool_call_delta_events::dsl as tool_call_delta_dsl,
prompt_hook_on_tool_call_events::dsl as tool_call_dsl,
prompt_hook_on_tool_result_events::dsl as tool_result_dsl,
};
let completion_call_count = completion_call_dsl::prompt_hook_on_completion_call_events
.filter(completion_call_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count completion call events");
assert_eq!(completion_call_count, 1);
let completion_response_count = completion_response_dsl::prompt_hook_on_completion_response_events
.filter(completion_response_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count completion response events");
assert_eq!(completion_response_count, 1);
let tool_call_count = tool_call_dsl::prompt_hook_on_tool_call_events
.filter(tool_call_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count tool call events");
assert_eq!(tool_call_count, 1);
let tool_result_count = tool_result_dsl::prompt_hook_on_tool_result_events
.filter(tool_result_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count tool result events");
assert_eq!(tool_result_count, 1);
let text_delta_count = text_delta_dsl::prompt_hook_on_text_delta_events
.filter(text_delta_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count text delta events");
assert_eq!(text_delta_count, 1);
let tool_delta_count = tool_call_delta_dsl::prompt_hook_on_tool_call_delta_events
.filter(tool_call_delta_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count tool call delta events");
assert_eq!(tool_delta_count, 1);
let stream_finish_count =
stream_finish_dsl::prompt_hook_on_stream_completion_response_finish_events
.filter(stream_finish_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count stream finish events");
assert_eq!(stream_finish_count, 1);
}