diff --git a/Cargo.lock b/Cargo.lock index 4b73986..e626987 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -107,6 +107,17 @@ dependencies = [ "syn", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -274,6 +285,7 @@ dependencies = [ name = "codetaker-agent" version = "0.1.0" dependencies = [ + "async-trait", "chrono", "codetaker-db", "diesel", @@ -294,6 +306,7 @@ dependencies = [ "clap", "codetaker-agent", "codetaker-db", + "dialoguer", "git2", "rig-core", "serde_json", @@ -354,6 +367,19 @@ version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" +[[package]] +name = "console" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03e45a4a8926227e4197636ba97a9fc9b00477e9f4bd711395687c5f0734bec4" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.61.2", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -433,6 +459,18 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "dialoguer" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25f104b501bf2364e78d0d3974cbc774f738f5865306ed128e1e0d7499c0ad96" +dependencies = [ + "console", + "shell-words", + "tempfile", + "zeroize", +] + [[package]] name = "diesel" version = "2.3.6" @@ -545,6 +583,12 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -1182,6 +1226,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + [[package]] name = "litemap" version = "0.8.1" @@ -1738,6 +1788,19 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustls" version = "0.23.37" @@ -1990,6 +2053,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77" + [[package]] name = "shlex" version = "1.3.0" @@ -2116,6 +2185,19 @@ dependencies = [ "libc", ] +[[package]] +name = "tempfile" +version = "3.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -2440,6 +2522,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 34c6314..77977f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,8 +5,10 @@ members = [ ] [workspace.dependencies] +async-trait = "0.1.89" clap = { version = "4.5.60", features = ["derive", "env"] } codetaker-db = { path = "crates/codetaker-db" } +dialoguer = "0.12.0" git2 = { version = "0.20.4", features = ["vendored-libgit2"] } rig-core = "0.31.0" schemars = "1.2.1" diff --git a/crates/codetaker-agent/Cargo.toml b/crates/codetaker-agent/Cargo.toml index df11bd6..7671816 100644 --- a/crates/codetaker-agent/Cargo.toml +++ b/crates/codetaker-agent/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] +async-trait.workspace = true chrono.workspace = true codetaker-db.workspace = true diesel.workspace = true diff --git a/crates/codetaker-agent/src/agent/conversation_answer.rs b/crates/codetaker-agent/src/agent/conversation_answer.rs index 53fa090..826312c 100644 --- a/crates/codetaker-agent/src/agent/conversation_answer.rs +++ b/crates/codetaker-agent/src/agent/conversation_answer.rs @@ -5,10 +5,11 @@ use rig::agent::AgentBuilder; use rig::client::{Client, CompletionClient}; use rig::completion::{StructuredOutputError, TypedPrompt}; +use crate::audit::{self, AuditEntrypoint}; use crate::error::{AgentError, AgentResult}; use crate::git_access::{self}; -use crate::memory::{MemoryStore, ProjectContextSnapshot}; -use crate::tools::{AstGrepTool, FileContext, SearchHit}; +use crate::memory::{AuditStore, MemoryStore, ProjectContextSnapshot}; +use crate::tools::{AstGrepTool, FileContext, ProjectMemoryWriteTool, ReadFileTool, SearchHit}; use crate::types::{ConversationInput, ConversationOutput}; const CONVERSATION_RESPONSE_PREAMBLE: &str = r#" @@ -44,7 +45,7 @@ impl ConversationAnswerAgent { impl ConversationAnswerAgent where Client: CompletionClient, - M: MemoryStore, + M: MemoryStore + AuditStore + Clone + Send + Sync + 'static, { pub async fn conversation_response( &self, @@ -63,6 +64,9 @@ where )?; let ast_grep = self.ast_grep.bind_to_request(repo, &input.head_ref)?; + let readfile = ReadFileTool::new().bind_to_request(repo, &input.head_ref)?; + let memory_write = + ProjectMemoryWriteTool::new(self.memory.clone()).bind_to_project(input.project.clone()); let search_hits = self.collect_conversation_search_hits(&ast_grep, &input); let prompt = build_conversation_prompt(&input, &memory_snapshot, &anchor_context, &search_hits); @@ -71,29 +75,56 @@ where let agent = AgentBuilder::new(conversation_model) .preamble(CONVERSATION_RESPONSE_PREAMBLE) .tool(ast_grep) + .tool(readfile) + .tool(memory_write) .temperature(0.1) .output_schema::() .build(); - agent - .prompt_typed::(prompt) - .await - .map_err(|err| match err { - StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError { - model: self.conversation_model.clone(), - message: prompt_err.to_string(), - }, - StructuredOutputError::DeserializationError(deser_err) => { - AgentError::OutputValidationError { - message: format!("failed to deserialize conversation output: {deser_err}"), - raw_output: None, - } + let audit_session = audit::start_session( + self.memory.clone(), + &input.project, + input.pull_request_id, + AuditEntrypoint::ConversationResponse, + ) + .await; + let result = if let Some(session) = &audit_session { + agent + .prompt_typed::(prompt) + .with_hook(session.hook()) + .await + } else { + agent.prompt_typed::(prompt).await + }; + match &result { + Ok(_) => { + if let Some(session) = &audit_session { + session.finish_success().await; } - StructuredOutputError::EmptyResponse => AgentError::OutputValidationError { - message: "conversation response was empty".to_owned(), + } + Err(err) => { + if let Some(session) = &audit_session { + session.finish_failure(&err.to_string()).await; + } + } + } + + result.map_err(|err| match err { + StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError { + model: self.conversation_model.clone(), + message: prompt_err.to_string(), + }, + StructuredOutputError::DeserializationError(deser_err) => { + AgentError::OutputValidationError { + message: format!("failed to deserialize conversation output: {deser_err}"), raw_output: None, - }, - }) + } + } + StructuredOutputError::EmptyResponse => AgentError::OutputValidationError { + message: "conversation response was empty".to_owned(), + raw_output: None, + }, + }) } fn collect_conversation_search_hits( @@ -126,6 +157,7 @@ fn build_conversation_prompt( input.project.forge.as_db_value() ) .ok(); + writeln!(prompt, "Pull request ID: {}", input.pull_request_id).ok(); writeln!(prompt, "Head reference: '{}'", input.head_ref).ok(); writeln!( diff --git a/crates/codetaker-agent/src/agent/pull_request_review.rs b/crates/codetaker-agent/src/agent/pull_request_review.rs index 45a37d8..78adcc8 100644 --- a/crates/codetaker-agent/src/agent/pull_request_review.rs +++ b/crates/codetaker-agent/src/agent/pull_request_review.rs @@ -5,22 +5,23 @@ use rig::agent::AgentBuilder; use rig::client::{Client, CompletionClient}; use rig::completion::{StructuredOutputError, TypedPrompt}; +use crate::audit::{self, AuditEntrypoint}; use crate::error::{AgentError, AgentResult}; use crate::git_access::{self, PullRequestMaterial}; -use crate::memory::{MemoryStore, ProjectContextSnapshot}; -use crate::tools::{AstGrepTool, FileContext}; +use crate::memory::{AuditStore, MemoryStore, ProjectContextSnapshot}; +use crate::tools::{AstGrepTool, FileContext, ProjectMemoryWriteTool, ReadFileTool}; use crate::types::{PullRequestReviewInput, PullRequestReviewOutput}; const PULL_REQUEST_REVIEW_PREAMBLE: &str = r#" You are a code review agent. Analyze pull request diffs and respond only with valid JSON. Use available tools when you need extra repository context. +If you found valuable insight that might help future reviews, save it to the project memory with the memory write tool. Be concise and focus on the most important issues in the code. Explain your judgement clearly and provide actionable feedback for the developer. When using tools, only request the specific information you need to make your review. In the end, pass judgement if you want to accept it or request changes, and explain your reasoning. Line numbers must target the head/new file version. -Do not include extra keys. "#; pub struct PullRequestReviewAgent { @@ -49,7 +50,7 @@ impl PullRequestReviewAgent { impl PullRequestReviewAgent where Client: CompletionClient, - M: MemoryStore, + M: MemoryStore + AuditStore + Clone + Send + Sync + 'static, { pub async fn pull_request_review( &self, @@ -65,36 +66,66 @@ where build_pull_request_prompt(&input, &memory_snapshot, &pr_material, &file_contexts); let ast_grep = self.ast_grep.bind_to_request(repo, &input.head_ref)?; + let readfile = ReadFileTool::new().bind_to_request(repo, &input.head_ref)?; + let memory_write = + ProjectMemoryWriteTool::new(self.memory.clone()).bind_to_project(input.project.clone()); let review_model = self.client.completion_model(&self.review_model); let agent = AgentBuilder::new(review_model) .preamble(PULL_REQUEST_REVIEW_PREAMBLE) .tool(ast_grep) + .tool(readfile) + .tool(memory_write) .temperature(0.0) .output_schema::() .build(); - agent - .prompt_typed::(prompt) - .await - .map_err(|err| match err { - StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError { - model: self.review_model.clone(), - message: prompt_err.to_string(), - }, - StructuredOutputError::DeserializationError(deser_err) => { - AgentError::OutputValidationError { - message: format!( - "failed to deserialize pull request review output: {deser_err}" - ), - raw_output: None, - } + let audit_session = audit::start_session( + self.memory.clone(), + &input.project, + input.pull_request_id, + AuditEntrypoint::PullRequestReview, + ) + .await; + let result = if let Some(session) = &audit_session { + agent + .prompt_typed::(prompt) + .with_hook(session.hook()) + .await + } else { + agent.prompt_typed::(prompt).await + }; + match &result { + Ok(_) => { + if let Some(session) = &audit_session { + session.finish_success().await; } - StructuredOutputError::EmptyResponse => AgentError::OutputValidationError { - message: "pull request review response was empty".to_owned(), + } + Err(err) => { + if let Some(session) = &audit_session { + session.finish_failure(&err.to_string()).await; + } + } + } + + result.map_err(|err| match err { + StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError { + model: self.review_model.clone(), + message: prompt_err.to_string(), + }, + StructuredOutputError::DeserializationError(deser_err) => { + AgentError::OutputValidationError { + message: format!( + "failed to deserialize pull request review output: {deser_err}" + ), raw_output: None, - }, - }) + } + } + StructuredOutputError::EmptyResponse => AgentError::OutputValidationError { + message: "pull request review response was empty".to_owned(), + raw_output: None, + }, + }) } fn collect_diff_file_contexts( @@ -142,6 +173,7 @@ fn build_pull_request_prompt( input.project.forge.as_db_value() ) .ok(); + writeln!(prompt, "Pull request ID: {}", input.pull_request_id).ok(); writeln!( prompt, diff --git a/crates/codetaker-agent/src/audit/mod.rs b/crates/codetaker-agent/src/audit/mod.rs new file mode 100644 index 0000000..19ff92b --- /dev/null +++ b/crates/codetaker-agent/src/audit/mod.rs @@ -0,0 +1,325 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicI32, Ordering}; + +use chrono::Utc; +use rig::agent::{HookAction, PromptHook, ToolCallHookAction}; +use rig::completion::{CompletionModel, CompletionResponse}; +use rig::message::Message; +use serde::Serialize; + +use crate::memory::AuditStore; +use crate::types::ProjectRef; + +#[derive(Debug, Clone, Copy)] +pub enum AuditEntrypoint { + PullRequestReview, + ConversationResponse, +} + +impl AuditEntrypoint { + fn as_db_value(self) -> &'static str { + match self { + Self::PullRequestReview => "pull_request_review", + Self::ConversationResponse => "conversation_response", + } + } +} + +#[derive(Clone)] +pub struct PromptAuditHook { + store: M, + session_id: i32, + sequence: Arc, +} + +impl PromptAuditHook { + pub fn new(store: M, session_id: i32) -> Self { + Self { + store, + session_id, + sequence: Arc::new(AtomicI32::new(0)), + } + } + + fn next_sequence(&self) -> i32 { + self.sequence.fetch_add(1, Ordering::Relaxed) + 1 + } +} + +pub struct AuditSession { + store: M, + session_id: i32, +} + +impl AuditSession +where + M: AuditStore + Clone + Send + Sync + 'static, +{ + pub fn hook(&self) -> PromptAuditHook { + PromptAuditHook::new(self.store.clone(), self.session_id) + } + + pub async fn finish_success(&self) { + let ended_at = Utc::now().naive_utc(); + if let Err(err) = self + .store + .finish_prompt_audit_session(self.session_id, ended_at, "completed", None) + .await + { + tracing::warn!(error = %err, "failed to finish successful prompt audit session"); + } + } + + pub async fn finish_failure(&self, error_message: &str) { + let ended_at = Utc::now().naive_utc(); + if let Err(err) = self + .store + .finish_prompt_audit_session( + self.session_id, + ended_at, + "failed", + Some(error_message), + ) + .await + { + tracing::warn!(error = %err, "failed to finish failed prompt audit session"); + } + } +} + +pub async fn start_session( + store: M, + project: &ProjectRef, + pull_request_id: i64, + entrypoint: AuditEntrypoint, +) -> Option> +where + M: AuditStore + Clone + Send + Sync + 'static, +{ + let started_at = Utc::now().naive_utc(); + match store + .start_prompt_audit_session( + project, + pull_request_id, + entrypoint.as_db_value(), + started_at, + ) + .await + { + Ok(session_id) => Some(AuditSession { store, session_id }), + Err(err) => { + tracing::warn!(error = %err, "failed to start prompt audit session"); + None + } + } +} + +impl PromptHook for PromptAuditHook +where + M: AuditStore + Clone + Send + Sync + 'static, + Model: CompletionModel, +{ + async fn on_completion_call(&self, prompt: &Message, history: &[Message]) -> HookAction { + let sequence_no = self.next_sequence(); + let event_at = Utc::now().naive_utc(); + let prompt_json = serialize_to_json(prompt); + let history_json = serialize_to_json(history); + + if let Err(err) = self + .store + .log_on_completion_call( + self.session_id, + event_at, + sequence_no, + &prompt_json, + &history_json, + ) + .await + { + tracing::warn!(error = %err, "failed to log on_completion_call"); + } + + HookAction::cont() + } + + async fn on_completion_response( + &self, + prompt: &Message, + response: &CompletionResponse, + ) -> HookAction { + let sequence_no = self.next_sequence(); + let event_at = Utc::now().naive_utc(); + let prompt_json = serialize_to_json(prompt); + let assistant_choice_json = serialize_to_json(&response.choice); + let usage_input_tokens = u64_to_i64(response.usage.input_tokens); + let usage_output_tokens = u64_to_i64(response.usage.output_tokens); + let usage_total_tokens = u64_to_i64(response.usage.total_tokens); + + if let Err(err) = self + .store + .log_on_completion_response( + self.session_id, + event_at, + sequence_no, + &prompt_json, + &assistant_choice_json, + usage_input_tokens, + usage_output_tokens, + usage_total_tokens, + ) + .await + { + tracing::warn!(error = %err, "failed to log on_completion_response"); + } + + HookAction::cont() + } + + async fn on_tool_call( + &self, + tool_name: &str, + tool_call_id: Option, + internal_call_id: &str, + args: &str, + ) -> ToolCallHookAction { + let sequence_no = self.next_sequence(); + let event_at = Utc::now().naive_utc(); + + if let Err(err) = self + .store + .log_on_tool_call( + self.session_id, + event_at, + sequence_no, + tool_name, + tool_call_id.as_deref(), + internal_call_id, + args, + "continue", + None, + ) + .await + { + tracing::warn!(error = %err, "failed to log on_tool_call"); + } + + ToolCallHookAction::cont() + } + + async fn on_tool_result( + &self, + tool_name: &str, + tool_call_id: Option, + internal_call_id: &str, + args: &str, + result: &str, + ) -> HookAction { + let sequence_no = self.next_sequence(); + let event_at = Utc::now().naive_utc(); + + if let Err(err) = self + .store + .log_on_tool_result( + self.session_id, + event_at, + sequence_no, + tool_name, + tool_call_id.as_deref(), + internal_call_id, + args, + result, + ) + .await + { + tracing::warn!(error = %err, "failed to log on_tool_result"); + } + + HookAction::cont() + } + + async fn on_text_delta(&self, text_delta: &str, aggregated_text: &str) -> HookAction { + let sequence_no = self.next_sequence(); + let event_at = Utc::now().naive_utc(); + + if let Err(err) = self + .store + .log_on_text_delta( + self.session_id, + event_at, + sequence_no, + text_delta, + aggregated_text, + ) + .await + { + tracing::warn!(error = %err, "failed to log on_text_delta"); + } + + HookAction::cont() + } + + async fn on_tool_call_delta( + &self, + tool_call_id: &str, + internal_call_id: &str, + tool_name: Option<&str>, + tool_call_delta: &str, + ) -> HookAction { + let sequence_no = self.next_sequence(); + let event_at = Utc::now().naive_utc(); + + if let Err(err) = self + .store + .log_on_tool_call_delta( + self.session_id, + event_at, + sequence_no, + tool_call_id, + internal_call_id, + tool_name, + tool_call_delta, + ) + .await + { + tracing::warn!(error = %err, "failed to log on_tool_call_delta"); + } + + HookAction::cont() + } + + async fn on_stream_completion_response_finish( + &self, + prompt: &Message, + response: &Model::StreamingResponse, + ) -> HookAction { + let sequence_no = self.next_sequence(); + let event_at = Utc::now().naive_utc(); + let prompt_json = serialize_to_json(prompt); + let response_summary_json = serialize_to_json(response); + + if let Err(err) = self + .store + .log_on_stream_completion_response_finish( + self.session_id, + event_at, + sequence_no, + &prompt_json, + &response_summary_json, + ) + .await + { + tracing::warn!(error = %err, "failed to log on_stream_completion_response_finish"); + } + + HookAction::cont() + } +} + +fn serialize_to_json(value: &T) -> String { + serde_json::to_string(value) + .unwrap_or_else(|err| format!(r#"{{"serialization_error":"{}"}}"#, err)) +} + +fn u64_to_i64(value: u64) -> i64 { + i64::try_from(value).unwrap_or(i64::MAX) +} diff --git a/crates/codetaker-agent/src/lib.rs b/crates/codetaker-agent/src/lib.rs index 067c6df..ad166c3 100644 --- a/crates/codetaker-agent/src/lib.rs +++ b/crates/codetaker-agent/src/lib.rs @@ -1,4 +1,5 @@ pub mod agent; +pub mod audit; pub mod error; pub mod git_access; pub mod memory; @@ -8,8 +9,12 @@ pub mod types; pub use agent::{ConversationAnswerAgent, PullRequestReviewAgent}; pub use error::{AgentError, AgentResult}; pub use git_access::PullRequestMaterial; -pub use memory::{DieselMemoryStore, MemoryStore, ProjectContextSnapshot}; -pub use tools::{AstGrepArgs, AstGrepOutput, AstGrepTool, FileContext, SearchHit}; +pub use memory::{AuditStore, DieselMemoryStore, MemoryStore, ProjectContextSnapshot}; +pub use tools::{ + AstGrepArgs, AstGrepOutput, AstGrepTool, FileContext, MemoryWriteOperation, + ProjectMemoryWriteArgs, ProjectMemoryWriteOutput, ProjectMemoryWriteTool, ReadFileArgs, + ReadFileOutput, ReadFileTool, SearchHit, +}; pub use types::{ ConversationInput, ConversationOutput, Forge, MessageAuthor, ProjectRef, PullRequestReviewInput, PullRequestReviewOutput, ReviewComment, ReviewResult, ThreadMessage, diff --git a/crates/codetaker-agent/src/memory/mod.rs b/crates/codetaker-agent/src/memory/mod.rs index 6455c2d..e40bfc6 100644 --- a/crates/codetaker-agent/src/memory/mod.rs +++ b/crates/codetaker-agent/src/memory/mod.rs @@ -2,6 +2,7 @@ mod sqlite; use std::collections::BTreeMap; +use async_trait::async_trait; use chrono::NaiveDateTime; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -24,8 +25,8 @@ pub struct ProjectContextSnapshot { pub summaries: Vec, } -#[allow(async_fn_in_trait)] -pub trait MemoryStore { +#[async_trait] +pub trait MemoryStore: Send + Sync { async fn project_context_snapshot( &self, project: &ProjectRef, @@ -49,6 +50,7 @@ pub trait MemoryStore { async fn create_review_thread( &self, project: &ProjectRef, + pull_request_id: i64, file: &str, line: i32, initial_comment: &str, @@ -63,3 +65,97 @@ pub trait MemoryStore { async fn load_thread_messages(&self, thread_id: i32) -> AgentResult>; } + +#[async_trait] +pub trait AuditStore: Send + Sync { + async fn start_prompt_audit_session( + &self, + project: &ProjectRef, + pull_request_id: i64, + entrypoint: &str, + started_at: NaiveDateTime, + ) -> AgentResult; + + async fn finish_prompt_audit_session( + &self, + session_id: i32, + ended_at: NaiveDateTime, + status: &str, + error_message: Option<&str>, + ) -> AgentResult<()>; + + async fn log_on_completion_call( + &self, + session_id: i32, + event_at: NaiveDateTime, + sequence_no: i32, + prompt_json: &str, + history_json: &str, + ) -> AgentResult<()>; + + async fn log_on_completion_response( + &self, + session_id: i32, + event_at: NaiveDateTime, + sequence_no: i32, + prompt_json: &str, + assistant_choice_json: &str, + usage_input_tokens: i64, + usage_output_tokens: i64, + usage_total_tokens: i64, + ) -> AgentResult<()>; + + async fn log_on_tool_call( + &self, + session_id: i32, + event_at: NaiveDateTime, + sequence_no: i32, + tool_name: &str, + tool_call_id: Option<&str>, + internal_call_id: &str, + args_json: &str, + action: &str, + reason: Option<&str>, + ) -> AgentResult<()>; + + async fn log_on_tool_result( + &self, + session_id: i32, + event_at: NaiveDateTime, + sequence_no: i32, + tool_name: &str, + tool_call_id: Option<&str>, + internal_call_id: &str, + args_json: &str, + result_json: &str, + ) -> AgentResult<()>; + + async fn log_on_text_delta( + &self, + session_id: i32, + event_at: NaiveDateTime, + sequence_no: i32, + text_delta: &str, + aggregated_text: &str, + ) -> AgentResult<()>; + + async fn log_on_tool_call_delta( + &self, + session_id: i32, + event_at: NaiveDateTime, + sequence_no: i32, + tool_call_id: &str, + internal_call_id: &str, + tool_name: Option<&str>, + tool_call_delta: &str, + ) -> AgentResult<()>; + + async fn log_on_stream_completion_response_finish( + &self, + session_id: i32, + event_at: NaiveDateTime, + sequence_no: i32, + prompt_json: &str, + response_summary_json: &str, + ) -> AgentResult<()>; +} diff --git a/crates/codetaker-agent/src/memory/sqlite.rs b/crates/codetaker-agent/src/memory/sqlite.rs index 0254f1f..f4d24cf 100644 --- a/crates/codetaker-agent/src/memory/sqlite.rs +++ b/crates/codetaker-agent/src/memory/sqlite.rs @@ -1,13 +1,21 @@ use std::collections::BTreeMap; +use async_trait::async_trait; use chrono::Utc; use codetaker_db::models::{ - NewProjectMemoryEntryRow, NewProjectMemorySummaryRow, NewProjectRow, NewReviewThreadMessageRow, - NewReviewThreadRow, ProjectMemoryEntryRow, ProjectMemorySummaryRow, ReviewThreadMessageRow, + NewProjectMemoryEntryRow, NewProjectMemorySummaryRow, NewProjectRow, NewPromptAuditSessionRow, + NewPromptHookOnCompletionCallEventRow, NewPromptHookOnCompletionResponseEventRow, + NewPromptHookOnStreamCompletionResponseFinishEventRow, NewPromptHookOnTextDeltaEventRow, + NewPromptHookOnToolCallDeltaEventRow, NewPromptHookOnToolCallEventRow, + NewPromptHookOnToolResultEventRow, NewReviewThreadMessageRow, NewReviewThreadRow, + ProjectMemoryEntryRow, ProjectMemorySummaryRow, ReviewThreadMessageRow, }; use codetaker_db::schema::{ - project_memory_entries, project_memory_summaries, projects, review_thread_messages, - review_threads, + project_memory_entries, project_memory_summaries, projects, prompt_audit_sessions, + prompt_hook_on_completion_call_events, prompt_hook_on_completion_response_events, + prompt_hook_on_stream_completion_response_finish_events, prompt_hook_on_text_delta_events, + prompt_hook_on_tool_call_delta_events, prompt_hook_on_tool_call_events, + prompt_hook_on_tool_result_events, review_thread_messages, review_threads, }; use codetaker_db::{DatabaseConnection, DatabasePool}; use diesel::OptionalExtension; @@ -16,7 +24,7 @@ use diesel_async::RunQueryDsl; use serde_json::Value; use crate::error::{AgentError, AgentResult}; -use crate::memory::{MemoryStore, MemorySummary, ProjectContextSnapshot}; +use crate::memory::{AuditStore, MemoryStore, MemorySummary, ProjectContextSnapshot}; use crate::types::{MessageAuthor, ProjectRef, ThreadMessage}; type PooledConnection<'a> = @@ -90,6 +98,7 @@ impl DieselMemoryStore { } } +#[async_trait] impl MemoryStore for DieselMemoryStore { async fn project_context_snapshot( &self, @@ -221,6 +230,7 @@ impl MemoryStore for DieselMemoryStore { async fn create_review_thread( &self, project: &ProjectRef, + pull_request_id: i64, file: &str, line: i32, initial_comment: &str, @@ -232,6 +242,7 @@ impl MemoryStore for DieselMemoryStore { let now = Utc::now().naive_utc(); let new_row = NewReviewThreadRow { project_id, + pull_request_id, file, line, initial_comment, @@ -248,6 +259,7 @@ impl MemoryStore for DieselMemoryStore { dsl::review_threads .filter(dsl::project_id.eq(project_id)) + .filter(dsl::pull_request_id.eq(pull_request_id)) .filter(dsl::file.eq(file)) .filter(dsl::line.eq(line)) .filter(dsl::initial_comment.eq(initial_comment)) @@ -317,3 +329,292 @@ impl MemoryStore for DieselMemoryStore { .collect() } } + +#[async_trait] +impl AuditStore for DieselMemoryStore { + async fn start_prompt_audit_session( + &self, + project: &ProjectRef, + pull_request_id: i64, + entrypoint: &str, + started_at: chrono::NaiveDateTime, + ) -> AgentResult { + use codetaker_db::schema::prompt_audit_sessions::dsl; + + let mut conn = self.get_conn().await?; + let project_id = Self::ensure_project_id(&mut conn, project).await?; + let new_row = NewPromptAuditSessionRow { + project_id, + pull_request_id, + entrypoint, + started_at, + ended_at: None, + status: "started", + error_message: None, + }; + + diesel::insert_into(prompt_audit_sessions::table) + .values(&new_row) + .execute(&mut conn) + .await + .map_err(|err| AgentError::MemoryError { + message: format!("failed to create prompt audit session: {err}"), + })?; + + dsl::prompt_audit_sessions + .order(dsl::id.desc()) + .select(dsl::id) + .first::(&mut conn) + .await + .map_err(|err| AgentError::MemoryError { + message: format!("failed to fetch prompt audit session id: {err}"), + }) + } + + async fn finish_prompt_audit_session( + &self, + session_id: i32, + ended_at: chrono::NaiveDateTime, + status: &str, + error_message: Option<&str>, + ) -> AgentResult<()> { + use codetaker_db::schema::prompt_audit_sessions::dsl; + + let mut conn = self.get_conn().await?; + diesel::update(dsl::prompt_audit_sessions.filter(dsl::id.eq(session_id))) + .set(( + dsl::ended_at.eq(Some(ended_at)), + dsl::status.eq(status), + dsl::error_message.eq(error_message), + )) + .execute(&mut conn) + .await + .map_err(|err| AgentError::MemoryError { + message: format!("failed to finish prompt audit session {session_id}: {err}"), + })?; + + Ok(()) + } + + async fn log_on_completion_call( + &self, + session_id: i32, + event_at: chrono::NaiveDateTime, + sequence_no: i32, + prompt_json: &str, + history_json: &str, + ) -> AgentResult<()> { + let mut conn = self.get_conn().await?; + let row = NewPromptHookOnCompletionCallEventRow { + session_id, + event_at, + sequence_no, + prompt_json, + history_json, + }; + + diesel::insert_into(prompt_hook_on_completion_call_events::table) + .values(&row) + .execute(&mut conn) + .await + .map_err(|err| AgentError::MemoryError { + message: format!("failed to log on_completion_call event: {err}"), + })?; + + Ok(()) + } + + async fn log_on_completion_response( + &self, + session_id: i32, + event_at: chrono::NaiveDateTime, + sequence_no: i32, + prompt_json: &str, + assistant_choice_json: &str, + usage_input_tokens: i64, + usage_output_tokens: i64, + usage_total_tokens: i64, + ) -> AgentResult<()> { + let mut conn = self.get_conn().await?; + let row = NewPromptHookOnCompletionResponseEventRow { + session_id, + event_at, + sequence_no, + prompt_json, + assistant_choice_json, + usage_input_tokens, + usage_output_tokens, + usage_total_tokens, + }; + + diesel::insert_into(prompt_hook_on_completion_response_events::table) + .values(&row) + .execute(&mut conn) + .await + .map_err(|err| AgentError::MemoryError { + message: format!("failed to log on_completion_response event: {err}"), + })?; + + Ok(()) + } + + async fn log_on_tool_call( + &self, + session_id: i32, + event_at: chrono::NaiveDateTime, + sequence_no: i32, + tool_name: &str, + tool_call_id: Option<&str>, + internal_call_id: &str, + args_json: &str, + action: &str, + reason: Option<&str>, + ) -> AgentResult<()> { + let mut conn = self.get_conn().await?; + let row = NewPromptHookOnToolCallEventRow { + session_id, + event_at, + sequence_no, + tool_name, + tool_call_id, + internal_call_id, + args_json, + action, + reason, + }; + + diesel::insert_into(prompt_hook_on_tool_call_events::table) + .values(&row) + .execute(&mut conn) + .await + .map_err(|err| AgentError::MemoryError { + message: format!("failed to log on_tool_call event: {err}"), + })?; + + Ok(()) + } + + async fn log_on_tool_result( + &self, + session_id: i32, + event_at: chrono::NaiveDateTime, + sequence_no: i32, + tool_name: &str, + tool_call_id: Option<&str>, + internal_call_id: &str, + args_json: &str, + result_json: &str, + ) -> AgentResult<()> { + let mut conn = self.get_conn().await?; + let row = NewPromptHookOnToolResultEventRow { + session_id, + event_at, + sequence_no, + tool_name, + tool_call_id, + internal_call_id, + args_json, + result_json, + }; + + diesel::insert_into(prompt_hook_on_tool_result_events::table) + .values(&row) + .execute(&mut conn) + .await + .map_err(|err| AgentError::MemoryError { + message: format!("failed to log on_tool_result event: {err}"), + })?; + + Ok(()) + } + + async fn log_on_text_delta( + &self, + session_id: i32, + event_at: chrono::NaiveDateTime, + sequence_no: i32, + text_delta: &str, + aggregated_text: &str, + ) -> AgentResult<()> { + let mut conn = self.get_conn().await?; + let row = NewPromptHookOnTextDeltaEventRow { + session_id, + event_at, + sequence_no, + text_delta, + aggregated_text, + }; + + diesel::insert_into(prompt_hook_on_text_delta_events::table) + .values(&row) + .execute(&mut conn) + .await + .map_err(|err| AgentError::MemoryError { + message: format!("failed to log on_text_delta event: {err}"), + })?; + + Ok(()) + } + + async fn log_on_tool_call_delta( + &self, + session_id: i32, + event_at: chrono::NaiveDateTime, + sequence_no: i32, + tool_call_id: &str, + internal_call_id: &str, + tool_name: Option<&str>, + tool_call_delta: &str, + ) -> AgentResult<()> { + let mut conn = self.get_conn().await?; + let row = NewPromptHookOnToolCallDeltaEventRow { + session_id, + event_at, + sequence_no, + tool_call_id, + internal_call_id, + tool_name, + tool_call_delta, + }; + + diesel::insert_into(prompt_hook_on_tool_call_delta_events::table) + .values(&row) + .execute(&mut conn) + .await + .map_err(|err| AgentError::MemoryError { + message: format!("failed to log on_tool_call_delta event: {err}"), + })?; + + Ok(()) + } + + async fn log_on_stream_completion_response_finish( + &self, + session_id: i32, + event_at: chrono::NaiveDateTime, + sequence_no: i32, + prompt_json: &str, + response_summary_json: &str, + ) -> AgentResult<()> { + let mut conn = self.get_conn().await?; + let row = NewPromptHookOnStreamCompletionResponseFinishEventRow { + session_id, + event_at, + sequence_no, + prompt_json, + response_summary_json, + }; + + diesel::insert_into(prompt_hook_on_stream_completion_response_finish_events::table) + .values(&row) + .execute(&mut conn) + .await + .map_err(|err| AgentError::MemoryError { + message: format!( + "failed to log on_stream_completion_response_finish event: {err}" + ), + })?; + + Ok(()) + } +} diff --git a/crates/codetaker-agent/src/tools/ast_grep.rs b/crates/codetaker-agent/src/tools/ast_grep.rs index 3c9af5a..90a0073 100644 --- a/crates/codetaker-agent/src/tools/ast_grep.rs +++ b/crates/codetaker-agent/src/tools/ast_grep.rs @@ -9,7 +9,7 @@ use serde_json::Value; use crate::error::{AgentError, AgentResult}; use crate::git_access; -use crate::tools::{FileContext, SearchHit}; +use crate::tools::SearchHit; #[derive(Debug, Clone)] pub struct AstGrepTool { @@ -21,24 +21,14 @@ pub struct AstGrepTool { #[derive(Debug, Clone, Deserialize)] #[serde(tag = "operation", rename_all = "snake_case")] pub enum AstGrepArgs { - SearchSymbol { - query: String, - }, - SearchPattern { - query: String, - }, - GetFileContext { - file: String, - line: i32, - radius: i32, - }, + SearchSymbol { query: String }, + SearchPattern { query: String }, } #[derive(Debug, Clone, Serialize)] #[serde(tag = "operation", rename_all = "snake_case")] pub enum AstGrepOutput { SearchHits { hits: Vec }, - FileContext { context: FileContext }, } impl Default for AstGrepTool { @@ -171,12 +161,6 @@ impl AstGrepTool { pub fn search_pattern(&self, query: &str) -> AgentResult> { self.run_query(query) } - - pub fn get_file_context(&self, file: &str, line: i32, radius: i32) -> AgentResult { - let (repo_git_dir, head_ref) = self.bound_state()?; - let repo = self.open_repo(&repo_git_dir)?; - git_access::read_file_context_at_ref(&repo, &head_ref, file, line, radius) - } } impl Tool for AstGrepTool { @@ -189,37 +173,18 @@ impl Tool for AstGrepTool { async fn definition(&self, _prompt: String) -> ToolDefinition { ToolDefinition { name: Self::NAME.to_owned(), - description: "Search source code using ast-grep and fetch file context.".to_owned(), + description: "Search source code using ast-grep.".to_owned(), parameters: serde_json::json!({ "type": "object", - "oneOf": [ - { - "type": "object", - "properties": { - "operation": { "type": "string", "const": "search_symbol" }, - "query": { "type": "string", "description": "Symbol-like query" } - }, - "required": ["operation", "query"] + "properties": { + "operation": { + "type": "string", + "enum": ["search_symbol", "search_pattern"], + "description": "Tool operation. Both operations require `query`." }, - { - "type": "object", - "properties": { - "operation": { "type": "string", "const": "search_pattern" }, - "query": { "type": "string", "description": "Pattern query" } - }, - "required": ["operation", "query"] - }, - { - "type": "object", - "properties": { - "operation": { "type": "string", "const": "get_file_context" }, - "file": { "type": "string", "description": "Path to file in repository" }, - "line": { "type": "integer", "description": "1-based line number" }, - "radius": { "type": "integer", "description": "Number of lines before/after line" } - }, - "required": ["operation", "file", "line", "radius"] - } - ] + "query": { "type": "string", "description": "Symbol-like or pattern query" }, + }, + "required": ["operation"] }), } } @@ -234,10 +199,6 @@ impl Tool for AstGrepTool { let hits = self.search_pattern(&query)?; Ok(AstGrepOutput::SearchHits { hits }) } - AstGrepArgs::GetFileContext { file, line, radius } => { - let context = self.get_file_context(&file, line, radius)?; - Ok(AstGrepOutput::FileContext { context }) - } } } } diff --git a/crates/codetaker-agent/src/tools/memory_write.rs b/crates/codetaker-agent/src/tools/memory_write.rs new file mode 100644 index 0000000..23f3809 --- /dev/null +++ b/crates/codetaker-agent/src/tools/memory_write.rs @@ -0,0 +1,148 @@ +use rig::completion::ToolDefinition; +use rig::tool::Tool; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::error::{AgentError, AgentResult}; +use crate::memory::MemoryStore; +use crate::types::ProjectRef; + +#[derive(Debug, Clone)] +pub struct ProjectMemoryWriteTool { + memory: M, + project: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ProjectMemoryWriteArgs { + pub operation: MemoryWriteOperation, + pub key: Option, + pub value: Option, + pub source: Option, + pub summary_type: Option, + pub content: Option, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum MemoryWriteOperation { + UpsertEntry, + UpsertSummary, +} + +#[derive(Debug, Clone, Serialize)] +pub struct ProjectMemoryWriteOutput { + pub stored: bool, + pub detail: String, +} + +impl ProjectMemoryWriteTool { + pub fn new(memory: M) -> Self { + Self { + memory, + project: None, + } + } + + pub fn bind_to_project(&self, project: ProjectRef) -> Self + where + M: Clone, + { + Self { + memory: self.memory.clone(), + project: Some(project), + } + } + + fn project(&self) -> AgentResult<&ProjectRef> { + self.project + .as_ref() + .ok_or_else(|| AgentError::ConfigError { + message: "project memory write tool is not bound to a project context".to_owned(), + }) + } +} + +impl Tool for ProjectMemoryWriteTool +where + M: MemoryStore + Clone + Send + Sync, +{ + const NAME: &'static str = "project_memory_write"; + + type Error = AgentError; + type Args = ProjectMemoryWriteArgs; + type Output = ProjectMemoryWriteOutput; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: Self::NAME.to_owned(), + description: "Persist project-specific memory entries or summaries for future reviews." + .to_owned(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["upsert_entry", "upsert_summary"], + "description": "upsert_entry requires key and value. upsert_summary requires summary_type and content." + }, + "key": { "type": "string", "description": "Memory key for upsert_entry." }, + "value": { "type": "string", "description": "Memory value for upsert_entry. JSON text is accepted; plain text is stored as JSON string." }, + "source": { "type": "string", "description": "Source tag for upsert_entry. Defaults to agent_tool." }, + "summary_type": { "type": "string", "description": "Summary type for upsert_summary." }, + "content": { "type": "string", "description": "Summary content for upsert_summary." } + }, + "required": ["operation"] + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let project = self.project()?; + + match args.operation { + MemoryWriteOperation::UpsertEntry => { + let key = args.key.ok_or_else(|| AgentError::ToolError { + tool: Self::NAME.to_owned(), + message: "missing required field `key` for upsert_entry".to_owned(), + })?; + let raw_value = args.value.ok_or_else(|| AgentError::ToolError { + tool: Self::NAME.to_owned(), + message: "missing required field `value` for upsert_entry".to_owned(), + })?; + + let value: Value = + serde_json::from_str(&raw_value).unwrap_or_else(|_| Value::String(raw_value)); + let source = args.source.as_deref().unwrap_or("agent_tool"); + + self.memory + .upsert_memory_entry(project, &key, &value, source) + .await?; + + Ok(ProjectMemoryWriteOutput { + stored: true, + detail: format!("stored memory entry `{key}`"), + }) + } + MemoryWriteOperation::UpsertSummary => { + let summary_type = args.summary_type.ok_or_else(|| AgentError::ToolError { + tool: Self::NAME.to_owned(), + message: "missing required field `summary_type` for upsert_summary".to_owned(), + })?; + let content = args.content.ok_or_else(|| AgentError::ToolError { + tool: Self::NAME.to_owned(), + message: "missing required field `content` for upsert_summary".to_owned(), + })?; + + self.memory + .upsert_memory_summary(project, &summary_type, &content) + .await?; + + Ok(ProjectMemoryWriteOutput { + stored: true, + detail: format!("stored memory summary `{summary_type}`"), + }) + } + } + } +} diff --git a/crates/codetaker-agent/src/tools/mod.rs b/crates/codetaker-agent/src/tools/mod.rs index 08e842a..15bede1 100644 --- a/crates/codetaker-agent/src/tools/mod.rs +++ b/crates/codetaker-agent/src/tools/mod.rs @@ -1,8 +1,14 @@ mod ast_grep; +mod memory_write; +mod readfile; use serde::{Deserialize, Serialize}; pub use ast_grep::{AstGrepArgs, AstGrepOutput, AstGrepTool}; +pub use memory_write::{ + MemoryWriteOperation, ProjectMemoryWriteArgs, ProjectMemoryWriteOutput, ProjectMemoryWriteTool, +}; +pub use readfile::{ReadFileArgs, ReadFileOutput, ReadFileTool}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct SearchHit { diff --git a/crates/codetaker-agent/src/tools/readfile.rs b/crates/codetaker-agent/src/tools/readfile.rs new file mode 100644 index 0000000..d473bd0 --- /dev/null +++ b/crates/codetaker-agent/src/tools/readfile.rs @@ -0,0 +1,115 @@ +use std::path::{Path, PathBuf}; + +use git2::Repository; +use rig::completion::ToolDefinition; +use rig::tool::Tool; +use serde::{Deserialize, Serialize}; + +use crate::error::{AgentError, AgentResult}; +use crate::git_access; +use crate::tools::FileContext; + +#[derive(Debug, Clone)] +pub struct ReadFileTool { + repo_git_dir: Option, + head_ref: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ReadFileArgs { + pub file: String, + pub line: i32, + pub radius: i32, +} + +#[derive(Debug, Clone, Serialize)] +pub struct ReadFileOutput { + pub context: FileContext, +} + +impl Default for ReadFileTool { + fn default() -> Self { + Self::new() + } +} + +impl ReadFileTool { + pub fn new() -> Self { + Self { + repo_git_dir: None, + head_ref: None, + } + } + + pub fn bind_to_request( + &self, + repo: &Repository, + head_ref: impl Into, + ) -> AgentResult { + git_access::ensure_bare_repository(repo)?; + Ok(Self { + repo_git_dir: Some(repo.path().to_path_buf()), + head_ref: Some(head_ref.into()), + }) + } + + fn bound_state(&self) -> AgentResult<(PathBuf, String)> { + let repo_git_dir = self + .repo_git_dir + .clone() + .ok_or_else(|| AgentError::ConfigError { + message: "readfile tool is not bound to a repository context".to_owned(), + })?; + + let head_ref = self + .head_ref + .clone() + .ok_or_else(|| AgentError::ConfigError { + message: "readfile tool is not bound to a head_ref".to_owned(), + })?; + + Ok((repo_git_dir, head_ref)) + } + + fn open_repo(&self, repo_git_dir: &Path) -> AgentResult { + Repository::open_bare(repo_git_dir).map_err(|err| AgentError::GitError { + operation: format!("open bare repository at '{}'", repo_git_dir.display()), + message: err.to_string(), + }) + } + + fn read_context(&self, file: &str, line: i32, radius: i32) -> AgentResult { + let (repo_git_dir, head_ref) = self.bound_state()?; + let repo = self.open_repo(&repo_git_dir)?; + git_access::read_file_context_at_ref(&repo, &head_ref, file, line, radius) + } +} + +impl Tool for ReadFileTool { + const NAME: &'static str = "readfile"; + + type Error = AgentError; + type Args = ReadFileArgs; + type Output = ReadFileOutput; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: Self::NAME.to_owned(), + description: "Read file context from the repository at head_ref.".to_owned(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "file": { "type": "string", "description": "Path to file in repository" }, + "line": { "type": "integer", "description": "1-based line number" }, + "radius": { "type": "integer", "description": "Number of lines before/after line" } + }, + "required": ["file", "line", "radius"] + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let context = self.read_context(&args.file, args.line, args.radius)?; + Ok(ReadFileOutput { context }) + } +} diff --git a/crates/codetaker-agent/src/types.rs b/crates/codetaker-agent/src/types.rs index 7547216..5da8cf5 100644 --- a/crates/codetaker-agent/src/types.rs +++ b/crates/codetaker-agent/src/types.rs @@ -50,6 +50,7 @@ pub struct ProjectRef { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PullRequestReviewInput { pub project: ProjectRef, + pub pull_request_id: i64, pub base_ref: String, pub head_ref: String, } @@ -57,6 +58,7 @@ pub struct PullRequestReviewInput { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConversationInput { pub project: ProjectRef, + pub pull_request_id: i64, pub head_ref: String, pub anchor_file: String, pub anchor_line: i32, diff --git a/crates/codetaker-agent/tests/sqlite_memory_store.rs b/crates/codetaker-agent/tests/sqlite_memory_store.rs index 75f2160..776bd0e 100644 --- a/crates/codetaker-agent/tests/sqlite_memory_store.rs +++ b/crates/codetaker-agent/tests/sqlite_memory_store.rs @@ -1,8 +1,11 @@ use std::path::PathBuf; -use codetaker_agent::memory::{DieselMemoryStore, MemoryStore}; +use codetaker_agent::memory::{AuditStore, DieselMemoryStore, MemoryStore}; use codetaker_agent::types::{Forge, MessageAuthor, ProjectRef}; use codetaker_db::create_pool; +use diesel::ExpressionMethods; +use diesel::QueryDsl; +use diesel_async::RunQueryDsl; use serde_json::json; fn test_db_path(test_name: &str) -> PathBuf { @@ -87,7 +90,7 @@ async fn thread_messages_round_trip() { let project = sample_project(); let thread_id = store - .create_review_thread(&project, "src/lib.rs", 42, "Please avoid unwrap here") + .create_review_thread(&project, 123, "src/lib.rs", 42, "Please avoid unwrap here") .await .expect("create review thread"); @@ -113,3 +116,195 @@ async fn thread_messages_round_trip() { assert!(matches!(messages[0].author, MessageAuthor::User)); assert!(matches!(messages[1].author, MessageAuthor::Agent)); } + +#[tokio::test] +async fn prompt_audit_session_lifecycle_success_and_failure() { + let db_path = test_db_path("prompt_audit_session_lifecycle_success_and_failure"); + let database_url = db_path.display().to_string(); + let pool = create_pool(Some(&database_url)) + .await + .expect("create sqlite pool"); + let store = DieselMemoryStore::new(pool.clone()); + let project = sample_project(); + let now = chrono::Utc::now().naive_utc(); + + let success_session = store + .start_prompt_audit_session(&project, 11, "pull_request_review", now) + .await + .expect("start success session"); + store + .finish_prompt_audit_session( + success_session, + chrono::Utc::now().naive_utc(), + "completed", + None, + ) + .await + .expect("finish success session"); + + let failed_session = store + .start_prompt_audit_session(&project, 12, "conversation_response", now) + .await + .expect("start failed session"); + store + .finish_prompt_audit_session( + failed_session, + chrono::Utc::now().naive_utc(), + "failed", + Some("model error"), + ) + .await + .expect("finish failed session"); + + let mut conn = pool.get().await.expect("get pooled connection"); + use codetaker_db::schema::prompt_audit_sessions::dsl; + let statuses = dsl::prompt_audit_sessions + .filter(dsl::id.eq_any([success_session, failed_session])) + .select(dsl::status) + .load::(&mut conn) + .await + .expect("load session statuses"); + assert!(statuses.iter().any(|status| status == "completed")); + assert!(statuses.iter().any(|status| status == "failed")); +} + +#[tokio::test] +async fn prompt_hook_event_methods_write_rows() { + let db_path = test_db_path("prompt_hook_event_methods_write_rows"); + let database_url = db_path.display().to_string(); + let pool = create_pool(Some(&database_url)) + .await + .expect("create sqlite pool"); + let store = DieselMemoryStore::new(pool.clone()); + let project = sample_project(); + let now = chrono::Utc::now().naive_utc(); + + let session_id = store + .start_prompt_audit_session(&project, 99, "pull_request_review", now) + .await + .expect("start audit session"); + + store + .log_on_completion_call(session_id, now, 1, "{\"role\":\"user\"}", "[]") + .await + .expect("log completion call"); + store + .log_on_completion_response(session_id, now, 2, "{\"role\":\"user\"}", "[]", 10, 20, 30) + .await + .expect("log completion response"); + store + .log_on_tool_call( + session_id, + now, + 3, + "ast_grep", + Some("call_1"), + "internal_1", + "{\"query\":\"foo\"}", + "continue", + None, + ) + .await + .expect("log tool call"); + store + .log_on_tool_result( + session_id, + now, + 4, + "ast_grep", + Some("call_1"), + "internal_1", + "{\"query\":\"foo\"}", + "{\"hits\":[]}", + ) + .await + .expect("log tool result"); + store + .log_on_text_delta(session_id, now, 5, "foo", "foobar") + .await + .expect("log text delta"); + store + .log_on_tool_call_delta( + session_id, + now, + 6, + "call_2", + "internal_2", + Some("readfile"), + "{\"path\":\"src/lib.rs\"}", + ) + .await + .expect("log tool call delta"); + store + .log_on_stream_completion_response_finish(session_id, now, 7, "{}", "{\"done\":true}") + .await + .expect("log stream finish"); + + let mut conn = pool.get().await.expect("get pooled connection"); + use codetaker_db::schema::{ + prompt_hook_on_completion_call_events::dsl as completion_call_dsl, + prompt_hook_on_completion_response_events::dsl as completion_response_dsl, + prompt_hook_on_stream_completion_response_finish_events::dsl as stream_finish_dsl, + prompt_hook_on_text_delta_events::dsl as text_delta_dsl, + prompt_hook_on_tool_call_delta_events::dsl as tool_call_delta_dsl, + prompt_hook_on_tool_call_events::dsl as tool_call_dsl, + prompt_hook_on_tool_result_events::dsl as tool_result_dsl, + }; + + let completion_call_count = completion_call_dsl::prompt_hook_on_completion_call_events + .filter(completion_call_dsl::session_id.eq(session_id)) + .count() + .get_result::(&mut conn) + .await + .expect("count completion call events"); + assert_eq!(completion_call_count, 1); + + let completion_response_count = completion_response_dsl::prompt_hook_on_completion_response_events + .filter(completion_response_dsl::session_id.eq(session_id)) + .count() + .get_result::(&mut conn) + .await + .expect("count completion response events"); + assert_eq!(completion_response_count, 1); + + let tool_call_count = tool_call_dsl::prompt_hook_on_tool_call_events + .filter(tool_call_dsl::session_id.eq(session_id)) + .count() + .get_result::(&mut conn) + .await + .expect("count tool call events"); + assert_eq!(tool_call_count, 1); + + let tool_result_count = tool_result_dsl::prompt_hook_on_tool_result_events + .filter(tool_result_dsl::session_id.eq(session_id)) + .count() + .get_result::(&mut conn) + .await + .expect("count tool result events"); + assert_eq!(tool_result_count, 1); + + let text_delta_count = text_delta_dsl::prompt_hook_on_text_delta_events + .filter(text_delta_dsl::session_id.eq(session_id)) + .count() + .get_result::(&mut conn) + .await + .expect("count text delta events"); + assert_eq!(text_delta_count, 1); + + let tool_delta_count = tool_call_delta_dsl::prompt_hook_on_tool_call_delta_events + .filter(tool_call_delta_dsl::session_id.eq(session_id)) + .count() + .get_result::(&mut conn) + .await + .expect("count tool call delta events"); + assert_eq!(tool_delta_count, 1); + + let stream_finish_count = + stream_finish_dsl::prompt_hook_on_stream_completion_response_finish_events + .filter(stream_finish_dsl::session_id.eq(session_id)) + .count() + .get_result::(&mut conn) + .await + .expect("count stream finish events"); + assert_eq!(stream_finish_count, 1); +} diff --git a/crates/codetaker-cli/Cargo.toml b/crates/codetaker-cli/Cargo.toml index 038008a..858302e 100644 --- a/crates/codetaker-cli/Cargo.toml +++ b/crates/codetaker-cli/Cargo.toml @@ -7,6 +7,7 @@ edition = "2024" clap.workspace = true codetaker-agent = { path = "../codetaker-agent" } codetaker-db.workspace = true +dialoguer.workspace = true git2.workspace = true rig-core.workspace = true serde_json.workspace = true diff --git a/crates/codetaker-cli/src/main.rs b/crates/codetaker-cli/src/main.rs index 4971159..d09dc14 100644 --- a/crates/codetaker-cli/src/main.rs +++ b/crates/codetaker-cli/src/main.rs @@ -31,6 +31,9 @@ struct Cli { #[arg(long)] head_ref: String, + #[arg(long)] + pull_request_id: i64, + #[arg(long, default_value = "local")] owner: String, @@ -40,6 +43,7 @@ struct Cli { #[tokio::main] async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); let cli = Cli::parse(); let repo_name = cli.repo.unwrap_or_else(|| { @@ -63,6 +67,7 @@ async fn main() -> Result<(), Box> { owner: cli.owner, repo: repo_name, }, + pull_request_id: cli.pull_request_id, base_ref: cli.base_ref, head_ref: cli.head_ref, }; @@ -72,3 +77,26 @@ async fn main() -> Result<(), Box> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::Cli; + use clap::Parser; + + #[test] + fn requires_pull_request_id() { + let parsed = Cli::try_parse_from([ + "codetaker-cli", + "--anthropic-api-key", + "key", + "--repo-path", + ".", + "--base-ref", + "main", + "--head-ref", + "feature", + ]); + + assert!(parsed.is_err()); + } +} diff --git a/crates/codetaker-db/migrations/2026-02-27-094438-0000_initial/down.sql b/crates/codetaker-db/migrations/2026-02-27-094438-0000_initial/down.sql index 897a80a..499198f 100644 --- a/crates/codetaker-db/migrations/2026-02-27-094438-0000_initial/down.sql +++ b/crates/codetaker-db/migrations/2026-02-27-094438-0000_initial/down.sql @@ -1,3 +1,11 @@ +DROP TABLE IF EXISTS prompt_hook_on_stream_completion_response_finish_events; +DROP TABLE IF EXISTS prompt_hook_on_tool_call_delta_events; +DROP TABLE IF EXISTS prompt_hook_on_text_delta_events; +DROP TABLE IF EXISTS prompt_hook_on_tool_result_events; +DROP TABLE IF EXISTS prompt_hook_on_tool_call_events; +DROP TABLE IF EXISTS prompt_hook_on_completion_response_events; +DROP TABLE IF EXISTS prompt_hook_on_completion_call_events; +DROP TABLE IF EXISTS prompt_audit_sessions; DROP TABLE IF EXISTS review_thread_messages; DROP TABLE IF EXISTS review_threads; DROP TABLE IF EXISTS project_memory_summaries; diff --git a/crates/codetaker-db/migrations/2026-02-27-094438-0000_initial/up.sql b/crates/codetaker-db/migrations/2026-02-27-094438-0000_initial/up.sql index 0aceaf5..b432419 100644 --- a/crates/codetaker-db/migrations/2026-02-27-094438-0000_initial/up.sql +++ b/crates/codetaker-db/migrations/2026-02-27-094438-0000_initial/up.sql @@ -30,6 +30,7 @@ CREATE TABLE project_memory_summaries ( CREATE TABLE review_threads ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, project_id INTEGER NOT NULL, + pull_request_id BIGINT NOT NULL, file TEXT NOT NULL, line INTEGER NOT NULL, initial_comment TEXT NOT NULL, @@ -46,11 +47,125 @@ CREATE TABLE review_thread_messages ( FOREIGN KEY(thread_id) REFERENCES review_threads(id) ON DELETE CASCADE ); +CREATE TABLE prompt_audit_sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + project_id INTEGER NOT NULL, + pull_request_id BIGINT NOT NULL, + entrypoint TEXT NOT NULL, + started_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + ended_at TIMESTAMP, + status TEXT NOT NULL, + error_message TEXT, + FOREIGN KEY(project_id) REFERENCES projects(id) ON DELETE CASCADE +); + +CREATE TABLE prompt_hook_on_completion_call_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + session_id INTEGER NOT NULL, + event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + sequence_no INTEGER NOT NULL, + prompt_json TEXT NOT NULL, + history_json TEXT NOT NULL, + FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE +); + +CREATE TABLE prompt_hook_on_completion_response_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + session_id INTEGER NOT NULL, + event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + sequence_no INTEGER NOT NULL, + prompt_json TEXT NOT NULL, + assistant_choice_json TEXT NOT NULL, + usage_input_tokens BIGINT NOT NULL, + usage_output_tokens BIGINT NOT NULL, + usage_total_tokens BIGINT NOT NULL, + FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE +); + +CREATE TABLE prompt_hook_on_tool_call_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + session_id INTEGER NOT NULL, + event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + sequence_no INTEGER NOT NULL, + tool_name TEXT NOT NULL, + tool_call_id TEXT, + internal_call_id TEXT NOT NULL, + args_json TEXT NOT NULL, + action TEXT NOT NULL, + reason TEXT, + FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE +); + +CREATE TABLE prompt_hook_on_tool_result_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + session_id INTEGER NOT NULL, + event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + sequence_no INTEGER NOT NULL, + tool_name TEXT NOT NULL, + tool_call_id TEXT, + internal_call_id TEXT NOT NULL, + args_json TEXT NOT NULL, + result_json TEXT NOT NULL, + FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE +); + +CREATE TABLE prompt_hook_on_text_delta_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + session_id INTEGER NOT NULL, + event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + sequence_no INTEGER NOT NULL, + text_delta TEXT NOT NULL, + aggregated_text TEXT NOT NULL, + FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE +); + +CREATE TABLE prompt_hook_on_tool_call_delta_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + session_id INTEGER NOT NULL, + event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + sequence_no INTEGER NOT NULL, + tool_call_id TEXT NOT NULL, + internal_call_id TEXT NOT NULL, + tool_name TEXT, + tool_call_delta TEXT NOT NULL, + FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE +); + +CREATE TABLE prompt_hook_on_stream_completion_response_finish_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + session_id INTEGER NOT NULL, + event_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + sequence_no INTEGER NOT NULL, + prompt_json TEXT NOT NULL, + response_summary_json TEXT NOT NULL, + FOREIGN KEY(session_id) REFERENCES prompt_audit_sessions(id) ON DELETE CASCADE +); + CREATE INDEX idx_project_memory_entries_project_id ON project_memory_entries(project_id); CREATE INDEX idx_project_memory_summaries_project_id ON project_memory_summaries(project_id); CREATE INDEX idx_review_threads_project_id ON review_threads(project_id); +CREATE INDEX idx_review_threads_project_pr + ON review_threads(project_id, pull_request_id); CREATE INDEX idx_review_thread_messages_thread_id ON review_thread_messages(thread_id); +CREATE INDEX idx_prompt_audit_sessions_project_pr + ON prompt_audit_sessions(project_id, pull_request_id); +CREATE INDEX idx_prompt_audit_sessions_entrypoint_started + ON prompt_audit_sessions(entrypoint, started_at); +CREATE INDEX idx_prompt_hook_completion_call_session + ON prompt_hook_on_completion_call_events(session_id, sequence_no); +CREATE INDEX idx_prompt_hook_completion_response_session + ON prompt_hook_on_completion_response_events(session_id, sequence_no); +CREATE INDEX idx_prompt_hook_tool_call_session + ON prompt_hook_on_tool_call_events(session_id, sequence_no); +CREATE INDEX idx_prompt_hook_tool_result_session + ON prompt_hook_on_tool_result_events(session_id, sequence_no); +CREATE INDEX idx_prompt_hook_text_delta_session + ON prompt_hook_on_text_delta_events(session_id, sequence_no); +CREATE INDEX idx_prompt_hook_tool_delta_session + ON prompt_hook_on_tool_call_delta_events(session_id, sequence_no); +CREATE INDEX idx_prompt_hook_stream_finish_session + ON prompt_hook_on_stream_completion_response_finish_events(session_id, sequence_no); diff --git a/crates/codetaker-db/src/models.rs b/crates/codetaker-db/src/models.rs index f4f430e..755548e 100644 --- a/crates/codetaker-db/src/models.rs +++ b/crates/codetaker-db/src/models.rs @@ -2,8 +2,11 @@ use chrono::NaiveDateTime; use diesel::{Associations, Identifiable, Insertable, Queryable, Selectable}; use crate::schema::{ - project_memory_entries, project_memory_summaries, projects, review_thread_messages, - review_threads, + project_memory_entries, project_memory_summaries, projects, prompt_audit_sessions, + prompt_hook_on_completion_call_events, prompt_hook_on_completion_response_events, + prompt_hook_on_stream_completion_response_finish_events, prompt_hook_on_text_delta_events, + prompt_hook_on_tool_call_delta_events, prompt_hook_on_tool_call_events, + prompt_hook_on_tool_result_events, review_thread_messages, review_threads, }; #[derive(Debug, Clone, Queryable, Selectable, Identifiable)] @@ -71,6 +74,7 @@ pub struct NewProjectMemorySummaryRow<'a> { pub struct ReviewThreadRow { pub id: i32, pub project_id: i32, + pub pull_request_id: i64, pub file: String, pub line: i32, pub initial_comment: String, @@ -81,6 +85,7 @@ pub struct ReviewThreadRow { #[diesel(table_name = review_threads)] pub struct NewReviewThreadRow<'a> { pub project_id: i32, + pub pull_request_id: i64, pub file: &'a str, pub line: i32, pub initial_comment: &'a str, @@ -106,3 +111,111 @@ pub struct NewReviewThreadMessageRow<'a> { pub body: &'a str, pub created_at: NaiveDateTime, } + +#[derive(Debug, Clone, Queryable, Selectable, Identifiable, Associations)] +#[diesel(table_name = prompt_audit_sessions)] +#[diesel(belongs_to(ProjectRow, foreign_key = project_id))] +pub struct PromptAuditSessionRow { + pub id: i32, + pub project_id: i32, + pub pull_request_id: i64, + pub entrypoint: String, + pub started_at: NaiveDateTime, + pub ended_at: Option, + pub status: String, + pub error_message: Option, +} + +#[derive(Debug, Insertable)] +#[diesel(table_name = prompt_audit_sessions)] +pub struct NewPromptAuditSessionRow<'a> { + pub project_id: i32, + pub pull_request_id: i64, + pub entrypoint: &'a str, + pub started_at: NaiveDateTime, + pub ended_at: Option, + pub status: &'a str, + pub error_message: Option<&'a str>, +} + +#[derive(Debug, Insertable)] +#[diesel(table_name = prompt_hook_on_completion_call_events)] +pub struct NewPromptHookOnCompletionCallEventRow<'a> { + pub session_id: i32, + pub event_at: NaiveDateTime, + pub sequence_no: i32, + pub prompt_json: &'a str, + pub history_json: &'a str, +} + +#[derive(Debug, Insertable)] +#[diesel(table_name = prompt_hook_on_completion_response_events)] +pub struct NewPromptHookOnCompletionResponseEventRow<'a> { + pub session_id: i32, + pub event_at: NaiveDateTime, + pub sequence_no: i32, + pub prompt_json: &'a str, + pub assistant_choice_json: &'a str, + pub usage_input_tokens: i64, + pub usage_output_tokens: i64, + pub usage_total_tokens: i64, +} + +#[derive(Debug, Insertable)] +#[diesel(table_name = prompt_hook_on_tool_call_events)] +pub struct NewPromptHookOnToolCallEventRow<'a> { + pub session_id: i32, + pub event_at: NaiveDateTime, + pub sequence_no: i32, + pub tool_name: &'a str, + pub tool_call_id: Option<&'a str>, + pub internal_call_id: &'a str, + pub args_json: &'a str, + pub action: &'a str, + pub reason: Option<&'a str>, +} + +#[derive(Debug, Insertable)] +#[diesel(table_name = prompt_hook_on_tool_result_events)] +pub struct NewPromptHookOnToolResultEventRow<'a> { + pub session_id: i32, + pub event_at: NaiveDateTime, + pub sequence_no: i32, + pub tool_name: &'a str, + pub tool_call_id: Option<&'a str>, + pub internal_call_id: &'a str, + pub args_json: &'a str, + pub result_json: &'a str, +} + +#[derive(Debug, Insertable)] +#[diesel(table_name = prompt_hook_on_text_delta_events)] +pub struct NewPromptHookOnTextDeltaEventRow<'a> { + pub session_id: i32, + pub event_at: NaiveDateTime, + pub sequence_no: i32, + pub text_delta: &'a str, + pub aggregated_text: &'a str, +} + +#[derive(Debug, Insertable)] +#[diesel(table_name = prompt_hook_on_tool_call_delta_events)] +pub struct NewPromptHookOnToolCallDeltaEventRow<'a> { + pub session_id: i32, + pub event_at: NaiveDateTime, + pub sequence_no: i32, + pub tool_call_id: &'a str, + pub internal_call_id: &'a str, + pub tool_name: Option<&'a str>, + pub tool_call_delta: &'a str, +} + +#[derive(Debug, Insertable)] +#[diesel(table_name = prompt_hook_on_stream_completion_response_finish_events)] +pub struct NewPromptHookOnStreamCompletionResponseFinishEventRow<'a> { + pub session_id: i32, + pub event_at: NaiveDateTime, + pub sequence_no: i32, + pub prompt_json: &'a str, + pub response_summary_json: &'a str, +} diff --git a/crates/codetaker-db/src/schema.rs b/crates/codetaker-db/src/schema.rs index 20a8c10..ca96369 100644 --- a/crates/codetaker-db/src/schema.rs +++ b/crates/codetaker-db/src/schema.rs @@ -30,6 +30,108 @@ diesel::table! { } } +diesel::table! { + prompt_audit_sessions (id) { + id -> Integer, + project_id -> Integer, + pull_request_id -> BigInt, + entrypoint -> Text, + started_at -> Timestamp, + ended_at -> Nullable, + status -> Text, + error_message -> Nullable, + } +} + +diesel::table! { + prompt_hook_on_completion_call_events (id) { + id -> Integer, + session_id -> Integer, + event_at -> Timestamp, + sequence_no -> Integer, + prompt_json -> Text, + history_json -> Text, + } +} + +diesel::table! { + prompt_hook_on_completion_response_events (id) { + id -> Integer, + session_id -> Integer, + event_at -> Timestamp, + sequence_no -> Integer, + prompt_json -> Text, + assistant_choice_json -> Text, + usage_input_tokens -> BigInt, + usage_output_tokens -> BigInt, + usage_total_tokens -> BigInt, + } +} + +diesel::table! { + prompt_hook_on_stream_completion_response_finish_events (id) { + id -> Integer, + session_id -> Integer, + event_at -> Timestamp, + sequence_no -> Integer, + prompt_json -> Text, + response_summary_json -> Text, + } +} + +diesel::table! { + prompt_hook_on_text_delta_events (id) { + id -> Integer, + session_id -> Integer, + event_at -> Timestamp, + sequence_no -> Integer, + text_delta -> Text, + aggregated_text -> Text, + } +} + +diesel::table! { + prompt_hook_on_tool_call_delta_events (id) { + id -> Integer, + session_id -> Integer, + event_at -> Timestamp, + sequence_no -> Integer, + tool_call_id -> Text, + internal_call_id -> Text, + tool_name -> Nullable, + tool_call_delta -> Text, + } +} + +diesel::table! { + prompt_hook_on_tool_call_events (id) { + id -> Integer, + session_id -> Integer, + event_at -> Timestamp, + sequence_no -> Integer, + tool_name -> Text, + tool_call_id -> Nullable, + internal_call_id -> Text, + args_json -> Text, + action -> Text, + reason -> Nullable, + } +} + +diesel::table! { + prompt_hook_on_tool_result_events (id) { + id -> Integer, + session_id -> Integer, + event_at -> Timestamp, + sequence_no -> Integer, + tool_name -> Text, + tool_call_id -> Nullable, + internal_call_id -> Text, + args_json -> Text, + result_json -> Text, + } +} + diesel::table! { review_thread_messages (id) { id -> Integer, @@ -44,6 +146,7 @@ diesel::table! { review_threads (id) { id -> Integer, project_id -> Integer, + pull_request_id -> BigInt, file -> Text, line -> Integer, initial_comment -> Text, @@ -53,6 +156,14 @@ diesel::table! { diesel::joinable!(project_memory_entries -> projects (project_id)); diesel::joinable!(project_memory_summaries -> projects (project_id)); +diesel::joinable!(prompt_audit_sessions -> projects (project_id)); +diesel::joinable!(prompt_hook_on_completion_call_events -> prompt_audit_sessions (session_id)); +diesel::joinable!(prompt_hook_on_completion_response_events -> prompt_audit_sessions (session_id)); +diesel::joinable!(prompt_hook_on_stream_completion_response_finish_events -> prompt_audit_sessions (session_id)); +diesel::joinable!(prompt_hook_on_text_delta_events -> prompt_audit_sessions (session_id)); +diesel::joinable!(prompt_hook_on_tool_call_delta_events -> prompt_audit_sessions (session_id)); +diesel::joinable!(prompt_hook_on_tool_call_events -> prompt_audit_sessions (session_id)); +diesel::joinable!(prompt_hook_on_tool_result_events -> prompt_audit_sessions (session_id)); diesel::joinable!(review_thread_messages -> review_threads (thread_id)); diesel::joinable!(review_threads -> projects (project_id)); @@ -60,6 +171,14 @@ diesel::allow_tables_to_appear_in_same_query!( projects, project_memory_entries, project_memory_summaries, + prompt_audit_sessions, + prompt_hook_on_completion_call_events, + prompt_hook_on_completion_response_events, + prompt_hook_on_stream_completion_response_finish_events, + prompt_hook_on_text_delta_events, + prompt_hook_on_tool_call_delta_events, + prompt_hook_on_tool_call_events, + prompt_hook_on_tool_result_events, review_threads, review_thread_messages, );