feat(agent): add comprehensive prompt audit system

This commit is contained in:
hdbg
2026-02-27 13:23:21 +01:00
parent 3738272f80
commit 3a210809cf
21 changed files with 1800 additions and 107 deletions

View File

@@ -1,8 +1,11 @@
use std::path::PathBuf;
use codetaker_agent::memory::{DieselMemoryStore, MemoryStore};
use codetaker_agent::memory::{AuditStore, DieselMemoryStore, MemoryStore};
use codetaker_agent::types::{Forge, MessageAuthor, ProjectRef};
use codetaker_db::create_pool;
use diesel::ExpressionMethods;
use diesel::QueryDsl;
use diesel_async::RunQueryDsl;
use serde_json::json;
fn test_db_path(test_name: &str) -> PathBuf {
@@ -87,7 +90,7 @@ async fn thread_messages_round_trip() {
let project = sample_project();
let thread_id = store
.create_review_thread(&project, "src/lib.rs", 42, "Please avoid unwrap here")
.create_review_thread(&project, 123, "src/lib.rs", 42, "Please avoid unwrap here")
.await
.expect("create review thread");
@@ -113,3 +116,195 @@ async fn thread_messages_round_trip() {
assert!(matches!(messages[0].author, MessageAuthor::User));
assert!(matches!(messages[1].author, MessageAuthor::Agent));
}
#[tokio::test]
async fn prompt_audit_session_lifecycle_success_and_failure() {
let db_path = test_db_path("prompt_audit_session_lifecycle_success_and_failure");
let database_url = db_path.display().to_string();
let pool = create_pool(Some(&database_url))
.await
.expect("create sqlite pool");
let store = DieselMemoryStore::new(pool.clone());
let project = sample_project();
let now = chrono::Utc::now().naive_utc();
let success_session = store
.start_prompt_audit_session(&project, 11, "pull_request_review", now)
.await
.expect("start success session");
store
.finish_prompt_audit_session(
success_session,
chrono::Utc::now().naive_utc(),
"completed",
None,
)
.await
.expect("finish success session");
let failed_session = store
.start_prompt_audit_session(&project, 12, "conversation_response", now)
.await
.expect("start failed session");
store
.finish_prompt_audit_session(
failed_session,
chrono::Utc::now().naive_utc(),
"failed",
Some("model error"),
)
.await
.expect("finish failed session");
let mut conn = pool.get().await.expect("get pooled connection");
use codetaker_db::schema::prompt_audit_sessions::dsl;
let statuses = dsl::prompt_audit_sessions
.filter(dsl::id.eq_any([success_session, failed_session]))
.select(dsl::status)
.load::<String>(&mut conn)
.await
.expect("load session statuses");
assert!(statuses.iter().any(|status| status == "completed"));
assert!(statuses.iter().any(|status| status == "failed"));
}
#[tokio::test]
async fn prompt_hook_event_methods_write_rows() {
let db_path = test_db_path("prompt_hook_event_methods_write_rows");
let database_url = db_path.display().to_string();
let pool = create_pool(Some(&database_url))
.await
.expect("create sqlite pool");
let store = DieselMemoryStore::new(pool.clone());
let project = sample_project();
let now = chrono::Utc::now().naive_utc();
let session_id = store
.start_prompt_audit_session(&project, 99, "pull_request_review", now)
.await
.expect("start audit session");
store
.log_on_completion_call(session_id, now, 1, "{\"role\":\"user\"}", "[]")
.await
.expect("log completion call");
store
.log_on_completion_response(session_id, now, 2, "{\"role\":\"user\"}", "[]", 10, 20, 30)
.await
.expect("log completion response");
store
.log_on_tool_call(
session_id,
now,
3,
"ast_grep",
Some("call_1"),
"internal_1",
"{\"query\":\"foo\"}",
"continue",
None,
)
.await
.expect("log tool call");
store
.log_on_tool_result(
session_id,
now,
4,
"ast_grep",
Some("call_1"),
"internal_1",
"{\"query\":\"foo\"}",
"{\"hits\":[]}",
)
.await
.expect("log tool result");
store
.log_on_text_delta(session_id, now, 5, "foo", "foobar")
.await
.expect("log text delta");
store
.log_on_tool_call_delta(
session_id,
now,
6,
"call_2",
"internal_2",
Some("readfile"),
"{\"path\":\"src/lib.rs\"}",
)
.await
.expect("log tool call delta");
store
.log_on_stream_completion_response_finish(session_id, now, 7, "{}", "{\"done\":true}")
.await
.expect("log stream finish");
let mut conn = pool.get().await.expect("get pooled connection");
use codetaker_db::schema::{
prompt_hook_on_completion_call_events::dsl as completion_call_dsl,
prompt_hook_on_completion_response_events::dsl as completion_response_dsl,
prompt_hook_on_stream_completion_response_finish_events::dsl as stream_finish_dsl,
prompt_hook_on_text_delta_events::dsl as text_delta_dsl,
prompt_hook_on_tool_call_delta_events::dsl as tool_call_delta_dsl,
prompt_hook_on_tool_call_events::dsl as tool_call_dsl,
prompt_hook_on_tool_result_events::dsl as tool_result_dsl,
};
let completion_call_count = completion_call_dsl::prompt_hook_on_completion_call_events
.filter(completion_call_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count completion call events");
assert_eq!(completion_call_count, 1);
let completion_response_count = completion_response_dsl::prompt_hook_on_completion_response_events
.filter(completion_response_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count completion response events");
assert_eq!(completion_response_count, 1);
let tool_call_count = tool_call_dsl::prompt_hook_on_tool_call_events
.filter(tool_call_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count tool call events");
assert_eq!(tool_call_count, 1);
let tool_result_count = tool_result_dsl::prompt_hook_on_tool_result_events
.filter(tool_result_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count tool result events");
assert_eq!(tool_result_count, 1);
let text_delta_count = text_delta_dsl::prompt_hook_on_text_delta_events
.filter(text_delta_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count text delta events");
assert_eq!(text_delta_count, 1);
let tool_delta_count = tool_call_delta_dsl::prompt_hook_on_tool_call_delta_events
.filter(tool_call_delta_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count tool call delta events");
assert_eq!(tool_delta_count, 1);
let stream_finish_count =
stream_finish_dsl::prompt_hook_on_stream_completion_response_finish_events
.filter(stream_finish_dsl::session_id.eq(session_id))
.count()
.get_result::<i64>(&mut conn)
.await
.expect("count stream finish events");
assert_eq!(stream_finish_count, 1);
}