misc: initial code

This commit is contained in:
hdbg
2026-02-27 10:27:24 +01:00
commit 91036f4188
32 changed files with 36435 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
[package]
name = "codetaker-agent"
version = "0.1.0"
edition = "2024"
[dependencies]
chrono.workspace = true
codetaker-db.workspace = true
diesel.workspace = true
diesel-async.workspace = true
git2.workspace = true
rig-core.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
tracing.workspace = true
[dev-dependencies]
tokio.workspace = true

View File

@@ -0,0 +1,198 @@
use std::fmt::Write;
use git2::Repository;
use rig::agent::AgentBuilder;
use rig::client::{Client, CompletionClient};
use rig::completion::{StructuredOutputError, TypedPrompt};
use crate::error::{AgentError, AgentResult};
use crate::git_access::{self};
use crate::memory::{MemoryStore, ProjectContextSnapshot};
use crate::tools::{AstGrepTool, FileContext, SearchHit};
use crate::types::{ConversationInput, ConversationOutput};
const CONVERSATION_RESPONSE_PREAMBLE: &str = r#"
You are a code review assistant responding in an inline review thread.
Be concise, technical, and specific.
Use available tools when you need extra repository context.
Respond using the ConversationOutput JSON schema.
"#;
pub struct ConversationAnswerAgent<Ext, H, M> {
client: Client<Ext, H>,
conversation_model: String,
ast_grep: AstGrepTool,
memory: M,
}
impl<Ext, H, M> ConversationAnswerAgent<Ext, H, M> {
pub fn new(
client: Client<Ext, H>,
conversation_model: impl Into<String>,
ast_grep: AstGrepTool,
memory: M,
) -> Self {
Self {
client,
conversation_model: conversation_model.into(),
ast_grep,
memory,
}
}
}
impl<Ext, H, M> ConversationAnswerAgent<Ext, H, M>
where
Client<Ext, H>: CompletionClient,
M: MemoryStore,
{
pub async fn conversation_response(
&self,
repo: &Repository,
input: ConversationInput,
) -> AgentResult<ConversationOutput> {
git_access::ensure_bare_repository(repo)?;
let memory_snapshot = self.memory.project_context_snapshot(&input.project).await?;
let anchor_context = git_access::read_file_context_at_ref(
repo,
&input.head_ref,
&input.anchor_file,
input.anchor_line,
40,
)?;
let ast_grep = self.ast_grep.bind_to_request(repo, &input.head_ref)?;
let search_hits = self.collect_conversation_search_hits(&ast_grep, &input);
let prompt =
build_conversation_prompt(&input, &memory_snapshot, &anchor_context, &search_hits);
let conversation_model = self.client.completion_model(&self.conversation_model);
let agent = AgentBuilder::new(conversation_model)
.preamble(CONVERSATION_RESPONSE_PREAMBLE)
.tool(ast_grep)
.temperature(0.1)
.output_schema::<ConversationOutput>()
.build();
agent
.prompt_typed::<ConversationOutput>(prompt)
.await
.map_err(|err| match err {
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
model: self.conversation_model.clone(),
message: prompt_err.to_string(),
},
StructuredOutputError::DeserializationError(deser_err) => {
AgentError::OutputValidationError {
message: format!("failed to deserialize conversation output: {deser_err}"),
raw_output: None,
}
}
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
message: "conversation response was empty".to_owned(),
raw_output: None,
},
})
}
fn collect_conversation_search_hits(
&self,
ast_grep: &AstGrepTool,
input: &ConversationInput,
) -> Vec<SearchHit> {
match ast_grep.search_pattern(&input.initial_comment) {
Ok(hits) => hits.into_iter().take(10).collect(),
Err(err) => {
tracing::warn!(error = %err, "failed to fetch search hits for conversation context");
Vec::new()
}
}
}
}
fn build_conversation_prompt(
input: &ConversationInput,
memory: &ProjectContextSnapshot,
anchor_context: &FileContext,
search_hits: &[SearchHit],
) -> String {
let mut prompt = String::new();
writeln!(
prompt,
"Project: {}/{} ({})",
input.project.owner,
input.project.repo,
input.project.forge.as_db_value()
)
.ok();
writeln!(prompt, "Head reference: '{}'", input.head_ref).ok();
writeln!(
prompt,
"Anchor: {}:{}",
input.anchor_file, input.anchor_line
)
.ok();
prompt.push_str("\nInitial review comment:\n");
prompt.push_str(&input.initial_comment);
prompt.push_str("\n\nThread message chain:\n");
for message in &input.message_chain {
writeln!(
prompt,
"- [{:?} @ {}] {}",
message.author, message.created_at, message.body
)
.ok();
}
prompt.push_str("\nAnchor file context:\n");
writeln!(
prompt,
"File: {} (lines {}-{})\n{}",
anchor_context.file,
anchor_context.line_start,
anchor_context.line_end,
anchor_context.snippet
)
.ok();
prompt.push_str("\nPer-project memory (JSON map):\n");
let memory_json =
serde_json::to_string_pretty(&memory.entries).unwrap_or_else(|_| "{}".to_owned());
prompt.push_str(&memory_json);
prompt.push_str("\n\nPer-project summaries:\n");
if memory.summaries.is_empty() {
prompt.push_str("(none)\n");
} else {
for summary in &memory.summaries {
writeln!(
prompt,
"- [{} @ {}] {}",
summary.summary_type, summary.updated_at, summary.content
)
.ok();
}
}
prompt.push_str("\nTool search hits:\n");
if search_hits.is_empty() {
prompt.push_str("(none)\n");
} else {
for hit in search_hits {
writeln!(
prompt,
"- {}:{} {:?} {}",
hit.file, hit.line, hit.column, hit.snippet
)
.ok();
}
}
prompt.push_str("\nRespond to the latest user concern in this thread.");
prompt
}

View File

@@ -0,0 +1,5 @@
mod conversation_answer;
mod pull_request_review;
pub use conversation_answer::ConversationAnswerAgent;
pub use pull_request_review::PullRequestReviewAgent;

View File

@@ -0,0 +1,213 @@
use std::fmt::Write;
use git2::Repository;
use rig::agent::AgentBuilder;
use rig::client::{Client, CompletionClient};
use rig::completion::{StructuredOutputError, TypedPrompt};
use crate::error::{AgentError, AgentResult};
use crate::git_access::{self, PullRequestMaterial};
use crate::memory::{MemoryStore, ProjectContextSnapshot};
use crate::tools::{AstGrepTool, FileContext};
use crate::types::{PullRequestReviewInput, PullRequestReviewOutput};
const PULL_REQUEST_REVIEW_PREAMBLE: &str = r#"
You are a code review agent.
Analyze pull request diffs and respond only with valid JSON.
Use available tools when you need extra repository context.
Schema:
{
"review_result": "Approve" | "RequestChanges",
"global_comment": string,
"comments": [
{
"comment": string,
"file": string,
"line": integer
}
]
}
Line numbers must target the head/new file version.
Do not include extra keys.
"#;
pub struct PullRequestReviewAgent<Ext, H, M> {
client: Client<Ext, H>,
review_model: String,
ast_grep: AstGrepTool,
memory: M,
}
impl<Ext, H, M> PullRequestReviewAgent<Ext, H, M> {
pub fn new(
client: Client<Ext, H>,
review_model: impl Into<String>,
ast_grep: AstGrepTool,
memory: M,
) -> Self {
Self {
client,
review_model: review_model.into(),
ast_grep,
memory,
}
}
}
impl<Ext, H, M> PullRequestReviewAgent<Ext, H, M>
where
Client<Ext, H>: CompletionClient,
M: MemoryStore,
{
pub async fn pull_request_review(
&self,
repo: &Repository,
input: PullRequestReviewInput,
) -> AgentResult<PullRequestReviewOutput> {
git_access::ensure_bare_repository(repo)?;
let memory_snapshot = self.memory.project_context_snapshot(&input.project).await?;
let pr_material = git_access::compute_pr_material(repo, &input.base_ref, &input.head_ref)?;
let file_contexts = self.collect_diff_file_contexts(repo, &input.head_ref, &pr_material)?;
let prompt =
build_pull_request_prompt(&input, &memory_snapshot, &pr_material, &file_contexts);
let ast_grep = self.ast_grep.bind_to_request(repo, &input.head_ref)?;
let review_model = self.client.completion_model(&self.review_model);
let agent = AgentBuilder::new(review_model)
.preamble(PULL_REQUEST_REVIEW_PREAMBLE)
.tool(ast_grep)
.temperature(0.0)
.output_schema::<PullRequestReviewOutput>()
.build();
agent
.prompt_typed::<PullRequestReviewOutput>(prompt)
.await
.map_err(|err| match err {
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
model: self.review_model.clone(),
message: prompt_err.to_string(),
},
StructuredOutputError::DeserializationError(deser_err) => {
AgentError::OutputValidationError {
message: format!(
"failed to deserialize pull request review output: {deser_err}"
),
raw_output: None,
}
}
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
message: "pull request review response was empty".to_owned(),
raw_output: None,
},
})
}
fn collect_diff_file_contexts(
&self,
repo: &Repository,
head_ref: &str,
pr_material: &PullRequestMaterial,
) -> AgentResult<Vec<FileContext>> {
let mut contexts = Vec::new();
for file in pr_material.changed_files.iter().take(5) {
let line = pr_material
.first_changed_head_lines
.get(file)
.copied()
.unwrap_or(1);
match git_access::read_file_context_at_ref(repo, head_ref, file, line, 40) {
Ok(context) => contexts.push(context),
Err(err) => {
tracing::warn!(
file = %file,
error = %err,
"failed to fetch file context for pull request review"
);
}
}
}
Ok(contexts)
}
}
fn build_pull_request_prompt(
input: &PullRequestReviewInput,
memory: &ProjectContextSnapshot,
material: &PullRequestMaterial,
file_contexts: &[FileContext],
) -> String {
let mut prompt = String::new();
writeln!(
prompt,
"Project: {}/{} ({})",
input.project.owner,
input.project.repo,
input.project.forge.as_db_value()
)
.ok();
writeln!(
prompt,
"Review range: base_ref='{}', head_ref='{}'",
input.base_ref, input.head_ref
)
.ok();
writeln!(
prompt,
"Resolved OIDs: base={}, head={}, merge_base={}",
material.base_oid, material.head_oid, material.merge_base_oid
)
.ok();
prompt.push_str("\nChanged files:\n");
if material.changed_files.is_empty() {
prompt.push_str("(none)\n");
} else {
for file in &material.changed_files {
writeln!(prompt, "- {file}").ok();
}
}
prompt.push_str("\nPer-project memory (JSON map):\n");
let memory_json =
serde_json::to_string_pretty(&memory.entries).unwrap_or_else(|_| "{}".to_owned());
prompt.push_str(&memory_json);
prompt.push_str("\n\nPer-project summaries:\n");
if memory.summaries.is_empty() {
prompt.push_str("(none)\n");
} else {
for summary in &memory.summaries {
writeln!(
prompt,
"- [{} @ {}] {}",
summary.summary_type, summary.updated_at, summary.content
)
.ok();
}
}
prompt.push_str("\nAdditional file contexts:\n");
if file_contexts.is_empty() {
prompt.push_str("(none)\n");
} else {
for context in file_contexts {
writeln!(
prompt,
"File: {} (lines {}-{})\n{}",
context.file, context.line_start, context.line_end, context.snippet
)
.ok();
}
}
prompt.push_str("\nPull request diff:\n");
prompt.push_str(&material.patch);
prompt.push_str("\n\nReturn only JSON matching the schema exactly.");
prompt
}

View File

@@ -0,0 +1,59 @@
use std::error::Error;
use std::fmt::{Display, Formatter};
pub type AgentResult<T> = Result<T, AgentError>;
#[derive(Debug)]
pub enum AgentError {
MemoryError {
message: String,
},
GitError {
operation: String,
message: String,
},
ToolError {
tool: String,
message: String,
},
ModelError {
model: String,
message: String,
},
OutputValidationError {
message: String,
raw_output: Option<String>,
},
ConfigError {
message: String,
},
}
impl Display for AgentError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::MemoryError { message } => write!(f, "memory error: {message}"),
Self::GitError { operation, message } => {
write!(f, "git error ({operation}): {message}")
}
Self::ToolError { tool, message } => write!(f, "tool error ({tool}): {message}"),
Self::ModelError { model, message } => write!(f, "model error ({model}): {message}"),
Self::OutputValidationError {
message,
raw_output,
} => {
if let Some(raw_output) = raw_output {
write!(
f,
"output validation error: {message}; raw output: {raw_output}"
)
} else {
write!(f, "output validation error: {message}")
}
}
Self::ConfigError { message } => write!(f, "config error: {message}"),
}
}
}
impl Error for AgentError {}

View File

@@ -0,0 +1,465 @@
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::{Path, PathBuf};
use git2::{Commit, DiffFormat, ObjectType, Oid, Repository, Tree};
use crate::error::{AgentError, AgentResult};
use crate::tools::FileContext;
pub const MAX_PATCH_CHARS: usize = 250_000;
const TRUNCATION_NOTICE: &str = "\n\n[DIFF TRUNCATED: content exceeds configured prompt limit]\n";
const NON_TEXT_BLOB_NOTICE: &str = "[non-text blob omitted]";
#[derive(Debug, Clone)]
pub struct PullRequestMaterial {
pub base_oid: Oid,
pub head_oid: Oid,
pub merge_base_oid: Oid,
pub patch: String,
pub changed_files: Vec<String>,
pub first_changed_head_lines: HashMap<String, i32>,
}
pub struct TempRepoSnapshot {
path: PathBuf,
}
impl TempRepoSnapshot {
pub fn path(&self) -> &Path {
&self.path
}
}
impl Drop for TempRepoSnapshot {
fn drop(&mut self) {
if let Err(err) = fs::remove_dir_all(&self.path) {
tracing::warn!(
path = %self.path.display(),
error = %err,
"failed to clean up temporary repository snapshot"
);
}
}
}
pub fn ensure_bare_repository(repo: &Repository) -> AgentResult<()> {
if !repo.is_bare() {
return Err(AgentError::ConfigError {
message: "expected bare git repository for this agent mode".to_owned(),
});
}
Ok(())
}
pub fn resolve_commit<'repo>(
repo: &'repo Repository,
spec: &str,
) -> AgentResult<(Oid, Commit<'repo>)> {
let object = repo
.revparse_ext(spec)
.or_else(|_| repo.revparse_single(spec).map(|obj| (obj, None)))
.map(|(obj, _)| obj)
.map_err(|err| git_error(format!("resolve revision '{spec}'"), err))?;
let commit = object
.peel_to_commit()
.map_err(|err| git_error(format!("peel revision '{spec}' to commit"), err))?;
Ok((commit.id(), commit))
}
pub fn compute_pr_material(
repo: &Repository,
base_ref: &str,
head_ref: &str,
) -> AgentResult<PullRequestMaterial> {
let (base_oid, _base_commit) = resolve_commit(repo, base_ref)?;
let (head_oid, head_commit) = resolve_commit(repo, head_ref)?;
let merge_base_oid = repo.merge_base(base_oid, head_oid).map_err(|err| {
git_error(
format!("compute merge base between '{base_ref}' and '{head_ref}'"),
err,
)
})?;
let merge_base_commit = repo
.find_commit(merge_base_oid)
.map_err(|err| git_error(format!("load merge-base commit {merge_base_oid}"), err))?;
let base_tree = merge_base_commit
.tree()
.map_err(|err| git_error(format!("load merge-base tree for {merge_base_oid}"), err))?;
let head_tree = head_commit
.tree()
.map_err(|err| git_error(format!("load head tree for {head_oid}"), err))?;
let diff = repo
.diff_tree_to_tree(Some(&base_tree), Some(&head_tree), None)
.map_err(|err| git_error("generate tree diff".to_owned(), err))?;
let mut changed_files = Vec::new();
let mut seen_files = HashSet::new();
for delta in diff.deltas() {
if let Some(path) = delta.new_file().path().or(delta.old_file().path()) {
let file = path.to_string_lossy().to_string();
if seen_files.insert(file.clone()) {
changed_files.push(file);
}
}
}
let mut first_changed_head_lines: HashMap<String, i32> = HashMap::new();
diff.foreach(
&mut |_, _| true,
None,
None,
Some(&mut |delta, _hunk, line| {
let origin = line.origin();
if !matches!(origin, '+' | '>') {
return true;
}
let Some(line_no) = line.new_lineno() else {
return true;
};
let Some(path) = delta.new_file().path().or(delta.old_file().path()) else {
return true;
};
let file = path.to_string_lossy().to_string();
first_changed_head_lines
.entry(file)
.or_insert(line_no as i32);
true
}),
)
.map_err(|err| git_error("walk diff lines".to_owned(), err))?;
let mut patch = String::new();
diff.print(DiffFormat::Patch, |_delta, _hunk, line| {
patch.push_str(&String::from_utf8_lossy(line.content()));
true
})
.map_err(|err| git_error("render diff patch".to_owned(), err))?;
Ok(PullRequestMaterial {
base_oid,
head_oid,
merge_base_oid,
patch: truncate_patch(patch, MAX_PATCH_CHARS),
changed_files,
first_changed_head_lines,
})
}
pub fn read_file_context_at_ref(
repo: &Repository,
head_ref: &str,
file: &str,
line: i32,
radius: i32,
) -> AgentResult<FileContext> {
let (_head_oid, head_commit) = resolve_commit(repo, head_ref)?;
let tree = head_commit
.tree()
.map_err(|err| git_error(format!("load tree for '{head_ref}'"), err))?;
let entry = tree
.get_path(Path::new(file))
.map_err(|err| git_error(format!("lookup file '{file}' in '{head_ref}'"), err))?;
let object = entry
.to_object(repo)
.map_err(|err| git_error(format!("load git object for '{file}' in '{head_ref}'"), err))?;
let blob = object
.peel_to_blob()
.map_err(|err| git_error(format!("peel '{file}' to blob in '{head_ref}'"), err))?;
if blob.is_binary() {
return Ok(non_text_context(file, line));
}
let content = match std::str::from_utf8(blob.content()) {
Ok(content) => content,
Err(_) => return Ok(non_text_context(file, line)),
};
Ok(text_context_from_content(file, content, line, radius))
}
pub fn materialize_ref_to_temp_dir(
repo: &Repository,
head_ref: &str,
) -> AgentResult<TempRepoSnapshot> {
let (_head_oid, head_commit) = resolve_commit(repo, head_ref)?;
let tree = head_commit
.tree()
.map_err(|err| git_error(format!("load tree for '{head_ref}'"), err))?;
let snapshot_dir = unique_temp_snapshot_dir();
fs::create_dir_all(&snapshot_dir).map_err(|err| AgentError::GitError {
operation: "create temporary repository snapshot directory".to_owned(),
message: err.to_string(),
})?;
if let Err(err) = write_tree_to_dir(repo, &tree, &snapshot_dir) {
let _ = fs::remove_dir_all(&snapshot_dir);
return Err(err);
}
Ok(TempRepoSnapshot { path: snapshot_dir })
}
fn write_tree_to_dir(repo: &Repository, tree: &Tree<'_>, root: &Path) -> AgentResult<()> {
for entry in tree.iter() {
let name = entry.name().ok_or_else(|| AgentError::GitError {
operation: "decode tree entry name".to_owned(),
message: "encountered non-UTF8 tree entry name".to_owned(),
})?;
let target_path = root.join(name);
match entry.kind() {
Some(ObjectType::Tree) => {
fs::create_dir_all(&target_path).map_err(|err| AgentError::GitError {
operation: format!(
"create directory '{}' while materializing tree",
target_path.display()
),
message: err.to_string(),
})?;
let subtree = repo
.find_tree(entry.id())
.map_err(|err| git_error(format!("load subtree {}", entry.id()), err))?;
write_tree_to_dir(repo, &subtree, &target_path)?;
}
Some(ObjectType::Blob) => {
if let Some(parent) = target_path.parent() {
fs::create_dir_all(parent).map_err(|err| AgentError::GitError {
operation: format!(
"create parent directory '{}' for snapshot file",
parent.display()
),
message: err.to_string(),
})?;
}
let blob = repo
.find_blob(entry.id())
.map_err(|err| git_error(format!("load blob {}", entry.id()), err))?;
fs::write(&target_path, blob.content()).map_err(|err| AgentError::GitError {
operation: format!("write snapshot file '{}'", target_path.display()),
message: err.to_string(),
})?;
}
_ => {
// Non-blob/tree entries are ignored for snapshot materialization.
}
}
}
Ok(())
}
fn truncate_patch(mut patch: String, max_chars: usize) -> String {
if patch.chars().count() <= max_chars {
return patch;
}
patch = patch.chars().take(max_chars).collect();
patch.push_str(TRUNCATION_NOTICE);
patch
}
fn non_text_context(file: &str, line: i32) -> FileContext {
let safe_line = line.max(1);
FileContext {
file: file.to_owned(),
line_start: safe_line,
line_end: safe_line,
snippet: NON_TEXT_BLOB_NOTICE.to_owned(),
}
}
fn text_context_from_content(file: &str, content: &str, line: i32, radius: i32) -> FileContext {
let lines: Vec<&str> = content.lines().collect();
if lines.is_empty() {
return FileContext {
file: file.to_owned(),
line_start: 1,
line_end: 1,
snippet: String::new(),
};
}
let safe_line = line.clamp(1, lines.len() as i32);
let safe_radius = radius.max(0);
let start = (safe_line - safe_radius).max(1) as usize;
let end = (safe_line + safe_radius).min(lines.len() as i32) as usize;
let mut snippet = String::new();
for (idx, line_content) in lines[start - 1..end].iter().enumerate() {
let current_line = start + idx;
snippet.push_str(&format!("{current_line:>6} | {line_content}\n"));
}
FileContext {
file: file.to_owned(),
line_start: start as i32,
line_end: end as i32,
snippet,
}
}
fn unique_temp_snapshot_dir() -> PathBuf {
let nonce = format!(
"codetaker_ast_grep_snapshot_{}_{}",
std::process::id(),
chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
);
std::env::temp_dir().join(nonce)
}
fn git_error(operation: impl Into<String>, err: git2::Error) -> AgentError {
AgentError::GitError {
operation: operation.into(),
message: err.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use git2::{Repository, Signature};
fn test_repo_path(test_name: &str) -> PathBuf {
let nonce = format!(
"{}_{}_{}",
test_name,
std::process::id(),
chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
);
std::env::temp_dir().join(nonce)
}
fn create_commit_with_files(
repo: &Repository,
target_ref: &str,
parent: Option<Oid>,
message: &str,
files: &[(&str, &[u8])],
) -> Oid {
let mut builder = repo.treebuilder(None).expect("create treebuilder");
for (path, content) in files {
let blob_id = repo.blob(content).expect("write blob");
builder
.insert(*path, blob_id, 0o100644)
.expect("insert tree entry");
}
let tree_id = builder.write().expect("write tree");
let tree = repo.find_tree(tree_id).expect("find tree");
let sig = Signature::now("Codetaker", "codetaker@example.com").expect("signature");
let parent_commits = parent
.map(|oid| vec![repo.find_commit(oid).expect("parent commit")])
.unwrap_or_default();
let parent_refs: Vec<&Commit<'_>> = parent_commits.iter().collect();
repo.commit(Some(target_ref), &sig, &sig, message, &tree, &parent_refs)
.expect("create commit")
}
fn bare_repo_fixture() -> (PathBuf, Repository, Oid, Oid, Oid) {
let repo_path = test_repo_path("git_access_fixture");
let repo = Repository::init_bare(&repo_path).expect("init bare repo");
let root = create_commit_with_files(
&repo,
"refs/heads/main",
None,
"root",
&[("a.txt", b"line1\nline2\n")],
);
let main_head = create_commit_with_files(
&repo,
"refs/heads/main",
Some(root),
"main update",
&[
("a.txt", b"line1\nline2\n"),
("main_only.txt", b"main branch only\n"),
],
);
let feature_head = create_commit_with_files(
&repo,
"refs/heads/feature",
Some(root),
"feature update",
&[("a.txt", b"line1\nline2_feature\n")],
);
(repo_path, repo, root, main_head, feature_head)
}
#[test]
fn resolve_commit_handles_valid_and_invalid_specs() {
let (repo_path, repo, _root, _main_head, feature_head) = bare_repo_fixture();
let (oid, _commit) = resolve_commit(&repo, "refs/heads/feature").expect("resolve feature");
assert_eq!(oid, feature_head);
let err = resolve_commit(&repo, "refs/heads/does-not-exist").expect_err("expected error");
assert!(matches!(err, AgentError::GitError { .. }));
fs::remove_dir_all(repo_path).expect("cleanup fixture");
}
#[test]
fn compute_pr_material_uses_merge_base_to_head_range() {
let (repo_path, repo, merge_base, _main_head, _feature_head) = bare_repo_fixture();
let material = compute_pr_material(&repo, "refs/heads/main", "refs/heads/feature")
.expect("compute pr material");
assert_eq!(material.merge_base_oid, merge_base);
assert!(material.patch.contains("line2_feature"));
assert!(!material.patch.contains("main branch only"));
assert!(material.changed_files.contains(&"a.txt".to_owned()));
assert_eq!(material.first_changed_head_lines.get("a.txt"), Some(&2));
fs::remove_dir_all(repo_path).expect("cleanup fixture");
}
#[test]
fn read_file_context_handles_text_and_binary_blobs() {
let repo_path = test_repo_path("git_access_binary_fixture");
let repo = Repository::init_bare(&repo_path).expect("init bare repo");
let _head = create_commit_with_files(
&repo,
"refs/heads/main",
None,
"with binary",
&[
("a.txt", b"one\ntwo\nthree\n"),
("bin.dat", &[0, 255, 0, 1]),
],
);
let text_context = read_file_context_at_ref(&repo, "refs/heads/main", "a.txt", 2, 1)
.expect("text context");
assert!(text_context.snippet.contains("two"));
let binary_context = read_file_context_at_ref(&repo, "refs/heads/main", "bin.dat", 1, 1)
.expect("binary context");
assert_eq!(binary_context.snippet, NON_TEXT_BLOB_NOTICE);
fs::remove_dir_all(repo_path).expect("cleanup fixture");
}
}

View File

@@ -0,0 +1,16 @@
pub mod agent;
pub mod error;
pub mod git_access;
pub mod memory;
pub mod tools;
pub mod types;
pub use agent::{ConversationAnswerAgent, PullRequestReviewAgent};
pub use error::{AgentError, AgentResult};
pub use git_access::PullRequestMaterial;
pub use memory::{DieselMemoryStore, MemoryStore, ProjectContextSnapshot};
pub use tools::{AstGrepArgs, AstGrepOutput, AstGrepTool, FileContext, SearchHit};
pub use types::{
ConversationInput, ConversationOutput, Forge, MessageAuthor, ProjectRef,
PullRequestReviewInput, PullRequestReviewOutput, ReviewComment, ReviewResult, ThreadMessage,
};

View File

@@ -0,0 +1,65 @@
mod sqlite;
use std::collections::BTreeMap;
use chrono::NaiveDateTime;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::AgentResult;
use crate::types::{MessageAuthor, ProjectRef, ThreadMessage};
pub use sqlite::DieselMemoryStore;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemorySummary {
pub summary_type: String,
pub content: String,
pub updated_at: NaiveDateTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProjectContextSnapshot {
pub entries: BTreeMap<String, Value>,
pub summaries: Vec<MemorySummary>,
}
#[allow(async_fn_in_trait)]
pub trait MemoryStore {
async fn project_context_snapshot(
&self,
project: &ProjectRef,
) -> AgentResult<ProjectContextSnapshot>;
async fn upsert_memory_entry(
&self,
project: &ProjectRef,
key: &str,
value: &Value,
source: &str,
) -> AgentResult<()>;
async fn upsert_memory_summary(
&self,
project: &ProjectRef,
summary_type: &str,
content: &str,
) -> AgentResult<()>;
async fn create_review_thread(
&self,
project: &ProjectRef,
file: &str,
line: i32,
initial_comment: &str,
) -> AgentResult<i32>;
async fn append_thread_message(
&self,
thread_id: i32,
author: MessageAuthor,
body: &str,
) -> AgentResult<()>;
async fn load_thread_messages(&self, thread_id: i32) -> AgentResult<Vec<ThreadMessage>>;
}

View File

@@ -0,0 +1,319 @@
use std::collections::BTreeMap;
use chrono::Utc;
use codetaker_db::models::{
NewProjectMemoryEntryRow, NewProjectMemorySummaryRow, NewProjectRow, NewReviewThreadMessageRow,
NewReviewThreadRow, ProjectMemoryEntryRow, ProjectMemorySummaryRow, ReviewThreadMessageRow,
};
use codetaker_db::schema::{
project_memory_entries, project_memory_summaries, projects, review_thread_messages,
review_threads,
};
use codetaker_db::{DatabaseConnection, DatabasePool};
use diesel::OptionalExtension;
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use serde_json::Value;
use crate::error::{AgentError, AgentResult};
use crate::memory::{MemoryStore, MemorySummary, ProjectContextSnapshot};
use crate::types::{MessageAuthor, ProjectRef, ThreadMessage};
type PooledConnection<'a> =
diesel_async::pooled_connection::bb8::PooledConnection<'a, DatabaseConnection>;
#[derive(Debug, Clone)]
pub struct DieselMemoryStore {
pool: DatabasePool,
}
impl DieselMemoryStore {
pub fn new(pool: DatabasePool) -> Self {
Self { pool }
}
async fn get_conn(&self) -> AgentResult<PooledConnection<'_>> {
self.pool
.get()
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to get database connection from pool: {err}"),
})
}
async fn ensure_project_id(
conn: &mut PooledConnection<'_>,
project: &ProjectRef,
) -> AgentResult<i32> {
use codetaker_db::schema::projects::dsl;
let existing_id = dsl::projects
.filter(dsl::forge.eq(project.forge.as_db_value()))
.filter(dsl::owner.eq(&project.owner))
.filter(dsl::repo.eq(&project.repo))
.select(dsl::id)
.first::<i32>(conn)
.await
.optional()
.map_err(|err| AgentError::MemoryError {
message: format!("failed to query project id: {err}"),
})?;
if let Some(existing_id) = existing_id {
return Ok(existing_id);
}
let new_row = NewProjectRow {
forge: project.forge.as_db_value(),
owner: &project.owner,
repo: &project.repo,
};
diesel::insert_into(projects::table)
.values(&new_row)
.execute(conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to insert project row: {err}"),
})?;
dsl::projects
.filter(dsl::forge.eq(project.forge.as_db_value()))
.filter(dsl::owner.eq(&project.owner))
.filter(dsl::repo.eq(&project.repo))
.select(dsl::id)
.first::<i32>(conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to re-query inserted project id: {err}"),
})
}
}
impl MemoryStore for DieselMemoryStore {
async fn project_context_snapshot(
&self,
project: &ProjectRef,
) -> AgentResult<ProjectContextSnapshot> {
use codetaker_db::schema::project_memory_entries::dsl as entries_dsl;
use codetaker_db::schema::project_memory_summaries::dsl as summaries_dsl;
let mut conn = self.get_conn().await?;
let project_id = Self::ensure_project_id(&mut conn, project).await?;
let entry_rows = entries_dsl::project_memory_entries
.filter(entries_dsl::project_id.eq(project_id))
.order(entries_dsl::key.asc())
.load::<ProjectMemoryEntryRow>(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to load project memory entries: {err}"),
})?;
let summary_rows = summaries_dsl::project_memory_summaries
.filter(summaries_dsl::project_id.eq(project_id))
.order(summaries_dsl::summary_type.asc())
.load::<ProjectMemorySummaryRow>(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to load project memory summaries: {err}"),
})?;
let mut entries = BTreeMap::new();
for row in entry_rows {
let parsed: Value =
serde_json::from_str(&row.value_json).map_err(|err| AgentError::MemoryError {
message: format!(
"failed to parse memory entry '{}' JSON value '{}': {err}",
row.key, row.value_json
),
})?;
entries.insert(row.key, parsed);
}
let summaries = summary_rows
.into_iter()
.map(|row| MemorySummary {
summary_type: row.summary_type,
content: row.content,
updated_at: row.updated_at,
})
.collect();
Ok(ProjectContextSnapshot { entries, summaries })
}
async fn upsert_memory_entry(
&self,
project: &ProjectRef,
key: &str,
value: &Value,
source: &str,
) -> AgentResult<()> {
use codetaker_db::schema::project_memory_entries::dsl;
let mut conn = self.get_conn().await?;
let project_id = Self::ensure_project_id(&mut conn, project).await?;
let now = Utc::now().naive_utc();
let value_json = serde_json::to_string(value).map_err(|err| AgentError::MemoryError {
message: format!("failed to serialize memory entry '{key}': {err}"),
})?;
let new_row = NewProjectMemoryEntryRow {
project_id,
key,
value_json: &value_json,
source,
updated_at: now,
};
diesel::insert_into(project_memory_entries::table)
.values(&new_row)
.on_conflict((dsl::project_id, dsl::key))
.do_update()
.set((
dsl::value_json.eq(&value_json),
dsl::source.eq(source),
dsl::updated_at.eq(now),
))
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to upsert memory entry '{key}': {err}"),
})?;
Ok(())
}
async fn upsert_memory_summary(
&self,
project: &ProjectRef,
summary_type: &str,
content: &str,
) -> AgentResult<()> {
use codetaker_db::schema::project_memory_summaries::dsl;
let mut conn = self.get_conn().await?;
let project_id = Self::ensure_project_id(&mut conn, project).await?;
let now = Utc::now().naive_utc();
let new_row = NewProjectMemorySummaryRow {
project_id,
summary_type,
content,
updated_at: now,
};
diesel::insert_into(project_memory_summaries::table)
.values(&new_row)
.on_conflict((dsl::project_id, dsl::summary_type))
.do_update()
.set((dsl::content.eq(content), dsl::updated_at.eq(now)))
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to upsert memory summary '{summary_type}': {err}"),
})?;
Ok(())
}
async fn create_review_thread(
&self,
project: &ProjectRef,
file: &str,
line: i32,
initial_comment: &str,
) -> AgentResult<i32> {
use codetaker_db::schema::review_threads::dsl;
let mut conn = self.get_conn().await?;
let project_id = Self::ensure_project_id(&mut conn, project).await?;
let now = Utc::now().naive_utc();
let new_row = NewReviewThreadRow {
project_id,
file,
line,
initial_comment,
created_at: now,
};
diesel::insert_into(review_threads::table)
.values(&new_row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to create review thread: {err}"),
})?;
dsl::review_threads
.filter(dsl::project_id.eq(project_id))
.filter(dsl::file.eq(file))
.filter(dsl::line.eq(line))
.filter(dsl::initial_comment.eq(initial_comment))
.order(dsl::id.desc())
.select(dsl::id)
.first::<i32>(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to fetch inserted review thread id: {err}"),
})
}
async fn append_thread_message(
&self,
thread_id: i32,
author: MessageAuthor,
body: &str,
) -> AgentResult<()> {
let mut conn = self.get_conn().await?;
let now = Utc::now().naive_utc();
let row = NewReviewThreadMessageRow {
thread_id,
author: author.as_db_value(),
body,
created_at: now,
};
diesel::insert_into(review_thread_messages::table)
.values(&row)
.execute(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to append review thread message: {err}"),
})?;
Ok(())
}
async fn load_thread_messages(&self, thread_id: i32) -> AgentResult<Vec<ThreadMessage>> {
use codetaker_db::schema::review_thread_messages::dsl;
let mut conn = self.get_conn().await?;
let rows = dsl::review_thread_messages
.filter(dsl::thread_id.eq(thread_id))
.order(dsl::created_at.asc())
.load::<ReviewThreadMessageRow>(&mut conn)
.await
.map_err(|err| AgentError::MemoryError {
message: format!("failed to load review thread messages: {err}"),
})?;
rows.into_iter()
.map(|row| {
let author = MessageAuthor::from_db_value(&row.author).ok_or_else(|| {
AgentError::MemoryError {
message: format!("invalid message author in database: {}", row.author),
}
})?;
let created_at = chrono::DateTime::from_naive_utc_and_offset(row.created_at, Utc);
Ok(ThreadMessage {
author,
body: row.body,
created_at,
})
})
.collect()
}
}

View File

@@ -0,0 +1,243 @@
use std::path::{Path, PathBuf};
use std::process::Command;
use git2::Repository;
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::{AgentError, AgentResult};
use crate::git_access;
use crate::tools::{FileContext, SearchHit};
#[derive(Debug, Clone)]
pub struct AstGrepTool {
ast_grep_bin: String,
repo_git_dir: Option<PathBuf>,
head_ref: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "operation", rename_all = "snake_case")]
pub enum AstGrepArgs {
SearchSymbol {
query: String,
},
SearchPattern {
query: String,
},
GetFileContext {
file: String,
line: i32,
radius: i32,
},
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "operation", rename_all = "snake_case")]
pub enum AstGrepOutput {
SearchHits { hits: Vec<SearchHit> },
FileContext { context: FileContext },
}
impl Default for AstGrepTool {
fn default() -> Self {
Self::new("ast-grep")
}
}
impl AstGrepTool {
pub fn new(ast_grep_bin: impl Into<String>) -> Self {
Self {
ast_grep_bin: ast_grep_bin.into(),
repo_git_dir: None,
head_ref: None,
}
}
pub fn bind_to_request(
&self,
repo: &Repository,
head_ref: impl Into<String>,
) -> AgentResult<Self> {
git_access::ensure_bare_repository(repo)?;
Ok(Self {
ast_grep_bin: self.ast_grep_bin.clone(),
repo_git_dir: Some(repo.path().to_path_buf()),
head_ref: Some(head_ref.into()),
})
}
fn bound_state(&self) -> AgentResult<(PathBuf, String)> {
let repo_git_dir = self
.repo_git_dir
.clone()
.ok_or_else(|| AgentError::ConfigError {
message: "ast-grep tool is not bound to a repository context".to_owned(),
})?;
let head_ref = self
.head_ref
.clone()
.ok_or_else(|| AgentError::ConfigError {
message: "ast-grep tool is not bound to a head_ref".to_owned(),
})?;
Ok((repo_git_dir, head_ref))
}
fn open_repo(&self, repo_git_dir: &Path) -> AgentResult<Repository> {
Repository::open_bare(repo_git_dir).map_err(|err| AgentError::GitError {
operation: format!("open bare repository at '{}'", repo_git_dir.display()),
message: err.to_string(),
})
}
fn run_query(&self, pattern: &str) -> AgentResult<Vec<SearchHit>> {
let (repo_git_dir, head_ref) = self.bound_state()?;
let repo = self.open_repo(&repo_git_dir)?;
let snapshot = git_access::materialize_ref_to_temp_dir(&repo, &head_ref)?;
let output = Command::new(&self.ast_grep_bin)
.arg("run")
.arg("--pattern")
.arg(pattern)
.arg("--json=stream")
.arg("--no-color")
.current_dir(snapshot.path())
.output()
.map_err(|err| AgentError::ToolError {
tool: "ast-grep".to_owned(),
message: format!("failed to execute ast-grep: {err}"),
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
return Err(AgentError::ToolError {
tool: "ast-grep".to_owned(),
message: format!("ast-grep exited with status {}: {stderr}", output.status),
});
}
let stdout = String::from_utf8_lossy(&output.stdout);
let mut hits = Vec::new();
for line in stdout.lines().filter(|line| !line.trim().is_empty()) {
let parsed: Value = match serde_json::from_str(line) {
Ok(value) => value,
Err(_) => continue,
};
let file = parsed
.get("file")
.and_then(Value::as_str)
.unwrap_or_default()
.to_owned();
let snippet = parsed
.get("text")
.and_then(Value::as_str)
.unwrap_or_default()
.to_owned();
let line_num = parsed
.get("range")
.and_then(|range| range.get("start"))
.and_then(|start| start.get("line"))
.and_then(Value::as_i64)
.map(|value| value as i32 + 1)
.unwrap_or(1);
let column = parsed
.get("range")
.and_then(|range| range.get("start"))
.and_then(|start| start.get("column"))
.and_then(Value::as_i64)
.map(|value| value as i32 + 1);
hits.push(SearchHit {
file,
line: line_num,
column,
snippet,
});
}
Ok(hits)
}
pub fn search_symbol(&self, query: &str) -> AgentResult<Vec<SearchHit>> {
self.run_query(query)
}
pub fn search_pattern(&self, query: &str) -> AgentResult<Vec<SearchHit>> {
self.run_query(query)
}
pub fn get_file_context(&self, file: &str, line: i32, radius: i32) -> AgentResult<FileContext> {
let (repo_git_dir, head_ref) = self.bound_state()?;
let repo = self.open_repo(&repo_git_dir)?;
git_access::read_file_context_at_ref(&repo, &head_ref, file, line, radius)
}
}
impl Tool for AstGrepTool {
const NAME: &'static str = "ast_grep";
type Error = AgentError;
type Args = AstGrepArgs;
type Output = AstGrepOutput;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_owned(),
description: "Search source code using ast-grep and fetch file context.".to_owned(),
parameters: serde_json::json!({
"type": "object",
"oneOf": [
{
"type": "object",
"properties": {
"operation": { "type": "string", "const": "search_symbol" },
"query": { "type": "string", "description": "Symbol-like query" }
},
"required": ["operation", "query"]
},
{
"type": "object",
"properties": {
"operation": { "type": "string", "const": "search_pattern" },
"query": { "type": "string", "description": "Pattern query" }
},
"required": ["operation", "query"]
},
{
"type": "object",
"properties": {
"operation": { "type": "string", "const": "get_file_context" },
"file": { "type": "string", "description": "Path to file in repository" },
"line": { "type": "integer", "description": "1-based line number" },
"radius": { "type": "integer", "description": "Number of lines before/after line" }
},
"required": ["operation", "file", "line", "radius"]
}
]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
match args {
AstGrepArgs::SearchSymbol { query } => {
let hits = self.search_symbol(&query)?;
Ok(AstGrepOutput::SearchHits { hits })
}
AstGrepArgs::SearchPattern { query } => {
let hits = self.search_pattern(&query)?;
Ok(AstGrepOutput::SearchHits { hits })
}
AstGrepArgs::GetFileContext { file, line, radius } => {
let context = self.get_file_context(&file, line, radius)?;
Ok(AstGrepOutput::FileContext { context })
}
}
}
}

View File

@@ -0,0 +1,21 @@
mod ast_grep;
use serde::{Deserialize, Serialize};
pub use ast_grep::{AstGrepArgs, AstGrepOutput, AstGrepTool};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SearchHit {
pub file: String,
pub line: i32,
pub column: Option<i32>,
pub snippet: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FileContext {
pub file: String,
pub line_start: i32,
pub line_end: i32,
pub snippet: String,
}

View File

@@ -0,0 +1,123 @@
use chrono::{DateTime, Utc};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub enum ReviewResult {
Approve,
RequestChanges,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub struct ReviewComment {
pub comment: String,
pub file: String,
pub line: i32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub struct PullRequestReviewOutput {
pub review_result: ReviewResult,
pub global_comment: String,
pub comments: Vec<ReviewComment>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageAuthor {
User,
Agent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreadMessage {
pub author: MessageAuthor,
pub body: String,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Forge {
Gitea,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ProjectRef {
pub forge: Forge,
pub owner: String,
pub repo: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PullRequestReviewInput {
pub project: ProjectRef,
pub base_ref: String,
pub head_ref: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationInput {
pub project: ProjectRef,
pub head_ref: String,
pub anchor_file: String,
pub anchor_line: i32,
pub initial_comment: String,
pub message_chain: Vec<ThreadMessage>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub struct ConversationOutput {
pub reply: String,
}
impl Forge {
pub fn as_db_value(&self) -> &'static str {
match self {
Self::Gitea => "gitea",
}
}
pub fn from_db_value(value: &str) -> Option<Self> {
match value {
"gitea" => Some(Self::Gitea),
_ => None,
}
}
}
impl MessageAuthor {
pub fn as_db_value(&self) -> &'static str {
match self {
Self::User => "user",
Self::Agent => "agent",
}
}
pub fn from_db_value(value: &str) -> Option<Self> {
match value {
"user" => Some(Self::User),
"agent" => Some(Self::Agent),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn review_result_round_trip_serde() {
let value = ReviewResult::RequestChanges;
let raw = serde_json::to_string(&value).expect("serialize ReviewResult");
let parsed: ReviewResult = serde_json::from_str(&raw).expect("deserialize ReviewResult");
assert_eq!(parsed, value);
}
#[test]
fn message_author_round_trip_serde() {
let value = MessageAuthor::Agent;
let raw = serde_json::to_string(&value).expect("serialize MessageAuthor");
let parsed: MessageAuthor = serde_json::from_str(&raw).expect("deserialize MessageAuthor");
assert_eq!(parsed, value);
}
}

View File

@@ -0,0 +1,115 @@
use std::path::PathBuf;
use codetaker_agent::memory::{DieselMemoryStore, MemoryStore};
use codetaker_agent::types::{Forge, MessageAuthor, ProjectRef};
use codetaker_db::create_pool;
use serde_json::json;
fn test_db_path(test_name: &str) -> PathBuf {
let mut path = std::env::temp_dir();
let nonce = format!(
"{}_{}_{}.sqlite",
test_name,
std::process::id(),
chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
);
path.push(nonce);
path
}
fn sample_project() -> ProjectRef {
ProjectRef {
forge: Forge::Gitea,
owner: "acme".to_owned(),
repo: "rocket".to_owned(),
}
}
#[tokio::test]
async fn pool_creation_runs_migrations() {
let db_path = test_db_path("pool_creation_runs_migrations");
let database_url = db_path.display().to_string();
let pool = create_pool(Some(&database_url))
.await
.expect("create sqlite pool");
let _conn = pool
.get()
.await
.expect("get pooled sqlite connection after migration");
}
#[tokio::test]
async fn upsert_memory_entry_overwrites_by_key() {
let db_path = test_db_path("upsert_memory_entry_overwrites_by_key");
let database_url = db_path.display().to_string();
let pool = create_pool(Some(&database_url))
.await
.expect("create sqlite pool");
let store = DieselMemoryStore::new(pool);
let project = sample_project();
store
.upsert_memory_entry(&project, "style.preferred_errors", &json!("typed"), "seed")
.await
.expect("insert memory entry");
store
.upsert_memory_entry(
&project,
"style.preferred_errors",
&json!("typed-and-contextual"),
"refresh",
)
.await
.expect("update memory entry");
let snapshot = store
.project_context_snapshot(&project)
.await
.expect("fetch project snapshot");
assert_eq!(
snapshot.entries.get("style.preferred_errors"),
Some(&json!("typed-and-contextual"))
);
}
#[tokio::test]
async fn thread_messages_round_trip() {
let db_path = test_db_path("thread_messages_round_trip");
let database_url = db_path.display().to_string();
let pool = create_pool(Some(&database_url))
.await
.expect("create sqlite pool");
let store = DieselMemoryStore::new(pool);
let project = sample_project();
let thread_id = store
.create_review_thread(&project, "src/lib.rs", 42, "Please avoid unwrap here")
.await
.expect("create review thread");
store
.append_thread_message(thread_id, MessageAuthor::User, "Can you clarify why?")
.await
.expect("append user message");
store
.append_thread_message(
thread_id,
MessageAuthor::Agent,
"This path can fail on malformed input.",
)
.await
.expect("append agent message");
let messages = store
.load_thread_messages(thread_id)
.await
.expect("load thread messages");
assert_eq!(messages.len(), 2);
assert!(matches!(messages[0].author, MessageAuthor::User));
assert!(matches!(messages[1].author, MessageAuthor::Agent));
}