Compare commits

2 Commits

Author SHA1 Message Date
hdbg
3a210809cf feat(agent): add comprehensive prompt audit system 2026-02-27 14:27:52 +01:00
hdbg
3738272f80 fix(agent): adjusted prompt and removed useless schema for output 2026-02-27 13:22:44 +01:00
21 changed files with 1804 additions and 119 deletions

88
Cargo.lock generated
View File

@@ -107,6 +107,17 @@ dependencies = [
"syn",
]
[[package]]
name = "async-trait"
version = "0.1.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "atomic-waker"
version = "1.1.2"
@@ -274,6 +285,7 @@ dependencies = [
name = "codetaker-agent"
version = "0.1.0"
dependencies = [
"async-trait",
"chrono",
"codetaker-db",
"diesel",
@@ -294,6 +306,7 @@ dependencies = [
"clap",
"codetaker-agent",
"codetaker-db",
"dialoguer",
"git2",
"rig-core",
"serde_json",
@@ -354,6 +367,19 @@ version = "0.4.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d"
[[package]]
name = "console"
version = "0.16.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03e45a4a8926227e4197636ba97a9fc9b00477e9f4bd711395687c5f0734bec4"
dependencies = [
"encode_unicode",
"libc",
"once_cell",
"unicode-width",
"windows-sys 0.61.2",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -433,6 +459,18 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "dialoguer"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25f104b501bf2364e78d0d3974cbc774f738f5865306ed128e1e0d7499c0ad96"
dependencies = [
"console",
"shell-words",
"tempfile",
"zeroize",
]
[[package]]
name = "diesel"
version = "2.3.6"
@@ -545,6 +583,12 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "encode_unicode"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
[[package]]
name = "encoding_rs"
version = "0.8.35"
@@ -1182,6 +1226,12 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "linux-raw-sys"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53"
[[package]]
name = "litemap"
version = "0.8.1"
@@ -1738,6 +1788,19 @@ version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustix"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190"
dependencies = [
"bitflags",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.61.2",
]
[[package]]
name = "rustls"
version = "0.23.37"
@@ -1990,6 +2053,12 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "shell-words"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77"
[[package]]
name = "shlex"
version = "1.3.0"
@@ -2116,6 +2185,19 @@ dependencies = [
"libc",
]
[[package]]
name = "tempfile"
version = "3.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0"
dependencies = [
"fastrand",
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.61.2",
]
[[package]]
name = "thiserror"
version = "1.0.69"
@@ -2440,6 +2522,12 @@ version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "untrusted"
version = "0.9.0"

View File

@@ -5,8 +5,10 @@ members = [
]
[workspace.dependencies]
async-trait = "0.1.89"
clap = { version = "4.5.60", features = ["derive", "env"] }
codetaker-db = { path = "crates/codetaker-db" }
dialoguer = "0.12.0"
git2 = { version = "0.20.4", features = ["vendored-libgit2"] }
rig-core = "0.31.0"
schemars = "1.2.1"

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2024"
[dependencies]
async-trait.workspace = true
chrono.workspace = true
codetaker-db.workspace = true
diesel.workspace = true

View File

@@ -5,10 +5,11 @@ use rig::agent::AgentBuilder;
use rig::client::{Client, CompletionClient};
use rig::completion::{StructuredOutputError, TypedPrompt};
use crate::audit::{self, AuditEntrypoint};
use crate::error::{AgentError, AgentResult};
use crate::git_access::{self};
use crate::memory::{MemoryStore, ProjectContextSnapshot};
use crate::tools::{AstGrepTool, FileContext, SearchHit};
use crate::memory::{AuditStore, MemoryStore, ProjectContextSnapshot};
use crate::tools::{AstGrepTool, FileContext, ProjectMemoryWriteTool, ReadFileTool, SearchHit};
use crate::types::{ConversationInput, ConversationOutput};
const CONVERSATION_RESPONSE_PREAMBLE: &str = r#"
@@ -44,7 +45,7 @@ impl<Ext, H, M> ConversationAnswerAgent<Ext, H, M> {
impl<Ext, H, M> ConversationAnswerAgent<Ext, H, M>
where
Client<Ext, H>: CompletionClient,
M: MemoryStore,
M: MemoryStore + AuditStore + Clone + Send + Sync + 'static,
{
pub async fn conversation_response(
&self,
@@ -63,6 +64,9 @@ where
)?;
let ast_grep = self.ast_grep.bind_to_request(repo, &input.head_ref)?;
let readfile = ReadFileTool::new().bind_to_request(repo, &input.head_ref)?;
let memory_write =
ProjectMemoryWriteTool::new(self.memory.clone()).bind_to_project(input.project.clone());
let search_hits = self.collect_conversation_search_hits(&ast_grep, &input);
let prompt =
build_conversation_prompt(&input, &memory_snapshot, &anchor_context, &search_hits);
@@ -71,29 +75,56 @@ where
let agent = AgentBuilder::new(conversation_model)
.preamble(CONVERSATION_RESPONSE_PREAMBLE)
.tool(ast_grep)
.tool(readfile)
.tool(memory_write)
.temperature(0.1)
.output_schema::<ConversationOutput>()
.build();
agent
.prompt_typed::<ConversationOutput>(prompt)
.await
.map_err(|err| match err {
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
model: self.conversation_model.clone(),
message: prompt_err.to_string(),
},
StructuredOutputError::DeserializationError(deser_err) => {
AgentError::OutputValidationError {
message: format!("failed to deserialize conversation output: {deser_err}"),
raw_output: None,
}
let audit_session = audit::start_session(
self.memory.clone(),
&input.project,
input.pull_request_id,
AuditEntrypoint::ConversationResponse,
)
.await;
let result = if let Some(session) = &audit_session {
agent
.prompt_typed::<ConversationOutput>(prompt)
.with_hook(session.hook())
.await
} else {
agent.prompt_typed::<ConversationOutput>(prompt).await
};
match &result {
Ok(_) => {
if let Some(session) = &audit_session {
session.finish_success().await;
}
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
message: "conversation response was empty".to_owned(),
}
Err(err) => {
if let Some(session) = &audit_session {
session.finish_failure(&err.to_string()).await;
}
}
}
result.map_err(|err| match err {
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
model: self.conversation_model.clone(),
message: prompt_err.to_string(),
},
StructuredOutputError::DeserializationError(deser_err) => {
AgentError::OutputValidationError {
message: format!("failed to deserialize conversation output: {deser_err}"),
raw_output: None,
},
})
}
}
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
message: "conversation response was empty".to_owned(),
raw_output: None,
},
})
}
fn collect_conversation_search_hits(
@@ -126,6 +157,7 @@ fn build_conversation_prompt(
input.project.forge.as_db_value()
)
.ok();
writeln!(prompt, "Pull request ID: {}", input.pull_request_id).ok();
writeln!(prompt, "Head reference: '{}'", input.head_ref).ok();
writeln!(

View File

@@ -5,30 +5,23 @@ use rig::agent::AgentBuilder;
use rig::client::{Client, CompletionClient};
use rig::completion::{StructuredOutputError, TypedPrompt};
use crate::audit::{self, AuditEntrypoint};
use crate::error::{AgentError, AgentResult};
use crate::git_access::{self, PullRequestMaterial};
use crate::memory::{MemoryStore, ProjectContextSnapshot};
use crate::tools::{AstGrepTool, FileContext};
use crate::memory::{AuditStore, MemoryStore, ProjectContextSnapshot};
use crate::tools::{AstGrepTool, FileContext, ProjectMemoryWriteTool, ReadFileTool};
use crate::types::{PullRequestReviewInput, PullRequestReviewOutput};
const PULL_REQUEST_REVIEW_PREAMBLE: &str = r#"
You are a code review agent.
Analyze pull request diffs and respond only with valid JSON.
Use available tools when you need extra repository context.
Schema:
{
"review_result": "Approve" | "RequestChanges",
"global_comment": string,
"comments": [
{
"comment": string,
"file": string,
"line": integer
}
]
}
If you found valuable insight that might help future reviews, save it to the project memory with the memory write tool.
Be concise and focus on the most important issues in the code.
Explain your judgement clearly and provide actionable feedback for the developer.
When using tools, only request the specific information you need to make your review.
In the end, pass judgement if you want to accept it or request changes, and explain your reasoning.
Line numbers must target the head/new file version.
Do not include extra keys.
"#;
pub struct PullRequestReviewAgent<Ext, H, M> {
@@ -57,7 +50,7 @@ impl<Ext, H, M> PullRequestReviewAgent<Ext, H, M> {
impl<Ext, H, M> PullRequestReviewAgent<Ext, H, M>
where
Client<Ext, H>: CompletionClient,
M: MemoryStore,
M: MemoryStore + AuditStore + Clone + Send + Sync + 'static,
{
pub async fn pull_request_review(
&self,
@@ -73,36 +66,66 @@ where
build_pull_request_prompt(&input, &memory_snapshot, &pr_material, &file_contexts);
let ast_grep = self.ast_grep.bind_to_request(repo, &input.head_ref)?;
let readfile = ReadFileTool::new().bind_to_request(repo, &input.head_ref)?;
let memory_write =
ProjectMemoryWriteTool::new(self.memory.clone()).bind_to_project(input.project.clone());
let review_model = self.client.completion_model(&self.review_model);
let agent = AgentBuilder::new(review_model)
.preamble(PULL_REQUEST_REVIEW_PREAMBLE)
.tool(ast_grep)
.tool(readfile)
.tool(memory_write)
.temperature(0.0)
.output_schema::<PullRequestReviewOutput>()
.build();
agent
.prompt_typed::<PullRequestReviewOutput>(prompt)
.await
.map_err(|err| match err {
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
model: self.review_model.clone(),
message: prompt_err.to_string(),
},
StructuredOutputError::DeserializationError(deser_err) => {
AgentError::OutputValidationError {
message: format!(
"failed to deserialize pull request review output: {deser_err}"
),
raw_output: None,
}
let audit_session = audit::start_session(
self.memory.clone(),
&input.project,
input.pull_request_id,
AuditEntrypoint::PullRequestReview,
)
.await;
let result = if let Some(session) = &audit_session {
agent
.prompt_typed::<PullRequestReviewOutput>(prompt)
.with_hook(session.hook())
.await
} else {
agent.prompt_typed::<PullRequestReviewOutput>(prompt).await
};
match &result {
Ok(_) => {
if let Some(session) = &audit_session {
session.finish_success().await;
}
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
message: "pull request review response was empty".to_owned(),
}
Err(err) => {
if let Some(session) = &audit_session {
session.finish_failure(&err.to_string()).await;
}
}
}
result.map_err(|err| match err {
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
model: self.review_model.clone(),
message: prompt_err.to_string(),
},
StructuredOutputError::DeserializationError(deser_err) => {
AgentError::OutputValidationError {
message: format!(
"failed to deserialize pull request review output: {deser_err}"
),
raw_output: None,
},
})
}
}
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
message: "pull request review response was empty".to_owned(),
raw_output: None,
},
})
}
fn collect_diff_file_contexts(
@@ -150,6 +173,7 @@ fn build_pull_request_prompt(
input.project.forge.as_db_value()
)
.ok();
writeln!(prompt, "Pull request ID: {}", input.pull_request_id).ok();
writeln!(
prompt,

View File

@@ -0,0 +1,325 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicI32, Ordering};
use chrono::Utc;
use rig::agent::{HookAction, PromptHook, ToolCallHookAction};
use rig::completion::{CompletionModel, CompletionResponse};
use rig::message::Message;
use serde::Serialize;
use crate::memory::AuditStore;
use crate::types::ProjectRef;
#[derive(Debug, Clone, Copy)]
pub enum AuditEntrypoint {
PullRequestReview,
ConversationResponse,
}
impl AuditEntrypoint {
fn as_db_value(self) -> &'static str {
match self {
Self::PullRequestReview => "pull_request_review",
Self::ConversationResponse => "conversation_response",
}
}
}
#[derive(Clone)]
pub struct PromptAuditHook<M> {
store: M,
session_id: i32,
sequence: Arc<AtomicI32>,
}
impl<M> PromptAuditHook<M> {
pub fn new(store: M, session_id: i32) -> Self {
Self {
store,
session_id,
sequence: Arc::new(AtomicI32::new(0)),
}
}
fn next_sequence(&self) -> i32 {
self.sequence.fetch_add(1, Ordering::Relaxed) + 1
}
}
pub struct AuditSession<M> {
store: M,
session_id: i32,
}
impl<M> AuditSession<M>
where
M: AuditStore + Clone + Send + Sync + 'static,
{
pub fn hook(&self) -> PromptAuditHook<M> {
PromptAuditHook::new(self.store.clone(), self.session_id)
}
pub async fn finish_success(&self) {
let ended_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.finish_prompt_audit_session(self.session_id, ended_at, "completed", None)
.await
{
tracing::warn!(error = %err, "failed to finish successful prompt audit session");
}
}
pub async fn finish_failure(&self, error_message: &str) {
let ended_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.finish_prompt_audit_session(
self.session_id,
ended_at,
"failed",
Some(error_message),
)
.await
{
tracing::warn!(error = %err, "failed to finish failed prompt audit session");
}
}
}
pub async fn start_session<M>(
store: M,
project: &ProjectRef,
pull_request_id: i64,
entrypoint: AuditEntrypoint,
) -> Option<AuditSession<M>>
where
M: AuditStore + Clone + Send + Sync + 'static,
{
let started_at = Utc::now().naive_utc();
match store
.start_prompt_audit_session(
project,
pull_request_id,
entrypoint.as_db_value(),
started_at,
)
.await
{
Ok(session_id) => Some(AuditSession { store, session_id }),
Err(err) => {
tracing::warn!(error = %err, "failed to start prompt audit session");
None
}
}
}
impl<M, Model> PromptHook<Model> for PromptAuditHook<M>
where
M: AuditStore + Clone + Send + Sync + 'static,
Model: CompletionModel,
{
async fn on_completion_call(&self, prompt: &Message, history: &[Message]) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
let prompt_json = serialize_to_json(prompt);
let history_json = serialize_to_json(history);
if let Err(err) = self
.store
.log_on_completion_call(
self.session_id,
event_at,
sequence_no,
&prompt_json,
&history_json,
)
.await
{
tracing::warn!(error = %err, "failed to log on_completion_call");
}
HookAction::cont()
}
async fn on_completion_response(
&self,
prompt: &Message,
response: &CompletionResponse<Model::Response>,
) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
let prompt_json = serialize_to_json(prompt);
let assistant_choice_json = serialize_to_json(&response.choice);
let usage_input_tokens = u64_to_i64(response.usage.input_tokens);
let usage_output_tokens = u64_to_i64(response.usage.output_tokens);
let usage_total_tokens = u64_to_i64(response.usage.total_tokens);
if let Err(err) = self
.store
.log_on_completion_response(
self.session_id,
event_at,
sequence_no,
&prompt_json,
&assistant_choice_json,
usage_input_tokens,
usage_output_tokens,
usage_total_tokens,
)
.await
{
tracing::warn!(error = %err, "failed to log on_completion_response");
}
HookAction::cont()
}
async fn on_tool_call(
&self,
tool_name: &str,
tool_call_id: Option<String>,
internal_call_id: &str,
args: &str,
) -> ToolCallHookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.log_on_tool_call(
self.session_id,
event_at,
sequence_no,
tool_name,
tool_call_id.as_deref(),
internal_call_id,
args,
"continue",
None,
)
.await
{
tracing::warn!(error = %err, "failed to log on_tool_call");
}
ToolCallHookAction::cont()
}
async fn on_tool_result(
&self,
tool_name: &str,
tool_call_id: Option<String>,
internal_call_id: &str,
args: &str,
result: &str,
) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.log_on_tool_result(
self.session_id,
event_at,
sequence_no,
tool_name,
tool_call_id.as_deref(),
internal_call_id,
args,
result,
)
.await
{
tracing::warn!(error = %err, "failed to log on_tool_result");
}
HookAction::cont()
}
async fn on_text_delta(&self, text_delta: &str, aggregated_text: &str) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.log_on_text_delta(
self.session_id,
event_at,
sequence_no,
text_delta,
aggregated_text,
)
.await
{
tracing::warn!(error = %err, "failed to log on_text_delta");
}
HookAction::cont()
}
async fn on_tool_call_delta(
&self,
tool_call_id: &str,
internal_call_id: &str,
tool_name: Option<&str>,
tool_call_delta: &str,
) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
if let Err(err) = self
.store
.log_on_tool_call_delta(
self.session_id,
event_at,
sequence_no,
tool_call_id,
internal_call_id,
tool_name,
tool_call_delta,
)
.await
{
tracing::warn!(error = %err, "failed to log on_tool_call_delta");
}
HookAction::cont()
}
async fn on_stream_completion_response_finish(
&self,
prompt: &Message,
response: &Model::StreamingResponse,
) -> HookAction {
let sequence_no = self.next_sequence();
let event_at = Utc::now().naive_utc();
let prompt_json = serialize_to_json(prompt);
let response_summary_json = serialize_to_json(response);
if let Err(err) = self
.store
.log_on_stream_completion_response_finish(
self.session_id,
event_at,
sequence_no,
&prompt_json,
&response_summary_json,
)
.await
{
tracing::warn!(error = %err, "failed to log on_stream_completion_response_finish");
}
HookAction::cont()
}
}
fn serialize_to_json<T: Serialize + ?Sized>(value: &T) -> String {
serde_json::to_string(value)
.unwrap_or_else(|err| format!(r#"{{"serialization_error":"{}"}}"#, err))
}
fn u64_to_i64(value: u64) -> i64 {
i64::try_from(value).unwrap_or(i64::MAX)
}

View File

@@ -1,4 +1,5 @@
pub mod agent;
pub mod audit;
pub mod error;
pub mod git_access;
pub mod memory;
@@ -8,8 +9,12 @@ pub mod types;
pub use agent::{ConversationAnswerAgent, PullRequestReviewAgent};
pub use error::{AgentError, AgentResult};
pub use git_access::PullRequestMaterial;
pub use memory::{DieselMemoryStore, MemoryStore, ProjectContextSnapshot};
pub use tools::{AstGrepArgs, AstGrepOutput, AstGrepTool, FileContext, SearchHit};
pub use memory::{AuditStore, DieselMemoryStore, MemoryStore, ProjectContextSnapshot};
pub use tools::{
AstGrepArgs, AstGrepOutput, AstGrepTool, FileContext, MemoryWriteOperation,
ProjectMemoryWriteArgs, ProjectMemoryWriteOutput, ProjectMemoryWriteTool, ReadFileArgs,
ReadFileOutput, ReadFileTool, SearchHit,
};
pub use types::{
ConversationInput, ConversationOutput, Forge, MessageAuthor, ProjectRef,
PullRequestReviewInput, PullRequestReviewOutput, ReviewComment, ReviewResult, ThreadMessage,

View File

@@ -2,6 +2,7 @@ mod sqlite;
use std::collections::BTreeMap;
use async_trait::async_trait;
use chrono::NaiveDateTime;
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -24,8 +25,8 @@ pub struct ProjectContextSnapshot {
pub summaries: Vec<MemorySummary>,
}
#[allow(async_fn_in_trait)]
pub trait MemoryStore {
#[async_trait]
pub trait MemoryStore: Send + Sync {
async fn project_context_snapshot(
&self,
project: &ProjectRef,
@@ -49,6 +50,7 @@ pub trait MemoryStore {
async fn create_review_thread(
&self,
project: &ProjectRef,
pull_request_id: i64,
file: &str,
line: i32,
initial_comment: &str,
@@ -63,3 +65,97 @@ pub trait MemoryStore {
async fn load_thread_messages(&self, thread_id: i32) -> AgentResult<Vec<ThreadMessage>>;
}
#[async_trait]
pub trait AuditStore: Send + Sync {
async fn start_prompt_audit_session(
&self,
project: &ProjectRef,
pull_request_id: i64,
entrypoint: &str,
started_at: NaiveDateTime,
) -> AgentResult<i32>;
async fn finish_prompt_audit_session(
&self,
session_id: i32,
ended_at: NaiveDateTime,
status: &str,
error_message: Option<&str>,
) -> AgentResult<()>;
async fn log_on_completion_call(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
history_json: &str,
) -> AgentResult<()>;
async fn log_on_completion_response(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
assistant_choice_json: &str,
usage_input_tokens: i64,
usage_output_tokens: i64,
usage_total_tokens: i64,
) -> AgentResult<()>;
async fn log_on_tool_call(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
tool_name: &str,
tool_call_id: Option<&str>,
internal_call_id: &str,
args_json: &str,
action: &str,
reason: Option<&str>,
) -> AgentResult<()>;
async fn log_on_tool_result(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
tool_name: &str,
tool_call_id: Option<&str>,
internal_call_id: &str,
args_json: &str,
result_json: &str,
) -> AgentResult<()>;
async fn log_on_text_delta(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
text_delta: &str,
aggregated_text: &str,
) -> AgentResult<()>;
async fn log_on_tool_call_delta(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
tool_call_id: &str,
internal_call_id: &str,
tool_name: Option<&str>,
tool_call_delta: &str,
) -> AgentResult<()>;
async fn log_on_stream_completion_response_finish(
&self,
session_id: i32,
event_at: NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
response_summary_json: &str,
) -> AgentResult<()>;
}

View File

@@ -1,13 +1,21 @@
use std::collections::BTreeMap;
use async_trait::async_trait;
use chrono::Utc;
use codetaker_db::models::{
NewProjectMemoryEntryRow, NewProjectMemorySummaryRow, NewProjectRow, NewReviewThreadMessageRow,
NewReviewThreadRow, ProjectMemoryEntryRow, ProjectMemorySummaryRow, ReviewThreadMessageRow,
NewProjectMemoryEntryRow, NewProjectMemorySummaryRow, NewProjectRow, NewPromptAuditSessionRow,
NewPromptHookOnCompletionCallEventRow, NewPromptHookOnCompletionResponseEventRow,
NewPromptHookOnStreamCompletionResponseFinishEventRow, NewPromptHookOnTextDeltaEventRow,
NewPromptHookOnToolCallDeltaEventRow, NewPromptHookOnToolCallEventRow,
NewPromptHookOnToolResultEventRow, NewReviewThreadMessageRow, NewReviewThreadRow,
ProjectMemoryEntryRow, ProjectMemorySummaryRow, ReviewThreadMessageRow,
};
use codetaker_db::schema::{
project_memory_entries, project_memory_summaries, projects, review_thread_messages,
review_threads,
project_memory_entries, project_memory_summaries, projects, prompt_audit_sessions,
prompt_hook_on_completion_call_events, prompt_hook_on_completion_response_events,
prompt_hook_on_stream_completion_response_finish_events, prompt_hook_on_text_delta_events,
prompt_hook_on_tool_call_delta_events, prompt_hook_on_tool_call_events,
prompt_hook_on_tool_result_events, review_thread_messages, review_threads,
};
use codetaker_db::{DatabaseConnection, DatabasePool};
use diesel::OptionalExtension;
@@ -16,7 +24,7 @@ use diesel_async::RunQueryDsl;
use serde_json::Value;
use crate::error::{AgentError, AgentResult};
use crate::memory::{MemoryStore, MemorySummary, ProjectContextSnapshot};
use crate::memory::{AuditStore, MemoryStore, MemorySummary, ProjectContextSnapshot};
use crate::types::{MessageAuthor, ProjectRef, ThreadMessage};
type PooledConnection<'a> =
@@ -90,6 +98,7 @@ impl DieselMemoryStore {
}
}
#[async_trait]
impl MemoryStore for DieselMemoryStore {
async fn project_context_snapshot(
&self,
@@ -221,6 +230,7 @@ impl MemoryStore for DieselMemoryStore {
async fn create_review_thread(
&self,
project: &ProjectRef,
pull_request_id: i64,
file: &str,
line: i32,
initial_comment: &str,
@@ -232,6 +242,7 @@ impl MemoryStore for DieselMemoryStore {
let now = Utc::now().naive_utc();
let new_row = NewReviewThreadRow {
project_id,
pull_request_id,
file,
line,
initial_comment,
@@ -248,6 +259,7 @@ impl MemoryStore for DieselMemoryStore {
dsl::review_threads
.filter(dsl::project_id.eq(project_id))
.filter(dsl::pull_request_id.eq(pull_request_id))
.filter(dsl::file.eq(file))
.filter(dsl::line.eq(line))
.filter(dsl::initial_comment.eq(initial_comment))
@@ -317,3 +329,292 @@ impl MemoryStore for DieselMemoryStore {
.collect()
}
}
#[async_trait]
impl AuditStore for DieselMemoryStore {
async fn start_prompt_audit_session(
&self,
project: &ProjectRef,
pull_request_id: i64,
entrypoint: &str,
started_at: chrono::NaiveDateTime,
) -> AgentResult<i32> {
use codetaker_db::schema::prompt_audit_sessions::dsl;
let mut conn = self.get_conn().await?;
let project_id = Self::ensure_project_id(&mut conn, project).await?;
let new_row = NewPromptAuditSessionRow {
project_id,
pull_request_id,
entrypoint,
started_at,
ended_at: None,
status: "started",
error_message: None,
};
diesel::insert_into(prompt_audit_sessions::table)
.values(&new_row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to create prompt audit session: {err}"),
})?;
dsl::prompt_audit_sessions
.order(dsl::id.desc())
.select(dsl::id)
.first::<i32>(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to fetch prompt audit session id: {err}"),
})
}
async fn finish_prompt_audit_session(
&self,
session_id: i32,
ended_at: chrono::NaiveDateTime,
status: &str,
error_message: Option<&str>,
) -> AgentResult<()> {
use codetaker_db::schema::prompt_audit_sessions::dsl;
let mut conn = self.get_conn().await?;
diesel::update(dsl::prompt_audit_sessions.filter(dsl::id.eq(session_id)))
.set((
dsl::ended_at.eq(Some(ended_at)),
dsl::status.eq(status),
dsl::error_message.eq(error_message),
))
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to finish prompt audit session {session_id}: {err}"),
})?;
Ok(())
}
async fn log_on_completion_call(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
history_json: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnCompletionCallEventRow {
session_id,
event_at,
sequence_no,
prompt_json,
history_json,
};
diesel::insert_into(prompt_hook_on_completion_call_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_completion_call event: {err}"),
})?;
Ok(())
}
async fn log_on_completion_response(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
assistant_choice_json: &str,
usage_input_tokens: i64,
usage_output_tokens: i64,
usage_total_tokens: i64,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnCompletionResponseEventRow {
session_id,
event_at,
sequence_no,
prompt_json,
assistant_choice_json,
usage_input_tokens,
usage_output_tokens,
usage_total_tokens,
};
diesel::insert_into(prompt_hook_on_completion_response_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_completion_response event: {err}"),
})?;
Ok(())
}
async fn log_on_tool_call(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
tool_name: &str,
tool_call_id: Option<&str>,
internal_call_id: &str,
args_json: &str,
action: &str,
reason: Option<&str>,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnToolCallEventRow {
session_id,
event_at,
sequence_no,
tool_name,
tool_call_id,
internal_call_id,
args_json,
action,
reason,
};
diesel::insert_into(prompt_hook_on_tool_call_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_tool_call event: {err}"),
})?;
Ok(())
}
async fn log_on_tool_result(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
tool_name: &str,
tool_call_id: Option<&str>,
internal_call_id: &str,
args_json: &str,
result_json: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnToolResultEventRow {
session_id,
event_at,
sequence_no,
tool_name,
tool_call_id,
internal_call_id,
args_json,
result_json,
};
diesel::insert_into(prompt_hook_on_tool_result_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_tool_result event: {err}"),
})?;
Ok(())
}
async fn log_on_text_delta(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
text_delta: &str,
aggregated_text: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnTextDeltaEventRow {
session_id,
event_at,
sequence_no,
text_delta,
aggregated_text,
};
diesel::insert_into(prompt_hook_on_text_delta_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_text_delta event: {err}"),
})?;
Ok(())
}
async fn log_on_tool_call_delta(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
tool_call_id: &str,
internal_call_id: &str,
tool_name: Option<&str>,
tool_call_delta: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnToolCallDeltaEventRow {
session_id,
event_at,
sequence_no,
tool_call_id,
internal_call_id,
tool_name,
tool_call_delta,
};
diesel::insert_into(prompt_hook_on_tool_call_delta_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to log on_tool_call_delta event: {err}"),
})?;
Ok(())
}
async fn log_on_stream_completion_response_finish(
&self,
session_id: i32,
event_at: chrono::NaiveDateTime,
sequence_no: i32,
prompt_json: &str,
response_summary_json: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let row = NewPromptHookOnStreamCompletionResponseFinishEventRow {
session_id,
event_at,
sequence_no,
prompt_json,
response_summary_json,
};
diesel::insert_into(prompt_hook_on_stream_completion_response_finish_events::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!(
"failed to log on_stream_completion_response_finish event: {err}"
),
})?;
Ok(())
}
}

View File

@@ -9,7 +9,7 @@ use serde_json::Value;
use crate::error::{AgentError, AgentResult};
use crate::git_access;
use crate::tools::{FileContext, SearchHit};
use crate::tools::SearchHit;
#[derive(Debug, Clone)]
pub struct AstGrepTool {
@@ -21,24 +21,14 @@ pub struct AstGrepTool {
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "operation", rename_all = "snake_case")]
pub enum AstGrepArgs {
SearchSymbol {
query: String,
},
SearchPattern {
query: String,
},
GetFileContext {
file: String,
line: i32,
radius: i32,
},
SearchSymbol { query: String },
SearchPattern { query: String },
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "operation", rename_all = "snake_case")]
pub enum AstGrepOutput {
SearchHits { hits: Vec<SearchHit> },
FileContext { context: FileContext },
}
impl Default for AstGrepTool {
@@ -171,12 +161,6 @@ impl AstGrepTool {
pub fn search_pattern(&self, query: &str) -> AgentResult<Vec<SearchHit>> {
self.run_query(query)
}
pub fn get_file_context(&self, file: &str, line: i32, radius: i32) -> AgentResult<FileContext> {
let (repo_git_dir, head_ref) = self.bound_state()?;
let repo = self.open_repo(&repo_git_dir)?;
git_access::read_file_context_at_ref(&repo, &head_ref, file, line, radius)
}
}
impl Tool for AstGrepTool {
@@ -189,37 +173,18 @@ impl Tool for AstGrepTool {
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_owned(),
description: "Search source code using ast-grep and fetch file context.".to_owned(),
description: "Search source code using ast-grep.".to_owned(),
parameters: serde_json::json!({
"type": "object",
"oneOf": [
{
"type": "object",
"properties": {
"operation": { "type": "string", "const": "search_symbol" },
"query": { "type": "string", "description": "Symbol-like query" }
},
"required": ["operation", "query"]
"properties": {
"operation": {
"type": "string",
"enum": ["search_symbol", "search_pattern"],
"description": "Tool operation. Both operations require `query`."
},
{
"type": "object",
"properties": {
"operation": { "type": "string", "const": "search_pattern" },
"query": { "type": "string", "description": "Pattern query" }
},
"required": ["operation", "query"]
},
{
"type": "object",
"properties": {
"operation": { "type": "string", "const": "get_file_context" },
"file": { "type": "string", "description": "Path to file in repository" },
"line": { "type": "integer", "description": "1-based line number" },
"radius": { "type": "integer", "description": "Number of lines before/after line" }
},
"required": ["operation", "file", "line", "radius"]
}
]
"query": { "type": "string", "description": "Symbol-like or pattern query" },
},
"required": ["operation"]
}),
}
}
@@ -234,10 +199,6 @@ impl Tool for AstGrepTool {
let hits = self.search_pattern(&query)?;
Ok(AstGrepOutput::SearchHits { hits })
}
AstGrepArgs::GetFileContext { file, line, radius } => {
let context = self.get_file_context(&file, line, radius)?;
Ok(AstGrepOutput::FileContext { context })
}
}
}
}

View File

@@ -0,0 +1,148 @@
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::{AgentError, AgentResult};
use crate::memory::MemoryStore;
use crate::types::ProjectRef;
#[derive(Debug, Clone)]
pub struct ProjectMemoryWriteTool<M> {
memory: M,
project: Option<ProjectRef>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ProjectMemoryWriteArgs {
pub operation: MemoryWriteOperation,
pub key: Option<String>,
pub value: Option<String>,
pub source: Option<String>,
pub summary_type: Option<String>,
pub content: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryWriteOperation {
UpsertEntry,
UpsertSummary,
}
#[derive(Debug, Clone, Serialize)]
pub struct ProjectMemoryWriteOutput {
pub stored: bool,
pub detail: String,
}
impl<M> ProjectMemoryWriteTool<M> {
pub fn new(memory: M) -> Self {
Self {
memory,
project: None,
}
}
pub fn bind_to_project(&self, project: ProjectRef) -> Self
where
M: Clone,
{
Self {
memory: self.memory.clone(),
project: Some(project),
}
}
fn project(&self) -> AgentResult<&ProjectRef> {
self.project
.as_ref()
.ok_or_else(|| AgentError::ConfigError {
message: "project memory write tool is not bound to a project context".to_owned(),
})
}
}
impl<M> Tool for ProjectMemoryWriteTool<M>
where
M: MemoryStore + Clone + Send + Sync,
{
const NAME: &'static str = "project_memory_write";
type Error = AgentError;
type Args = ProjectMemoryWriteArgs;
type Output = ProjectMemoryWriteOutput;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_owned(),
description: "Persist project-specific memory entries or summaries for future reviews."
.to_owned(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"operation": {
"type": "string",
"enum": ["upsert_entry", "upsert_summary"],
"description": "upsert_entry requires key and value. upsert_summary requires summary_type and content."
},
"key": { "type": "string", "description": "Memory key for upsert_entry." },
"value": { "type": "string", "description": "Memory value for upsert_entry. JSON text is accepted; plain text is stored as JSON string." },
"source": { "type": "string", "description": "Source tag for upsert_entry. Defaults to agent_tool." },
"summary_type": { "type": "string", "description": "Summary type for upsert_summary." },
"content": { "type": "string", "description": "Summary content for upsert_summary." }
},
"required": ["operation"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let project = self.project()?;
match args.operation {
MemoryWriteOperation::UpsertEntry => {
let key = args.key.ok_or_else(|| AgentError::ToolError {
tool: Self::NAME.to_owned(),
message: "missing required field `key` for upsert_entry".to_owned(),
})?;
let raw_value = args.value.ok_or_else(|| AgentError::ToolError {
tool: Self::NAME.to_owned(),
message: "missing required field `value` for upsert_entry".to_owned(),
})?;
let value: Value =
serde_json::from_str(&raw_value).unwrap_or_else(|_| Value::String(raw_value));
let source = args.source.as_deref().unwrap_or("agent_tool");
self.memory
.upsert_memory_entry(project, &key, &value, source)
.await?;
Ok(ProjectMemoryWriteOutput {
stored: true,
detail: format!("stored memory entry `{key}`"),
})
}
MemoryWriteOperation::UpsertSummary => {
let summary_type = args.summary_type.ok_or_else(|| AgentError::ToolError {
tool: Self::NAME.to_owned(),
message: "missing required field `summary_type` for upsert_summary".to_owned(),
})?;
let content = args.content.ok_or_else(|| AgentError::ToolError {
tool: Self::NAME.to_owned(),
message: "missing required field `content` for upsert_summary".to_owned(),
})?;
self.memory
.upsert_memory_summary(project, &summary_type, &content)
.await?;
Ok(ProjectMemoryWriteOutput {
stored: true,
detail: format!("stored memory summary `{summary_type}`"),
})
}
}
}
}

View File

@@ -1,8 +1,14 @@
mod ast_grep;
mod memory_write;
mod readfile;
use serde::{Deserialize, Serialize};
pub use ast_grep::{AstGrepArgs, AstGrepOutput, AstGrepTool};
pub use memory_write::{
MemoryWriteOperation, ProjectMemoryWriteArgs, ProjectMemoryWriteOutput, ProjectMemoryWriteTool,
};
pub use readfile::{ReadFileArgs, ReadFileOutput, ReadFileTool};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SearchHit {

View File

@@ -0,0 +1,115 @@
use std::path::{Path, PathBuf};
use git2::Repository;
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use serde::{Deserialize, Serialize};
use crate::error::{AgentError, AgentResult};
use crate::git_access;
use crate::tools::FileContext;
#[derive(Debug, Clone)]
pub struct ReadFileTool {
repo_git_dir: Option<PathBuf>,
head_ref: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ReadFileArgs {
pub file: String,
pub line: i32,
pub radius: i32,
}
#[derive(Debug, Clone, Serialize)]
pub struct ReadFileOutput {
pub context: FileContext,
}
impl Default for ReadFileTool {
fn default() -> Self {
Self::new()
}
}
impl ReadFileTool {
pub fn new() -> Self {
Self {
repo_git_dir: None,
head_ref: None,
}
}
pub fn bind_to_request(
&self,
repo: &Repository,
head_ref: impl Into<String>,
) -> AgentResult<Self> {
git_access::ensure_bare_repository(repo)?;
Ok(Self {
repo_git_dir: Some(repo.path().to_path_buf()),
head_ref: Some(head_ref.into()),
})
}
fn bound_state(&self) -> AgentResult<(PathBuf, String)> {
let repo_git_dir = self
.repo_git_dir
.clone()
.ok_or_else(|| AgentError::ConfigError {
message: "readfile tool is not bound to a repository context".to_owned(),
})?;
let head_ref = self
.head_ref
.clone()
.ok_or_else(|| AgentError::ConfigError {
message: "readfile tool is not bound to a head_ref".to_owned(),
})?;
Ok((repo_git_dir, head_ref))
}
fn open_repo(&self, repo_git_dir: &Path) -> AgentResult<Repository> {
Repository::open_bare(repo_git_dir).map_err(|err| AgentError::GitError {
operation: format!("open bare repository at '{}'", repo_git_dir.display()),
message: err.to_string(),
})
}
fn read_context(&self, file: &str, line: i32, radius: i32) -> AgentResult<FileContext> {
let (repo_git_dir, head_ref) = self.bound_state()?;
let repo = self.open_repo(&repo_git_dir)?;
git_access::read_file_context_at_ref(&repo, &head_ref, file, line, radius)
}
}
impl Tool for ReadFileTool {
const NAME: &'static str = "readfile";
type Error = AgentError;
type Args = ReadFileArgs;
type Output = ReadFileOutput;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_owned(),
description: "Read file context from the repository at head_ref.".to_owned(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"file": { "type": "string", "description": "Path to file in repository" },
"line": { "type": "integer", "description": "1-based line number" },
"radius": { "type": "integer", "description": "Number of lines before/after line" }
},
"required": ["file", "line", "radius"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let context = self.read_context(&args.file, args.line, args.radius)?;
Ok(ReadFileOutput { context })
}
}

View File

@@ -50,6 +50,7 @@ pub struct ProjectRef {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PullRequestReviewInput {
pub project: ProjectRef,
pub pull_request_id: i64,
pub base_ref: String,
pub head_ref: String,
}
@@ -57,6 +58,7 @@ pub struct PullRequestReviewInput {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationInput {
pub project: ProjectRef,
pub pull_request_id: i64,
pub head_ref: String,
pub anchor_file: String,
pub anchor_line: i32,

View File

@@ -1,8 +1,11 @@
use std::path::PathBuf;
use codetaker_agent::memory::{DieselMemoryStore, MemoryStore};
use codetaker_agent::memory::{AuditStore, DieselMemoryStore, MemoryStore};
use codetaker_agent::types::{Forge, MessageAuthor, ProjectRef};
use codetaker_db::create_pool;
use diesel::ExpressionMethods;
use diesel::QueryDsl;
use diesel_async::RunQueryDsl;
use serde_json::json;
fn test_db_path(test_name: &str) -> PathBuf {
@@ -87,7 +90,7 @@ async fn thread_messages_round_trip() {
let project = sample_project();
let thread_id = store
.create_review_thread(&project, "src/lib.rs", 42, "Please avoid unwrap here")
.create_review_thread(&project, 123, "src/lib.rs", 42, "Please avoid unwrap here")
.await
.expect("create review thread");
@@ -113,3 +116,195 @@ async fn thread_messages_round_trip() {
assert!(matches!(messages[0].author, MessageAuthor::User));
assert!(matches!(messages[1].author, MessageAuthor::Agent));
}
#[tokio::test]
async fn prompt_audit_session_lifecycle_success_and_failure() {
let db_path = test_db_path("prompt_audit_session_lifecycle_success_and_failure");
let database_url = db_path.display().to_string();
let pool = create_pool(Some(&database_url))
.await
.expect("create sqlite pool");
let store = DieselMemoryStore::new(pool.clone());
let project = sample_project();
let now = chrono::Utc::now().naive_utc();
let success_session = store
.start_prompt_audit_session(&project, 11, "pull_request_review", now)
.await
.expect("start success session");
store
.finish_prompt_audit_session(
success_session,
chrono::Utc::now().naive_utc(),
"completed",
None,
)
.await
.expect("finish success session");
let failed_session = store
.start_prompt_audit_session(&project, 12, "conversation_response", now)
.await
.expect("start failed session");
store
.finish_prompt_audit_session(
failed_session,
chrono::Utc::now().naive_utc(),
"failed",
Some("model error"),
)
.await
.expect("finish failed session");
let mut conn = pool.get().await.expect("get pooled connection");
use codetaker_db::schema::prompt_audit_sessions::dsl;
let statuses = dsl::prompt_audit_sessions
.filter(dsl::id.eq_any([success_session, failed_session]))
.select(dsl::status)
.load::<String>(&mut conn)
.await
.expect("load session statuses");
assert!(statuses.iter().any(|status| status == "completed"));
assert!(statuses.iter().any(|status| status == "failed"));
}
#[tokio::test]
async fn prompt_hook_event_methods_write_rows() {
let db_path = test_db_path("prompt_hook_event_methods_write_rows");
let database_url = db_path.display().to_string();
let pool = create_pool(Some(&database_url))
.await
.expect("create sqlite pool");
let store = DieselMemoryStore::new(pool.clone());
let project = sample_project();
let now = chrono::Utc::now().naive_utc();
let session_id = store
.start_prompt_audit_session(&project, 99, "pull_request_review", now)
.await
.expect("start audit session");
store
.log_on_completion_call(session_id, now, 1, "{\"role\":\"user\"}", "[]")
.await
.expect("log completion call");
store
.log_on_completion_response(session_id, now, 2, "{\"role\":\"user\"}", "[]", 10, 20, 30)
.await
.expect("log completion response");
store
.log_on_tool_call(
session_id,
now,
3,
"ast_grep",
Some("call_1"),
"internal_1",
"{\"query\":\"foo\"}",
"continue",
None,
)
.await
.expect("log tool call");
store
.log_on_tool_result(
session_id,
now,
4,
"ast_grep",
Some("call_1"),
"internal_1",
"{\"query\":\"foo\"}",
"{\"hits\":[]}",
)
.await
.expect("log tool result");
store
.log_on_text_delta(session_id, now, 5, "foo", "foobar")
.await
.expect("log text delta");
store
.log_on_tool_call_delta(
session_id,
now,
6,
"call_2",
"internal_2",
Some("readfile"),
"{\"path\":\"src/lib.rs\"}",
)
.await
.expect("log tool call delta");
store
.log_on_stream_completion_response_finish(session_id, now, 7, "{}", "{\"done\":true}")
.await
.expect("log stream finish");
let mut conn = pool.get().await.expect("get pooled connection");
use codetaker_db::schema::{
prompt_hook_on_completion_call_events::dsl as completion_call_dsl,
prompt_hook_on_completion_response_events::dsl as completion_response_dsl,
prompt_hook_on_stream_completion_response_finish_events::dsl as stream_finish_dsl,
prompt_hook_on_text_delta_events::dsl as text_delta_dsl,
prompt_hook_on_tool_call_delta_events::dsl as tool_call_delta_dsl,
prompt_hook_on_tool_call_events::dsl as tool_call_dsl,
prompt_hook_on_tool_result_events::dsl as tool_result_dsl,
};
let completion_call_count = completion_call_dsl::prompt_hook_on_completion_call_events
.filter(completion_call_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count completion call events");
assert_eq!(completion_call_count, 1);
let completion_response_count = completion_response_dsl::prompt_hook_on_completion_response_events
.filter(completion_response_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count completion response events");
assert_eq!(completion_response_count, 1);
let tool_call_count = tool_call_dsl::prompt_hook_on_tool_call_events
.filter(tool_call_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count tool call events");
assert_eq!(tool_call_count, 1);
let tool_result_count = tool_result_dsl::prompt_hook_on_tool_result_events
.filter(tool_result_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count tool result events");
assert_eq!(tool_result_count, 1);
let text_delta_count = text_delta_dsl::prompt_hook_on_text_delta_events
.filter(text_delta_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count text delta events");
assert_eq!(text_delta_count, 1);
let tool_delta_count = tool_call_delta_dsl::prompt_hook_on_tool_call_delta_events
.filter(tool_call_delta_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count tool call delta events");
assert_eq!(tool_delta_count, 1);
let stream_finish_count =
stream_finish_dsl::prompt_hook_on_stream_completion_response_finish_events
.filter(stream_finish_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count stream finish events");
assert_eq!(stream_finish_count, 1);
}

View File

@@ -7,6 +7,7 @@ edition = "2024"
clap.workspace = true
codetaker-agent = { path = "../codetaker-agent" }
codetaker-db.workspace = true
dialoguer.workspace = true
git2.workspace = true
rig-core.workspace = true
serde_json.workspace = true

View File

@@ -31,6 +31,9 @@ struct Cli {
#[arg(long)]
head_ref: String,
#[arg(long)]
pull_request_id: i64,
#[arg(long, default_value = "local")]
owner: String,
@@ -40,6 +43,7 @@ struct Cli {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
let cli = Cli::parse();
let repo_name = cli.repo.unwrap_or_else(|| {
@@ -63,6 +67,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
owner: cli.owner,
repo: repo_name,
},
pull_request_id: cli.pull_request_id,
base_ref: cli.base_ref,
head_ref: cli.head_ref,
};
@@ -72,3 +77,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::Cli;
use clap::Parser;
#[test]
fn requires_pull_request_id() {
let parsed = Cli::try_parse_from([
"codetaker-cli",
"--anthropic-api-key",
"key",
"--repo-path",
".",
"--base-ref",
"main",
"--head-ref",
"feature",
]);
assert!(parsed.is_err());
}
}

View File

@@ -1,3 +1,11 @@
DROP TABLE IF EXISTS prompt_hook_on_stream_completion_response_finish_events;
DROP TABLE IF EXISTS prompt_hook_on_tool_call_delta_events;
DROP TABLE IF EXISTS prompt_hook_on_text_delta_events;
DROP TABLE IF EXISTS prompt_hook_on_tool_result_events;
DROP TABLE IF EXISTS prompt_hook_on_tool_call_events;
DROP TABLE IF EXISTS prompt_hook_on_completion_response_events;
DROP TABLE IF EXISTS prompt_hook_on_completion_call_events;
DROP TABLE IF EXISTS prompt_audit_sessions;
DROP TABLE IF EXISTS review_thread_messages;
DROP TABLE IF EXISTS review_threads;
DROP TABLE IF EXISTS project_memory_summaries;

View File

@@ -30,6 +30,7 @@ CREATE TABLE project_memory_summaries (
CREATE TABLE review_threads (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
project_id INTEGER NOT NULL,
pull_request_id BIGINT NOT NULL,
file TEXT NOT NULL,
line INTEGER NOT NULL,
initial_comment TEXT NOT NULL,
@@ -46,11 +47,125 @@ CREATE TABLE review_thread_messages (
FOREIGN KEY(thread_id) REFERENCES review_threads(id) ON DELETE CASCADE
);
CREATE TABLE prompt_audit_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
project_id INTEGER NOT NULL,
pull_request_id BIGINT NOT NULL,
entrypoint TEXT NOT NULL,
started_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
ended_at TIMESTAMP,
status TEXT NOT NULL,
error_message TEXT,
FOREIGN KEY(project_id) REFERENCES projects(id) ON DELETE CASCADE
);
CREATE TABLE prompt_hook_on_completion_call_events (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
session_id INTEGER NOT NULL,
event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
sequence_no INTEGER NOT NULL,
prompt_json TEXT NOT NULL,
history_json TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE
);
CREATE TABLE prompt_hook_on_completion_response_events (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
session_id INTEGER NOT NULL,
event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
sequence_no INTEGER NOT NULL,
prompt_json TEXT NOT NULL,
assistant_choice_json TEXT NOT NULL,
usage_input_tokens BIGINT NOT NULL,
usage_output_tokens BIGINT NOT NULL,
usage_total_tokens BIGINT NOT NULL,
FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE
);
CREATE TABLE prompt_hook_on_tool_call_events (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
session_id INTEGER NOT NULL,
event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
sequence_no INTEGER NOT NULL,
tool_name TEXT NOT NULL,
tool_call_id TEXT,
internal_call_id TEXT NOT NULL,
args_json TEXT NOT NULL,
action TEXT NOT NULL,
reason TEXT,
FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE
);
CREATE TABLE prompt_hook_on_tool_result_events (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
session_id INTEGER NOT NULL,
event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
sequence_no INTEGER NOT NULL,
tool_name TEXT NOT NULL,
tool_call_id TEXT,
internal_call_id TEXT NOT NULL,
args_json TEXT NOT NULL,
result_json TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE
);
CREATE TABLE prompt_hook_on_text_delta_events (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
session_id INTEGER NOT NULL,
event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
sequence_no INTEGER NOT NULL,
text_delta TEXT NOT NULL,
aggregated_text TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE
);
CREATE TABLE prompt_hook_on_tool_call_delta_events (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
session_id INTEGER NOT NULL,
event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
sequence_no INTEGER NOT NULL,
tool_call_id TEXT NOT NULL,
internal_call_id TEXT NOT NULL,
tool_name TEXT,
tool_call_delta TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE
);
CREATE TABLE prompt_hook_on_stream_completion_response_finish_events (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
session_id INTEGER NOT NULL,
event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
sequence_no INTEGER NOT NULL,
prompt_json TEXT NOT NULL,
response_summary_json TEXT NOT NULL,
FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE
);
CREATE INDEX idx_project_memory_entries_project_id
ON project_memory_entries(project_id);
CREATE INDEX idx_project_memory_summaries_project_id
ON project_memory_summaries(project_id);
CREATE INDEX idx_review_threads_project_id
ON review_threads(project_id);
CREATE INDEX idx_review_threads_project_pr
ON review_threads(project_id, pull_request_id);
CREATE INDEX idx_review_thread_messages_thread_id
ON review_thread_messages(thread_id);
CREATE INDEX idx_prompt_audit_sessions_project_pr
ON prompt_audit_sessions(project_id, pull_request_id);
CREATE INDEX idx_prompt_audit_sessions_entrypoint_started
ON prompt_audit_sessions(entrypoint, started_at);
CREATE INDEX idx_prompt_hook_completion_call_session
ON prompt_hook_on_completion_call_events(session_id, sequence_no);
CREATE INDEX idx_prompt_hook_completion_response_session
ON prompt_hook_on_completion_response_events(session_id, sequence_no);
CREATE INDEX idx_prompt_hook_tool_call_session
ON prompt_hook_on_tool_call_events(session_id, sequence_no);
CREATE INDEX idx_prompt_hook_tool_result_session
ON prompt_hook_on_tool_result_events(session_id, sequence_no);
CREATE INDEX idx_prompt_hook_text_delta_session
ON prompt_hook_on_text_delta_events(session_id, sequence_no);
CREATE INDEX idx_prompt_hook_tool_delta_session
ON prompt_hook_on_tool_call_delta_events(session_id, sequence_no);
CREATE INDEX idx_prompt_hook_stream_finish_session
ON prompt_hook_on_stream_completion_response_finish_events(session_id, sequence_no);

View File

@@ -2,8 +2,11 @@ use chrono::NaiveDateTime;
use diesel::{Associations, Identifiable, Insertable, Queryable, Selectable};
use crate::schema::{
project_memory_entries, project_memory_summaries, projects, review_thread_messages,
review_threads,
project_memory_entries, project_memory_summaries, projects, prompt_audit_sessions,
prompt_hook_on_completion_call_events, prompt_hook_on_completion_response_events,
prompt_hook_on_stream_completion_response_finish_events, prompt_hook_on_text_delta_events,
prompt_hook_on_tool_call_delta_events, prompt_hook_on_tool_call_events,
prompt_hook_on_tool_result_events, review_thread_messages, review_threads,
};
#[derive(Debug, Clone, Queryable, Selectable, Identifiable)]
@@ -71,6 +74,7 @@ pub struct NewProjectMemorySummaryRow<'a> {
pub struct ReviewThreadRow {
pub id: i32,
pub project_id: i32,
pub pull_request_id: i64,
pub file: String,
pub line: i32,
pub initial_comment: String,
@@ -81,6 +85,7 @@ pub struct ReviewThreadRow {
#[diesel(table_name = review_threads)]
pub struct NewReviewThreadRow<'a> {
pub project_id: i32,
pub pull_request_id: i64,
pub file: &'a str,
pub line: i32,
pub initial_comment: &'a str,
@@ -106,3 +111,111 @@ pub struct NewReviewThreadMessageRow<'a> {
pub body: &'a str,
pub created_at: NaiveDateTime,
}
#[derive(Debug, Clone, Queryable, Selectable, Identifiable, Associations)]
#[diesel(table_name = prompt_audit_sessions)]
#[diesel(belongs_to(ProjectRow, foreign_key = project_id))]
pub struct PromptAuditSessionRow {
pub id: i32,
pub project_id: i32,
pub pull_request_id: i64,
pub entrypoint: String,
pub started_at: NaiveDateTime,
pub ended_at: Option<NaiveDateTime>,
pub status: String,
pub error_message: Option<String>,
}
#[derive(Debug, Insertable)]
#[diesel(table_name = prompt_audit_sessions)]
pub struct NewPromptAuditSessionRow<'a> {
pub project_id: i32,
pub pull_request_id: i64,
pub entrypoint: &'a str,
pub started_at: NaiveDateTime,
pub ended_at: Option<NaiveDateTime>,
pub status: &'a str,
pub error_message: Option<&'a str>,
}
#[derive(Debug, Insertable)]
#[diesel(table_name = prompt_hook_on_completion_call_events)]
pub struct NewPromptHookOnCompletionCallEventRow<'a> {
pub session_id: i32,
pub event_at: NaiveDateTime,
pub sequence_no: i32,
pub prompt_json: &'a str,
pub history_json: &'a str,
}
#[derive(Debug, Insertable)]
#[diesel(table_name = prompt_hook_on_completion_response_events)]
pub struct NewPromptHookOnCompletionResponseEventRow<'a> {
pub session_id: i32,
pub event_at: NaiveDateTime,
pub sequence_no: i32,
pub prompt_json: &'a str,
pub assistant_choice_json: &'a str,
pub usage_input_tokens: i64,
pub usage_output_tokens: i64,
pub usage_total_tokens: i64,
}
#[derive(Debug, Insertable)]
#[diesel(table_name = prompt_hook_on_tool_call_events)]
pub struct NewPromptHookOnToolCallEventRow<'a> {
pub session_id: i32,
pub event_at: NaiveDateTime,
pub sequence_no: i32,
pub tool_name: &'a str,
pub tool_call_id: Option<&'a str>,
pub internal_call_id: &'a str,
pub args_json: &'a str,
pub action: &'a str,
pub reason: Option<&'a str>,
}
#[derive(Debug, Insertable)]
#[diesel(table_name = prompt_hook_on_tool_result_events)]
pub struct NewPromptHookOnToolResultEventRow<'a> {
pub session_id: i32,
pub event_at: NaiveDateTime,
pub sequence_no: i32,
pub tool_name: &'a str,
pub tool_call_id: Option<&'a str>,
pub internal_call_id: &'a str,
pub args_json: &'a str,
pub result_json: &'a str,
}
#[derive(Debug, Insertable)]
#[diesel(table_name = prompt_hook_on_text_delta_events)]
pub struct NewPromptHookOnTextDeltaEventRow<'a> {
pub session_id: i32,
pub event_at: NaiveDateTime,
pub sequence_no: i32,
pub text_delta: &'a str,
pub aggregated_text: &'a str,
}
#[derive(Debug, Insertable)]
#[diesel(table_name = prompt_hook_on_tool_call_delta_events)]
pub struct NewPromptHookOnToolCallDeltaEventRow<'a> {
pub session_id: i32,
pub event_at: NaiveDateTime,
pub sequence_no: i32,
pub tool_call_id: &'a str,
pub internal_call_id: &'a str,
pub tool_name: Option<&'a str>,
pub tool_call_delta: &'a str,
}
#[derive(Debug, Insertable)]
#[diesel(table_name = prompt_hook_on_stream_completion_response_finish_events)]
pub struct NewPromptHookOnStreamCompletionResponseFinishEventRow<'a> {
pub session_id: i32,
pub event_at: NaiveDateTime,
pub sequence_no: i32,
pub prompt_json: &'a str,
pub response_summary_json: &'a str,
}

View File

@@ -30,6 +30,108 @@ diesel::table! {
}
}
diesel::table! {
prompt_audit_sessions (id) {
id -> Integer,
project_id -> Integer,
pull_request_id -> BigInt,
entrypoint -> Text,
started_at -> Timestamp,
ended_at -> Nullable<Timestamp>,
status -> Text,
error_message -> Nullable<Text>,
}
}
diesel::table! {
prompt_hook_on_completion_call_events (id) {
id -> Integer,
session_id -> Integer,
event_at -> Timestamp,
sequence_no -> Integer,
prompt_json -> Text,
history_json -> Text,
}
}
diesel::table! {
prompt_hook_on_completion_response_events (id) {
id -> Integer,
session_id -> Integer,
event_at -> Timestamp,
sequence_no -> Integer,
prompt_json -> Text,
assistant_choice_json -> Text,
usage_input_tokens -> BigInt,
usage_output_tokens -> BigInt,
usage_total_tokens -> BigInt,
}
}
diesel::table! {
prompt_hook_on_stream_completion_response_finish_events (id) {
id -> Integer,
session_id -> Integer,
event_at -> Timestamp,
sequence_no -> Integer,
prompt_json -> Text,
response_summary_json -> Text,
}
}
diesel::table! {
prompt_hook_on_text_delta_events (id) {
id -> Integer,
session_id -> Integer,
event_at -> Timestamp,
sequence_no -> Integer,
text_delta -> Text,
aggregated_text -> Text,
}
}
diesel::table! {
prompt_hook_on_tool_call_delta_events (id) {
id -> Integer,
session_id -> Integer,
event_at -> Timestamp,
sequence_no -> Integer,
tool_call_id -> Text,
internal_call_id -> Text,
tool_name -> Nullable<Text>,
tool_call_delta -> Text,
}
}
diesel::table! {
prompt_hook_on_tool_call_events (id) {
id -> Integer,
session_id -> Integer,
event_at -> Timestamp,
sequence_no -> Integer,
tool_name -> Text,
tool_call_id -> Nullable<Text>,
internal_call_id -> Text,
args_json -> Text,
action -> Text,
reason -> Nullable<Text>,
}
}
diesel::table! {
prompt_hook_on_tool_result_events (id) {
id -> Integer,
session_id -> Integer,
event_at -> Timestamp,
sequence_no -> Integer,
tool_name -> Text,
tool_call_id -> Nullable<Text>,
internal_call_id -> Text,
args_json -> Text,
result_json -> Text,
}
}
diesel::table! {
review_thread_messages (id) {
id -> Integer,
@@ -44,6 +146,7 @@ diesel::table! {
review_threads (id) {
id -> Integer,
project_id -> Integer,
pull_request_id -> BigInt,
file -> Text,
line -> Integer,
initial_comment -> Text,
@@ -53,6 +156,14 @@ diesel::table! {
diesel::joinable!(project_memory_entries -> projects (project_id));
diesel::joinable!(project_memory_summaries -> projects (project_id));
diesel::joinable!(prompt_audit_sessions -> projects (project_id));
diesel::joinable!(prompt_hook_on_completion_call_events -> prompt_audit_sessions (session_id));
diesel::joinable!(prompt_hook_on_completion_response_events -> prompt_audit_sessions (session_id));
diesel::joinable!(prompt_hook_on_stream_completion_response_finish_events -> prompt_audit_sessions (session_id));
diesel::joinable!(prompt_hook_on_text_delta_events -> prompt_audit_sessions (session_id));
diesel::joinable!(prompt_hook_on_tool_call_delta_events -> prompt_audit_sessions (session_id));
diesel::joinable!(prompt_hook_on_tool_call_events -> prompt_audit_sessions (session_id));
diesel::joinable!(prompt_hook_on_tool_result_events -> prompt_audit_sessions (session_id));
diesel::joinable!(review_thread_messages -> review_threads (thread_id));
diesel::joinable!(review_threads -> projects (project_id));
@@ -60,6 +171,14 @@ diesel::allow_tables_to_appear_in_same_query!(
projects,
project_memory_entries,
project_memory_summaries,
prompt_audit_sessions,
prompt_hook_on_completion_call_events,
prompt_hook_on_completion_response_events,
prompt_hook_on_stream_completion_response_finish_events,
prompt_hook_on_text_delta_events,
prompt_hook_on_tool_call_delta_events,
prompt_hook_on_tool_call_events,
prompt_hook_on_tool_result_events,
review_threads,
review_thread_messages,
);