feat(agent): add comprehensive prompt audit system
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codetaker_agent::memory::{DieselMemoryStore, MemoryStore};
|
||||
use codetaker_agent::memory::{AuditStore, DieselMemoryStore, MemoryStore};
|
||||
use codetaker_agent::types::{Forge, MessageAuthor, ProjectRef};
|
||||
use codetaker_db::create_pool;
|
||||
use diesel::ExpressionMethods;
|
||||
use diesel::QueryDsl;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use serde_json::json;
|
||||
|
||||
fn test_db_path(test_name: &str) -> PathBuf {
|
||||
@@ -87,7 +90,7 @@ async fn thread_messages_round_trip() {
|
||||
|
||||
let project = sample_project();
|
||||
let thread_id = store
|
||||
.create_review_thread(&project, "src/lib.rs", 42, "Please avoid unwrap here")
|
||||
.create_review_thread(&project, 123, "src/lib.rs", 42, "Please avoid unwrap here")
|
||||
.await
|
||||
.expect("create review thread");
|
||||
|
||||
@@ -113,3 +116,195 @@ async fn thread_messages_round_trip() {
|
||||
assert!(matches!(messages[0].author, MessageAuthor::User));
|
||||
assert!(matches!(messages[1].author, MessageAuthor::Agent));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_audit_session_lifecycle_success_and_failure() {
|
||||
let db_path = test_db_path("prompt_audit_session_lifecycle_success_and_failure");
|
||||
let database_url = db_path.display().to_string();
|
||||
let pool = create_pool(Some(&database_url))
|
||||
.await
|
||||
.expect("create sqlite pool");
|
||||
let store = DieselMemoryStore::new(pool.clone());
|
||||
let project = sample_project();
|
||||
let now = chrono::Utc::now().naive_utc();
|
||||
|
||||
let success_session = store
|
||||
.start_prompt_audit_session(&project, 11, "pull_request_review", now)
|
||||
.await
|
||||
.expect("start success session");
|
||||
store
|
||||
.finish_prompt_audit_session(
|
||||
success_session,
|
||||
chrono::Utc::now().naive_utc(),
|
||||
"completed",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("finish success session");
|
||||
|
||||
let failed_session = store
|
||||
.start_prompt_audit_session(&project, 12, "conversation_response", now)
|
||||
.await
|
||||
.expect("start failed session");
|
||||
store
|
||||
.finish_prompt_audit_session(
|
||||
failed_session,
|
||||
chrono::Utc::now().naive_utc(),
|
||||
"failed",
|
||||
Some("model error"),
|
||||
)
|
||||
.await
|
||||
.expect("finish failed session");
|
||||
|
||||
let mut conn = pool.get().await.expect("get pooled connection");
|
||||
use codetaker_db::schema::prompt_audit_sessions::dsl;
|
||||
let statuses = dsl::prompt_audit_sessions
|
||||
.filter(dsl::id.eq_any([success_session, failed_session]))
|
||||
.select(dsl::status)
|
||||
.load::<String>(&mut conn)
|
||||
.await
|
||||
.expect("load session statuses");
|
||||
assert!(statuses.iter().any(|status| status == "completed"));
|
||||
assert!(statuses.iter().any(|status| status == "failed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_hook_event_methods_write_rows() {
|
||||
let db_path = test_db_path("prompt_hook_event_methods_write_rows");
|
||||
let database_url = db_path.display().to_string();
|
||||
let pool = create_pool(Some(&database_url))
|
||||
.await
|
||||
.expect("create sqlite pool");
|
||||
let store = DieselMemoryStore::new(pool.clone());
|
||||
let project = sample_project();
|
||||
let now = chrono::Utc::now().naive_utc();
|
||||
|
||||
let session_id = store
|
||||
.start_prompt_audit_session(&project, 99, "pull_request_review", now)
|
||||
.await
|
||||
.expect("start audit session");
|
||||
|
||||
store
|
||||
.log_on_completion_call(session_id, now, 1, "{\"role\":\"user\"}", "[]")
|
||||
.await
|
||||
.expect("log completion call");
|
||||
store
|
||||
.log_on_completion_response(session_id, now, 2, "{\"role\":\"user\"}", "[]", 10, 20, 30)
|
||||
.await
|
||||
.expect("log completion response");
|
||||
store
|
||||
.log_on_tool_call(
|
||||
session_id,
|
||||
now,
|
||||
3,
|
||||
"ast_grep",
|
||||
Some("call_1"),
|
||||
"internal_1",
|
||||
"{\"query\":\"foo\"}",
|
||||
"continue",
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("log tool call");
|
||||
store
|
||||
.log_on_tool_result(
|
||||
session_id,
|
||||
now,
|
||||
4,
|
||||
"ast_grep",
|
||||
Some("call_1"),
|
||||
"internal_1",
|
||||
"{\"query\":\"foo\"}",
|
||||
"{\"hits\":[]}",
|
||||
)
|
||||
.await
|
||||
.expect("log tool result");
|
||||
store
|
||||
.log_on_text_delta(session_id, now, 5, "foo", "foobar")
|
||||
.await
|
||||
.expect("log text delta");
|
||||
store
|
||||
.log_on_tool_call_delta(
|
||||
session_id,
|
||||
now,
|
||||
6,
|
||||
"call_2",
|
||||
"internal_2",
|
||||
Some("readfile"),
|
||||
"{\"path\":\"src/lib.rs\"}",
|
||||
)
|
||||
.await
|
||||
.expect("log tool call delta");
|
||||
store
|
||||
.log_on_stream_completion_response_finish(session_id, now, 7, "{}", "{\"done\":true}")
|
||||
.await
|
||||
.expect("log stream finish");
|
||||
|
||||
let mut conn = pool.get().await.expect("get pooled connection");
|
||||
use codetaker_db::schema::{
|
||||
prompt_hook_on_completion_call_events::dsl as completion_call_dsl,
|
||||
prompt_hook_on_completion_response_events::dsl as completion_response_dsl,
|
||||
prompt_hook_on_stream_completion_response_finish_events::dsl as stream_finish_dsl,
|
||||
prompt_hook_on_text_delta_events::dsl as text_delta_dsl,
|
||||
prompt_hook_on_tool_call_delta_events::dsl as tool_call_delta_dsl,
|
||||
prompt_hook_on_tool_call_events::dsl as tool_call_dsl,
|
||||
prompt_hook_on_tool_result_events::dsl as tool_result_dsl,
|
||||
};
|
||||
|
||||
let completion_call_count = completion_call_dsl::prompt_hook_on_completion_call_events
|
||||
.filter(completion_call_dsl::session_id.eq(session_id))
|
||||
.count()
|
||||
.get_result::<i64>(&mut conn)
|
||||
.await
|
||||
.expect("count completion call events");
|
||||
assert_eq!(completion_call_count, 1);
|
||||
|
||||
let completion_response_count = completion_response_dsl::prompt_hook_on_completion_response_events
|
||||
.filter(completion_response_dsl::session_id.eq(session_id))
|
||||
.count()
|
||||
.get_result::<i64>(&mut conn)
|
||||
.await
|
||||
.expect("count completion response events");
|
||||
assert_eq!(completion_response_count, 1);
|
||||
|
||||
let tool_call_count = tool_call_dsl::prompt_hook_on_tool_call_events
|
||||
.filter(tool_call_dsl::session_id.eq(session_id))
|
||||
.count()
|
||||
.get_result::<i64>(&mut conn)
|
||||
.await
|
||||
.expect("count tool call events");
|
||||
assert_eq!(tool_call_count, 1);
|
||||
|
||||
let tool_result_count = tool_result_dsl::prompt_hook_on_tool_result_events
|
||||
.filter(tool_result_dsl::session_id.eq(session_id))
|
||||
.count()
|
||||
.get_result::<i64>(&mut conn)
|
||||
.await
|
||||
.expect("count tool result events");
|
||||
assert_eq!(tool_result_count, 1);
|
||||
|
||||
let text_delta_count = text_delta_dsl::prompt_hook_on_text_delta_events
|
||||
.filter(text_delta_dsl::session_id.eq(session_id))
|
||||
.count()
|
||||
.get_result::<i64>(&mut conn)
|
||||
.await
|
||||
.expect("count text delta events");
|
||||
assert_eq!(text_delta_count, 1);
|
||||
|
||||
let tool_delta_count = tool_call_delta_dsl::prompt_hook_on_tool_call_delta_events
|
||||
.filter(tool_call_delta_dsl::session_id.eq(session_id))
|
||||
.count()
|
||||
.get_result::<i64>(&mut conn)
|
||||
.await
|
||||
.expect("count tool call delta events");
|
||||
assert_eq!(tool_delta_count, 1);
|
||||
|
||||
let stream_finish_count =
|
||||
stream_finish_dsl::prompt_hook_on_stream_completion_response_finish_events
|
||||
.filter(stream_finish_dsl::session_id.eq(session_id))
|
||||
.count()
|
||||
.get_result::<i64>(&mut conn)
|
||||
.await
|
||||
.expect("count stream finish events");
|
||||
assert_eq!(stream_finish_count, 1);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user