misc: initial code
This commit is contained in:
19
crates/codetaker-agent/Cargo.toml
Normal file
19
crates/codetaker-agent/Cargo.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "codetaker-agent"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
chrono.workspace = true
|
||||
codetaker-db.workspace = true
|
||||
diesel.workspace = true
|
||||
diesel-async.workspace = true
|
||||
git2.workspace = true
|
||||
rig-core.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tracing.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio.workspace = true
|
||||
198
crates/codetaker-agent/src/agent/conversation_answer.rs
Normal file
198
crates/codetaker-agent/src/agent/conversation_answer.rs
Normal file
@@ -0,0 +1,198 @@
|
||||
use std::fmt::Write;
|
||||
|
||||
use git2::Repository;
|
||||
use rig::agent::AgentBuilder;
|
||||
use rig::client::{Client, CompletionClient};
|
||||
use rig::completion::{StructuredOutputError, TypedPrompt};
|
||||
|
||||
use crate::error::{AgentError, AgentResult};
|
||||
use crate::git_access::{self};
|
||||
use crate::memory::{MemoryStore, ProjectContextSnapshot};
|
||||
use crate::tools::{AstGrepTool, FileContext, SearchHit};
|
||||
use crate::types::{ConversationInput, ConversationOutput};
|
||||
|
||||
const CONVERSATION_RESPONSE_PREAMBLE: &str = r#"
|
||||
You are a code review assistant responding in an inline review thread.
|
||||
Be concise, technical, and specific.
|
||||
Use available tools when you need extra repository context.
|
||||
Respond using the ConversationOutput JSON schema.
|
||||
"#;
|
||||
|
||||
pub struct ConversationAnswerAgent<Ext, H, M> {
|
||||
client: Client<Ext, H>,
|
||||
conversation_model: String,
|
||||
ast_grep: AstGrepTool,
|
||||
memory: M,
|
||||
}
|
||||
|
||||
impl<Ext, H, M> ConversationAnswerAgent<Ext, H, M> {
|
||||
pub fn new(
|
||||
client: Client<Ext, H>,
|
||||
conversation_model: impl Into<String>,
|
||||
ast_grep: AstGrepTool,
|
||||
memory: M,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
conversation_model: conversation_model.into(),
|
||||
ast_grep,
|
||||
memory,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Ext, H, M> ConversationAnswerAgent<Ext, H, M>
|
||||
where
|
||||
Client<Ext, H>: CompletionClient,
|
||||
M: MemoryStore,
|
||||
{
|
||||
pub async fn conversation_response(
|
||||
&self,
|
||||
repo: &Repository,
|
||||
input: ConversationInput,
|
||||
) -> AgentResult<ConversationOutput> {
|
||||
git_access::ensure_bare_repository(repo)?;
|
||||
|
||||
let memory_snapshot = self.memory.project_context_snapshot(&input.project).await?;
|
||||
let anchor_context = git_access::read_file_context_at_ref(
|
||||
repo,
|
||||
&input.head_ref,
|
||||
&input.anchor_file,
|
||||
input.anchor_line,
|
||||
40,
|
||||
)?;
|
||||
|
||||
let ast_grep = self.ast_grep.bind_to_request(repo, &input.head_ref)?;
|
||||
let search_hits = self.collect_conversation_search_hits(&ast_grep, &input);
|
||||
let prompt =
|
||||
build_conversation_prompt(&input, &memory_snapshot, &anchor_context, &search_hits);
|
||||
|
||||
let conversation_model = self.client.completion_model(&self.conversation_model);
|
||||
let agent = AgentBuilder::new(conversation_model)
|
||||
.preamble(CONVERSATION_RESPONSE_PREAMBLE)
|
||||
.tool(ast_grep)
|
||||
.temperature(0.1)
|
||||
.output_schema::<ConversationOutput>()
|
||||
.build();
|
||||
|
||||
agent
|
||||
.prompt_typed::<ConversationOutput>(prompt)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
|
||||
model: self.conversation_model.clone(),
|
||||
message: prompt_err.to_string(),
|
||||
},
|
||||
StructuredOutputError::DeserializationError(deser_err) => {
|
||||
AgentError::OutputValidationError {
|
||||
message: format!("failed to deserialize conversation output: {deser_err}"),
|
||||
raw_output: None,
|
||||
}
|
||||
}
|
||||
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
|
||||
message: "conversation response was empty".to_owned(),
|
||||
raw_output: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn collect_conversation_search_hits(
|
||||
&self,
|
||||
ast_grep: &AstGrepTool,
|
||||
input: &ConversationInput,
|
||||
) -> Vec<SearchHit> {
|
||||
match ast_grep.search_pattern(&input.initial_comment) {
|
||||
Ok(hits) => hits.into_iter().take(10).collect(),
|
||||
Err(err) => {
|
||||
tracing::warn!(error = %err, "failed to fetch search hits for conversation context");
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_conversation_prompt(
|
||||
input: &ConversationInput,
|
||||
memory: &ProjectContextSnapshot,
|
||||
anchor_context: &FileContext,
|
||||
search_hits: &[SearchHit],
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Project: {}/{} ({})",
|
||||
input.project.owner,
|
||||
input.project.repo,
|
||||
input.project.forge.as_db_value()
|
||||
)
|
||||
.ok();
|
||||
|
||||
writeln!(prompt, "Head reference: '{}'", input.head_ref).ok();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Anchor: {}:{}",
|
||||
input.anchor_file, input.anchor_line
|
||||
)
|
||||
.ok();
|
||||
|
||||
prompt.push_str("\nInitial review comment:\n");
|
||||
prompt.push_str(&input.initial_comment);
|
||||
|
||||
prompt.push_str("\n\nThread message chain:\n");
|
||||
for message in &input.message_chain {
|
||||
writeln!(
|
||||
prompt,
|
||||
"- [{:?} @ {}] {}",
|
||||
message.author, message.created_at, message.body
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
|
||||
prompt.push_str("\nAnchor file context:\n");
|
||||
writeln!(
|
||||
prompt,
|
||||
"File: {} (lines {}-{})\n{}",
|
||||
anchor_context.file,
|
||||
anchor_context.line_start,
|
||||
anchor_context.line_end,
|
||||
anchor_context.snippet
|
||||
)
|
||||
.ok();
|
||||
|
||||
prompt.push_str("\nPer-project memory (JSON map):\n");
|
||||
let memory_json =
|
||||
serde_json::to_string_pretty(&memory.entries).unwrap_or_else(|_| "{}".to_owned());
|
||||
prompt.push_str(&memory_json);
|
||||
|
||||
prompt.push_str("\n\nPer-project summaries:\n");
|
||||
if memory.summaries.is_empty() {
|
||||
prompt.push_str("(none)\n");
|
||||
} else {
|
||||
for summary in &memory.summaries {
|
||||
writeln!(
|
||||
prompt,
|
||||
"- [{} @ {}] {}",
|
||||
summary.summary_type, summary.updated_at, summary.content
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("\nTool search hits:\n");
|
||||
if search_hits.is_empty() {
|
||||
prompt.push_str("(none)\n");
|
||||
} else {
|
||||
for hit in search_hits {
|
||||
writeln!(
|
||||
prompt,
|
||||
"- {}:{} {:?} {}",
|
||||
hit.file, hit.line, hit.column, hit.snippet
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("\nRespond to the latest user concern in this thread.");
|
||||
|
||||
prompt
|
||||
}
|
||||
5
crates/codetaker-agent/src/agent/mod.rs
Normal file
5
crates/codetaker-agent/src/agent/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod conversation_answer;
|
||||
mod pull_request_review;
|
||||
|
||||
pub use conversation_answer::ConversationAnswerAgent;
|
||||
pub use pull_request_review::PullRequestReviewAgent;
|
||||
213
crates/codetaker-agent/src/agent/pull_request_review.rs
Normal file
213
crates/codetaker-agent/src/agent/pull_request_review.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
use std::fmt::Write;
|
||||
|
||||
use git2::Repository;
|
||||
use rig::agent::AgentBuilder;
|
||||
use rig::client::{Client, CompletionClient};
|
||||
use rig::completion::{StructuredOutputError, TypedPrompt};
|
||||
|
||||
use crate::error::{AgentError, AgentResult};
|
||||
use crate::git_access::{self, PullRequestMaterial};
|
||||
use crate::memory::{MemoryStore, ProjectContextSnapshot};
|
||||
use crate::tools::{AstGrepTool, FileContext};
|
||||
use crate::types::{PullRequestReviewInput, PullRequestReviewOutput};
|
||||
|
||||
const PULL_REQUEST_REVIEW_PREAMBLE: &str = r#"
|
||||
You are a code review agent.
|
||||
Analyze pull request diffs and respond only with valid JSON.
|
||||
Use available tools when you need extra repository context.
|
||||
Schema:
|
||||
{
|
||||
"review_result": "Approve" | "RequestChanges",
|
||||
"global_comment": string,
|
||||
"comments": [
|
||||
{
|
||||
"comment": string,
|
||||
"file": string,
|
||||
"line": integer
|
||||
}
|
||||
]
|
||||
}
|
||||
Line numbers must target the head/new file version.
|
||||
Do not include extra keys.
|
||||
"#;
|
||||
|
||||
pub struct PullRequestReviewAgent<Ext, H, M> {
|
||||
client: Client<Ext, H>,
|
||||
review_model: String,
|
||||
ast_grep: AstGrepTool,
|
||||
memory: M,
|
||||
}
|
||||
|
||||
impl<Ext, H, M> PullRequestReviewAgent<Ext, H, M> {
|
||||
pub fn new(
|
||||
client: Client<Ext, H>,
|
||||
review_model: impl Into<String>,
|
||||
ast_grep: AstGrepTool,
|
||||
memory: M,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
review_model: review_model.into(),
|
||||
ast_grep,
|
||||
memory,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Ext, H, M> PullRequestReviewAgent<Ext, H, M>
|
||||
where
|
||||
Client<Ext, H>: CompletionClient,
|
||||
M: MemoryStore,
|
||||
{
|
||||
pub async fn pull_request_review(
|
||||
&self,
|
||||
repo: &Repository,
|
||||
input: PullRequestReviewInput,
|
||||
) -> AgentResult<PullRequestReviewOutput> {
|
||||
git_access::ensure_bare_repository(repo)?;
|
||||
|
||||
let memory_snapshot = self.memory.project_context_snapshot(&input.project).await?;
|
||||
let pr_material = git_access::compute_pr_material(repo, &input.base_ref, &input.head_ref)?;
|
||||
let file_contexts = self.collect_diff_file_contexts(repo, &input.head_ref, &pr_material)?;
|
||||
let prompt =
|
||||
build_pull_request_prompt(&input, &memory_snapshot, &pr_material, &file_contexts);
|
||||
|
||||
let ast_grep = self.ast_grep.bind_to_request(repo, &input.head_ref)?;
|
||||
|
||||
let review_model = self.client.completion_model(&self.review_model);
|
||||
let agent = AgentBuilder::new(review_model)
|
||||
.preamble(PULL_REQUEST_REVIEW_PREAMBLE)
|
||||
.tool(ast_grep)
|
||||
.temperature(0.0)
|
||||
.output_schema::<PullRequestReviewOutput>()
|
||||
.build();
|
||||
|
||||
agent
|
||||
.prompt_typed::<PullRequestReviewOutput>(prompt)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
StructuredOutputError::PromptError(prompt_err) => AgentError::ModelError {
|
||||
model: self.review_model.clone(),
|
||||
message: prompt_err.to_string(),
|
||||
},
|
||||
StructuredOutputError::DeserializationError(deser_err) => {
|
||||
AgentError::OutputValidationError {
|
||||
message: format!(
|
||||
"failed to deserialize pull request review output: {deser_err}"
|
||||
),
|
||||
raw_output: None,
|
||||
}
|
||||
}
|
||||
StructuredOutputError::EmptyResponse => AgentError::OutputValidationError {
|
||||
message: "pull request review response was empty".to_owned(),
|
||||
raw_output: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn collect_diff_file_contexts(
|
||||
&self,
|
||||
repo: &Repository,
|
||||
head_ref: &str,
|
||||
pr_material: &PullRequestMaterial,
|
||||
) -> AgentResult<Vec<FileContext>> {
|
||||
let mut contexts = Vec::new();
|
||||
|
||||
for file in pr_material.changed_files.iter().take(5) {
|
||||
let line = pr_material
|
||||
.first_changed_head_lines
|
||||
.get(file)
|
||||
.copied()
|
||||
.unwrap_or(1);
|
||||
match git_access::read_file_context_at_ref(repo, head_ref, file, line, 40) {
|
||||
Ok(context) => contexts.push(context),
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
file = %file,
|
||||
error = %err,
|
||||
"failed to fetch file context for pull request review"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(contexts)
|
||||
}
|
||||
}
|
||||
|
||||
fn build_pull_request_prompt(
|
||||
input: &PullRequestReviewInput,
|
||||
memory: &ProjectContextSnapshot,
|
||||
material: &PullRequestMaterial,
|
||||
file_contexts: &[FileContext],
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Project: {}/{} ({})",
|
||||
input.project.owner,
|
||||
input.project.repo,
|
||||
input.project.forge.as_db_value()
|
||||
)
|
||||
.ok();
|
||||
|
||||
writeln!(
|
||||
prompt,
|
||||
"Review range: base_ref='{}', head_ref='{}'",
|
||||
input.base_ref, input.head_ref
|
||||
)
|
||||
.ok();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Resolved OIDs: base={}, head={}, merge_base={}",
|
||||
material.base_oid, material.head_oid, material.merge_base_oid
|
||||
)
|
||||
.ok();
|
||||
|
||||
prompt.push_str("\nChanged files:\n");
|
||||
if material.changed_files.is_empty() {
|
||||
prompt.push_str("(none)\n");
|
||||
} else {
|
||||
for file in &material.changed_files {
|
||||
writeln!(prompt, "- {file}").ok();
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("\nPer-project memory (JSON map):\n");
|
||||
let memory_json =
|
||||
serde_json::to_string_pretty(&memory.entries).unwrap_or_else(|_| "{}".to_owned());
|
||||
prompt.push_str(&memory_json);
|
||||
prompt.push_str("\n\nPer-project summaries:\n");
|
||||
if memory.summaries.is_empty() {
|
||||
prompt.push_str("(none)\n");
|
||||
} else {
|
||||
for summary in &memory.summaries {
|
||||
writeln!(
|
||||
prompt,
|
||||
"- [{} @ {}] {}",
|
||||
summary.summary_type, summary.updated_at, summary.content
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("\nAdditional file contexts:\n");
|
||||
if file_contexts.is_empty() {
|
||||
prompt.push_str("(none)\n");
|
||||
} else {
|
||||
for context in file_contexts {
|
||||
writeln!(
|
||||
prompt,
|
||||
"File: {} (lines {}-{})\n{}",
|
||||
context.file, context.line_start, context.line_end, context.snippet
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("\nPull request diff:\n");
|
||||
prompt.push_str(&material.patch);
|
||||
prompt.push_str("\n\nReturn only JSON matching the schema exactly.");
|
||||
|
||||
prompt
|
||||
}
|
||||
59
crates/codetaker-agent/src/error.rs
Normal file
59
crates/codetaker-agent/src/error.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::error::Error;
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
pub type AgentResult<T> = Result<T, AgentError>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AgentError {
|
||||
MemoryError {
|
||||
message: String,
|
||||
},
|
||||
GitError {
|
||||
operation: String,
|
||||
message: String,
|
||||
},
|
||||
ToolError {
|
||||
tool: String,
|
||||
message: String,
|
||||
},
|
||||
ModelError {
|
||||
model: String,
|
||||
message: String,
|
||||
},
|
||||
OutputValidationError {
|
||||
message: String,
|
||||
raw_output: Option<String>,
|
||||
},
|
||||
ConfigError {
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for AgentError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::MemoryError { message } => write!(f, "memory error: {message}"),
|
||||
Self::GitError { operation, message } => {
|
||||
write!(f, "git error ({operation}): {message}")
|
||||
}
|
||||
Self::ToolError { tool, message } => write!(f, "tool error ({tool}): {message}"),
|
||||
Self::ModelError { model, message } => write!(f, "model error ({model}): {message}"),
|
||||
Self::OutputValidationError {
|
||||
message,
|
||||
raw_output,
|
||||
} => {
|
||||
if let Some(raw_output) = raw_output {
|
||||
write!(
|
||||
f,
|
||||
"output validation error: {message}; raw output: {raw_output}"
|
||||
)
|
||||
} else {
|
||||
write!(f, "output validation error: {message}")
|
||||
}
|
||||
}
|
||||
Self::ConfigError { message } => write!(f, "config error: {message}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for AgentError {}
|
||||
465
crates/codetaker-agent/src/git_access.rs
Normal file
465
crates/codetaker-agent/src/git_access.rs
Normal file
@@ -0,0 +1,465 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use git2::{Commit, DiffFormat, ObjectType, Oid, Repository, Tree};
|
||||
|
||||
use crate::error::{AgentError, AgentResult};
|
||||
use crate::tools::FileContext;
|
||||
|
||||
pub const MAX_PATCH_CHARS: usize = 250_000;
|
||||
const TRUNCATION_NOTICE: &str = "\n\n[DIFF TRUNCATED: content exceeds configured prompt limit]\n";
|
||||
const NON_TEXT_BLOB_NOTICE: &str = "[non-text blob omitted]";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PullRequestMaterial {
|
||||
pub base_oid: Oid,
|
||||
pub head_oid: Oid,
|
||||
pub merge_base_oid: Oid,
|
||||
pub patch: String,
|
||||
pub changed_files: Vec<String>,
|
||||
pub first_changed_head_lines: HashMap<String, i32>,
|
||||
}
|
||||
|
||||
pub struct TempRepoSnapshot {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl TempRepoSnapshot {
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TempRepoSnapshot {
|
||||
fn drop(&mut self) {
|
||||
if let Err(err) = fs::remove_dir_all(&self.path) {
|
||||
tracing::warn!(
|
||||
path = %self.path.display(),
|
||||
error = %err,
|
||||
"failed to clean up temporary repository snapshot"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ensure_bare_repository(repo: &Repository) -> AgentResult<()> {
|
||||
if !repo.is_bare() {
|
||||
return Err(AgentError::ConfigError {
|
||||
message: "expected bare git repository for this agent mode".to_owned(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn resolve_commit<'repo>(
|
||||
repo: &'repo Repository,
|
||||
spec: &str,
|
||||
) -> AgentResult<(Oid, Commit<'repo>)> {
|
||||
let object = repo
|
||||
.revparse_ext(spec)
|
||||
.or_else(|_| repo.revparse_single(spec).map(|obj| (obj, None)))
|
||||
.map(|(obj, _)| obj)
|
||||
.map_err(|err| git_error(format!("resolve revision '{spec}'"), err))?;
|
||||
|
||||
let commit = object
|
||||
.peel_to_commit()
|
||||
.map_err(|err| git_error(format!("peel revision '{spec}' to commit"), err))?;
|
||||
|
||||
Ok((commit.id(), commit))
|
||||
}
|
||||
|
||||
pub fn compute_pr_material(
|
||||
repo: &Repository,
|
||||
base_ref: &str,
|
||||
head_ref: &str,
|
||||
) -> AgentResult<PullRequestMaterial> {
|
||||
let (base_oid, _base_commit) = resolve_commit(repo, base_ref)?;
|
||||
let (head_oid, head_commit) = resolve_commit(repo, head_ref)?;
|
||||
|
||||
let merge_base_oid = repo.merge_base(base_oid, head_oid).map_err(|err| {
|
||||
git_error(
|
||||
format!("compute merge base between '{base_ref}' and '{head_ref}'"),
|
||||
err,
|
||||
)
|
||||
})?;
|
||||
|
||||
let merge_base_commit = repo
|
||||
.find_commit(merge_base_oid)
|
||||
.map_err(|err| git_error(format!("load merge-base commit {merge_base_oid}"), err))?;
|
||||
|
||||
let base_tree = merge_base_commit
|
||||
.tree()
|
||||
.map_err(|err| git_error(format!("load merge-base tree for {merge_base_oid}"), err))?;
|
||||
let head_tree = head_commit
|
||||
.tree()
|
||||
.map_err(|err| git_error(format!("load head tree for {head_oid}"), err))?;
|
||||
|
||||
let diff = repo
|
||||
.diff_tree_to_tree(Some(&base_tree), Some(&head_tree), None)
|
||||
.map_err(|err| git_error("generate tree diff".to_owned(), err))?;
|
||||
|
||||
let mut changed_files = Vec::new();
|
||||
let mut seen_files = HashSet::new();
|
||||
for delta in diff.deltas() {
|
||||
if let Some(path) = delta.new_file().path().or(delta.old_file().path()) {
|
||||
let file = path.to_string_lossy().to_string();
|
||||
if seen_files.insert(file.clone()) {
|
||||
changed_files.push(file);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut first_changed_head_lines: HashMap<String, i32> = HashMap::new();
|
||||
diff.foreach(
|
||||
&mut |_, _| true,
|
||||
None,
|
||||
None,
|
||||
Some(&mut |delta, _hunk, line| {
|
||||
let origin = line.origin();
|
||||
if !matches!(origin, '+' | '>') {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(line_no) = line.new_lineno() else {
|
||||
return true;
|
||||
};
|
||||
|
||||
let Some(path) = delta.new_file().path().or(delta.old_file().path()) else {
|
||||
return true;
|
||||
};
|
||||
|
||||
let file = path.to_string_lossy().to_string();
|
||||
first_changed_head_lines
|
||||
.entry(file)
|
||||
.or_insert(line_no as i32);
|
||||
true
|
||||
}),
|
||||
)
|
||||
.map_err(|err| git_error("walk diff lines".to_owned(), err))?;
|
||||
|
||||
let mut patch = String::new();
|
||||
diff.print(DiffFormat::Patch, |_delta, _hunk, line| {
|
||||
patch.push_str(&String::from_utf8_lossy(line.content()));
|
||||
true
|
||||
})
|
||||
.map_err(|err| git_error("render diff patch".to_owned(), err))?;
|
||||
|
||||
Ok(PullRequestMaterial {
|
||||
base_oid,
|
||||
head_oid,
|
||||
merge_base_oid,
|
||||
patch: truncate_patch(patch, MAX_PATCH_CHARS),
|
||||
changed_files,
|
||||
first_changed_head_lines,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn read_file_context_at_ref(
|
||||
repo: &Repository,
|
||||
head_ref: &str,
|
||||
file: &str,
|
||||
line: i32,
|
||||
radius: i32,
|
||||
) -> AgentResult<FileContext> {
|
||||
let (_head_oid, head_commit) = resolve_commit(repo, head_ref)?;
|
||||
let tree = head_commit
|
||||
.tree()
|
||||
.map_err(|err| git_error(format!("load tree for '{head_ref}'"), err))?;
|
||||
|
||||
let entry = tree
|
||||
.get_path(Path::new(file))
|
||||
.map_err(|err| git_error(format!("lookup file '{file}' in '{head_ref}'"), err))?;
|
||||
let object = entry
|
||||
.to_object(repo)
|
||||
.map_err(|err| git_error(format!("load git object for '{file}' in '{head_ref}'"), err))?;
|
||||
let blob = object
|
||||
.peel_to_blob()
|
||||
.map_err(|err| git_error(format!("peel '{file}' to blob in '{head_ref}'"), err))?;
|
||||
|
||||
if blob.is_binary() {
|
||||
return Ok(non_text_context(file, line));
|
||||
}
|
||||
|
||||
let content = match std::str::from_utf8(blob.content()) {
|
||||
Ok(content) => content,
|
||||
Err(_) => return Ok(non_text_context(file, line)),
|
||||
};
|
||||
|
||||
Ok(text_context_from_content(file, content, line, radius))
|
||||
}
|
||||
|
||||
pub fn materialize_ref_to_temp_dir(
|
||||
repo: &Repository,
|
||||
head_ref: &str,
|
||||
) -> AgentResult<TempRepoSnapshot> {
|
||||
let (_head_oid, head_commit) = resolve_commit(repo, head_ref)?;
|
||||
let tree = head_commit
|
||||
.tree()
|
||||
.map_err(|err| git_error(format!("load tree for '{head_ref}'"), err))?;
|
||||
|
||||
let snapshot_dir = unique_temp_snapshot_dir();
|
||||
fs::create_dir_all(&snapshot_dir).map_err(|err| AgentError::GitError {
|
||||
operation: "create temporary repository snapshot directory".to_owned(),
|
||||
message: err.to_string(),
|
||||
})?;
|
||||
|
||||
if let Err(err) = write_tree_to_dir(repo, &tree, &snapshot_dir) {
|
||||
let _ = fs::remove_dir_all(&snapshot_dir);
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
Ok(TempRepoSnapshot { path: snapshot_dir })
|
||||
}
|
||||
|
||||
fn write_tree_to_dir(repo: &Repository, tree: &Tree<'_>, root: &Path) -> AgentResult<()> {
|
||||
for entry in tree.iter() {
|
||||
let name = entry.name().ok_or_else(|| AgentError::GitError {
|
||||
operation: "decode tree entry name".to_owned(),
|
||||
message: "encountered non-UTF8 tree entry name".to_owned(),
|
||||
})?;
|
||||
let target_path = root.join(name);
|
||||
|
||||
match entry.kind() {
|
||||
Some(ObjectType::Tree) => {
|
||||
fs::create_dir_all(&target_path).map_err(|err| AgentError::GitError {
|
||||
operation: format!(
|
||||
"create directory '{}' while materializing tree",
|
||||
target_path.display()
|
||||
),
|
||||
message: err.to_string(),
|
||||
})?;
|
||||
|
||||
let subtree = repo
|
||||
.find_tree(entry.id())
|
||||
.map_err(|err| git_error(format!("load subtree {}", entry.id()), err))?;
|
||||
write_tree_to_dir(repo, &subtree, &target_path)?;
|
||||
}
|
||||
Some(ObjectType::Blob) => {
|
||||
if let Some(parent) = target_path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|err| AgentError::GitError {
|
||||
operation: format!(
|
||||
"create parent directory '{}' for snapshot file",
|
||||
parent.display()
|
||||
),
|
||||
message: err.to_string(),
|
||||
})?;
|
||||
}
|
||||
|
||||
let blob = repo
|
||||
.find_blob(entry.id())
|
||||
.map_err(|err| git_error(format!("load blob {}", entry.id()), err))?;
|
||||
fs::write(&target_path, blob.content()).map_err(|err| AgentError::GitError {
|
||||
operation: format!("write snapshot file '{}'", target_path.display()),
|
||||
message: err.to_string(),
|
||||
})?;
|
||||
}
|
||||
_ => {
|
||||
// Non-blob/tree entries are ignored for snapshot materialization.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn truncate_patch(mut patch: String, max_chars: usize) -> String {
|
||||
if patch.chars().count() <= max_chars {
|
||||
return patch;
|
||||
}
|
||||
|
||||
patch = patch.chars().take(max_chars).collect();
|
||||
patch.push_str(TRUNCATION_NOTICE);
|
||||
patch
|
||||
}
|
||||
|
||||
fn non_text_context(file: &str, line: i32) -> FileContext {
|
||||
let safe_line = line.max(1);
|
||||
FileContext {
|
||||
file: file.to_owned(),
|
||||
line_start: safe_line,
|
||||
line_end: safe_line,
|
||||
snippet: NON_TEXT_BLOB_NOTICE.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
fn text_context_from_content(file: &str, content: &str, line: i32, radius: i32) -> FileContext {
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
if lines.is_empty() {
|
||||
return FileContext {
|
||||
file: file.to_owned(),
|
||||
line_start: 1,
|
||||
line_end: 1,
|
||||
snippet: String::new(),
|
||||
};
|
||||
}
|
||||
|
||||
let safe_line = line.clamp(1, lines.len() as i32);
|
||||
let safe_radius = radius.max(0);
|
||||
let start = (safe_line - safe_radius).max(1) as usize;
|
||||
let end = (safe_line + safe_radius).min(lines.len() as i32) as usize;
|
||||
|
||||
let mut snippet = String::new();
|
||||
for (idx, line_content) in lines[start - 1..end].iter().enumerate() {
|
||||
let current_line = start + idx;
|
||||
snippet.push_str(&format!("{current_line:>6} | {line_content}\n"));
|
||||
}
|
||||
|
||||
FileContext {
|
||||
file: file.to_owned(),
|
||||
line_start: start as i32,
|
||||
line_end: end as i32,
|
||||
snippet,
|
||||
}
|
||||
}
|
||||
|
||||
fn unique_temp_snapshot_dir() -> PathBuf {
|
||||
let nonce = format!(
|
||||
"codetaker_ast_grep_snapshot_{}_{}",
|
||||
std::process::id(),
|
||||
chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
|
||||
);
|
||||
std::env::temp_dir().join(nonce)
|
||||
}
|
||||
|
||||
fn git_error(operation: impl Into<String>, err: git2::Error) -> AgentError {
|
||||
AgentError::GitError {
|
||||
operation: operation.into(),
|
||||
message: err.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use git2::{Repository, Signature};
|
||||
|
||||
fn test_repo_path(test_name: &str) -> PathBuf {
|
||||
let nonce = format!(
|
||||
"{}_{}_{}",
|
||||
test_name,
|
||||
std::process::id(),
|
||||
chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
|
||||
);
|
||||
std::env::temp_dir().join(nonce)
|
||||
}
|
||||
|
||||
fn create_commit_with_files(
|
||||
repo: &Repository,
|
||||
target_ref: &str,
|
||||
parent: Option<Oid>,
|
||||
message: &str,
|
||||
files: &[(&str, &[u8])],
|
||||
) -> Oid {
|
||||
let mut builder = repo.treebuilder(None).expect("create treebuilder");
|
||||
for (path, content) in files {
|
||||
let blob_id = repo.blob(content).expect("write blob");
|
||||
builder
|
||||
.insert(*path, blob_id, 0o100644)
|
||||
.expect("insert tree entry");
|
||||
}
|
||||
|
||||
let tree_id = builder.write().expect("write tree");
|
||||
let tree = repo.find_tree(tree_id).expect("find tree");
|
||||
let sig = Signature::now("Codetaker", "codetaker@example.com").expect("signature");
|
||||
|
||||
let parent_commits = parent
|
||||
.map(|oid| vec![repo.find_commit(oid).expect("parent commit")])
|
||||
.unwrap_or_default();
|
||||
let parent_refs: Vec<&Commit<'_>> = parent_commits.iter().collect();
|
||||
|
||||
repo.commit(Some(target_ref), &sig, &sig, message, &tree, &parent_refs)
|
||||
.expect("create commit")
|
||||
}
|
||||
|
||||
fn bare_repo_fixture() -> (PathBuf, Repository, Oid, Oid, Oid) {
|
||||
let repo_path = test_repo_path("git_access_fixture");
|
||||
let repo = Repository::init_bare(&repo_path).expect("init bare repo");
|
||||
|
||||
let root = create_commit_with_files(
|
||||
&repo,
|
||||
"refs/heads/main",
|
||||
None,
|
||||
"root",
|
||||
&[("a.txt", b"line1\nline2\n")],
|
||||
);
|
||||
|
||||
let main_head = create_commit_with_files(
|
||||
&repo,
|
||||
"refs/heads/main",
|
||||
Some(root),
|
||||
"main update",
|
||||
&[
|
||||
("a.txt", b"line1\nline2\n"),
|
||||
("main_only.txt", b"main branch only\n"),
|
||||
],
|
||||
);
|
||||
|
||||
let feature_head = create_commit_with_files(
|
||||
&repo,
|
||||
"refs/heads/feature",
|
||||
Some(root),
|
||||
"feature update",
|
||||
&[("a.txt", b"line1\nline2_feature\n")],
|
||||
);
|
||||
|
||||
(repo_path, repo, root, main_head, feature_head)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_commit_handles_valid_and_invalid_specs() {
|
||||
let (repo_path, repo, _root, _main_head, feature_head) = bare_repo_fixture();
|
||||
|
||||
let (oid, _commit) = resolve_commit(&repo, "refs/heads/feature").expect("resolve feature");
|
||||
assert_eq!(oid, feature_head);
|
||||
|
||||
let err = resolve_commit(&repo, "refs/heads/does-not-exist").expect_err("expected error");
|
||||
assert!(matches!(err, AgentError::GitError { .. }));
|
||||
|
||||
fs::remove_dir_all(repo_path).expect("cleanup fixture");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_pr_material_uses_merge_base_to_head_range() {
|
||||
let (repo_path, repo, merge_base, _main_head, _feature_head) = bare_repo_fixture();
|
||||
|
||||
let material = compute_pr_material(&repo, "refs/heads/main", "refs/heads/feature")
|
||||
.expect("compute pr material");
|
||||
|
||||
assert_eq!(material.merge_base_oid, merge_base);
|
||||
assert!(material.patch.contains("line2_feature"));
|
||||
assert!(!material.patch.contains("main branch only"));
|
||||
assert!(material.changed_files.contains(&"a.txt".to_owned()));
|
||||
assert_eq!(material.first_changed_head_lines.get("a.txt"), Some(&2));
|
||||
|
||||
fs::remove_dir_all(repo_path).expect("cleanup fixture");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_file_context_handles_text_and_binary_blobs() {
|
||||
let repo_path = test_repo_path("git_access_binary_fixture");
|
||||
let repo = Repository::init_bare(&repo_path).expect("init bare repo");
|
||||
|
||||
let _head = create_commit_with_files(
|
||||
&repo,
|
||||
"refs/heads/main",
|
||||
None,
|
||||
"with binary",
|
||||
&[
|
||||
("a.txt", b"one\ntwo\nthree\n"),
|
||||
("bin.dat", &[0, 255, 0, 1]),
|
||||
],
|
||||
);
|
||||
|
||||
let text_context = read_file_context_at_ref(&repo, "refs/heads/main", "a.txt", 2, 1)
|
||||
.expect("text context");
|
||||
assert!(text_context.snippet.contains("two"));
|
||||
|
||||
let binary_context = read_file_context_at_ref(&repo, "refs/heads/main", "bin.dat", 1, 1)
|
||||
.expect("binary context");
|
||||
assert_eq!(binary_context.snippet, NON_TEXT_BLOB_NOTICE);
|
||||
|
||||
fs::remove_dir_all(repo_path).expect("cleanup fixture");
|
||||
}
|
||||
}
|
||||
16
crates/codetaker-agent/src/lib.rs
Normal file
16
crates/codetaker-agent/src/lib.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
pub mod agent;
|
||||
pub mod error;
|
||||
pub mod git_access;
|
||||
pub mod memory;
|
||||
pub mod tools;
|
||||
pub mod types;
|
||||
|
||||
pub use agent::{ConversationAnswerAgent, PullRequestReviewAgent};
|
||||
pub use error::{AgentError, AgentResult};
|
||||
pub use git_access::PullRequestMaterial;
|
||||
pub use memory::{DieselMemoryStore, MemoryStore, ProjectContextSnapshot};
|
||||
pub use tools::{AstGrepArgs, AstGrepOutput, AstGrepTool, FileContext, SearchHit};
|
||||
pub use types::{
|
||||
ConversationInput, ConversationOutput, Forge, MessageAuthor, ProjectRef,
|
||||
PullRequestReviewInput, PullRequestReviewOutput, ReviewComment, ReviewResult, ThreadMessage,
|
||||
};
|
||||
65
crates/codetaker-agent/src/memory/mod.rs
Normal file
65
crates/codetaker-agent/src/memory/mod.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
mod sqlite;
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use chrono::NaiveDateTime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::error::AgentResult;
|
||||
use crate::types::{MessageAuthor, ProjectRef, ThreadMessage};
|
||||
|
||||
pub use sqlite::DieselMemoryStore;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemorySummary {
|
||||
pub summary_type: String,
|
||||
pub content: String,
|
||||
pub updated_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProjectContextSnapshot {
|
||||
pub entries: BTreeMap<String, Value>,
|
||||
pub summaries: Vec<MemorySummary>,
|
||||
}
|
||||
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait MemoryStore {
|
||||
async fn project_context_snapshot(
|
||||
&self,
|
||||
project: &ProjectRef,
|
||||
) -> AgentResult<ProjectContextSnapshot>;
|
||||
|
||||
async fn upsert_memory_entry(
|
||||
&self,
|
||||
project: &ProjectRef,
|
||||
key: &str,
|
||||
value: &Value,
|
||||
source: &str,
|
||||
) -> AgentResult<()>;
|
||||
|
||||
async fn upsert_memory_summary(
|
||||
&self,
|
||||
project: &ProjectRef,
|
||||
summary_type: &str,
|
||||
content: &str,
|
||||
) -> AgentResult<()>;
|
||||
|
||||
async fn create_review_thread(
|
||||
&self,
|
||||
project: &ProjectRef,
|
||||
file: &str,
|
||||
line: i32,
|
||||
initial_comment: &str,
|
||||
) -> AgentResult<i32>;
|
||||
|
||||
async fn append_thread_message(
|
||||
&self,
|
||||
thread_id: i32,
|
||||
author: MessageAuthor,
|
||||
body: &str,
|
||||
) -> AgentResult<()>;
|
||||
|
||||
async fn load_thread_messages(&self, thread_id: i32) -> AgentResult<Vec<ThreadMessage>>;
|
||||
}
|
||||
319
crates/codetaker-agent/src/memory/sqlite.rs
Normal file
319
crates/codetaker-agent/src/memory/sqlite.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use chrono::Utc;
|
||||
use codetaker_db::models::{
|
||||
NewProjectMemoryEntryRow, NewProjectMemorySummaryRow, NewProjectRow, NewReviewThreadMessageRow,
|
||||
NewReviewThreadRow, ProjectMemoryEntryRow, ProjectMemorySummaryRow, ReviewThreadMessageRow,
|
||||
};
|
||||
use codetaker_db::schema::{
|
||||
project_memory_entries, project_memory_summaries, projects, review_thread_messages,
|
||||
review_threads,
|
||||
};
|
||||
use codetaker_db::{DatabaseConnection, DatabasePool};
|
||||
use diesel::OptionalExtension;
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::error::{AgentError, AgentResult};
|
||||
use crate::memory::{MemoryStore, MemorySummary, ProjectContextSnapshot};
|
||||
use crate::types::{MessageAuthor, ProjectRef, ThreadMessage};
|
||||
|
||||
type PooledConnection<'a> =
|
||||
diesel_async::pooled_connection::bb8::PooledConnection<'a, DatabaseConnection>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DieselMemoryStore {
|
||||
pool: DatabasePool,
|
||||
}
|
||||
|
||||
impl DieselMemoryStore {
|
||||
pub fn new(pool: DatabasePool) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
async fn get_conn(&self) -> AgentResult<PooledConnection<'_>> {
|
||||
self.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to get database connection from pool: {err}"),
|
||||
})
|
||||
}
|
||||
|
||||
async fn ensure_project_id(
|
||||
conn: &mut PooledConnection<'_>,
|
||||
project: &ProjectRef,
|
||||
) -> AgentResult<i32> {
|
||||
use codetaker_db::schema::projects::dsl;
|
||||
|
||||
let existing_id = dsl::projects
|
||||
.filter(dsl::forge.eq(project.forge.as_db_value()))
|
||||
.filter(dsl::owner.eq(&project.owner))
|
||||
.filter(dsl::repo.eq(&project.repo))
|
||||
.select(dsl::id)
|
||||
.first::<i32>(conn)
|
||||
.await
|
||||
.optional()
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to query project id: {err}"),
|
||||
})?;
|
||||
|
||||
if let Some(existing_id) = existing_id {
|
||||
return Ok(existing_id);
|
||||
}
|
||||
|
||||
let new_row = NewProjectRow {
|
||||
forge: project.forge.as_db_value(),
|
||||
owner: &project.owner,
|
||||
repo: &project.repo,
|
||||
};
|
||||
|
||||
diesel::insert_into(projects::table)
|
||||
.values(&new_row)
|
||||
.execute(conn)
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to insert project row: {err}"),
|
||||
})?;
|
||||
|
||||
dsl::projects
|
||||
.filter(dsl::forge.eq(project.forge.as_db_value()))
|
||||
.filter(dsl::owner.eq(&project.owner))
|
||||
.filter(dsl::repo.eq(&project.repo))
|
||||
.select(dsl::id)
|
||||
.first::<i32>(conn)
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to re-query inserted project id: {err}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryStore for DieselMemoryStore {
|
||||
async fn project_context_snapshot(
|
||||
&self,
|
||||
project: &ProjectRef,
|
||||
) -> AgentResult<ProjectContextSnapshot> {
|
||||
use codetaker_db::schema::project_memory_entries::dsl as entries_dsl;
|
||||
use codetaker_db::schema::project_memory_summaries::dsl as summaries_dsl;
|
||||
|
||||
let mut conn = self.get_conn().await?;
|
||||
let project_id = Self::ensure_project_id(&mut conn, project).await?;
|
||||
|
||||
let entry_rows = entries_dsl::project_memory_entries
|
||||
.filter(entries_dsl::project_id.eq(project_id))
|
||||
.order(entries_dsl::key.asc())
|
||||
.load::<ProjectMemoryEntryRow>(&mut conn)
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to load project memory entries: {err}"),
|
||||
})?;
|
||||
|
||||
let summary_rows = summaries_dsl::project_memory_summaries
|
||||
.filter(summaries_dsl::project_id.eq(project_id))
|
||||
.order(summaries_dsl::summary_type.asc())
|
||||
.load::<ProjectMemorySummaryRow>(&mut conn)
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to load project memory summaries: {err}"),
|
||||
})?;
|
||||
|
||||
let mut entries = BTreeMap::new();
|
||||
for row in entry_rows {
|
||||
let parsed: Value =
|
||||
serde_json::from_str(&row.value_json).map_err(|err| AgentError::MemoryError {
|
||||
message: format!(
|
||||
"failed to parse memory entry '{}' JSON value '{}': {err}",
|
||||
row.key, row.value_json
|
||||
),
|
||||
})?;
|
||||
entries.insert(row.key, parsed);
|
||||
}
|
||||
|
||||
let summaries = summary_rows
|
||||
.into_iter()
|
||||
.map(|row| MemorySummary {
|
||||
summary_type: row.summary_type,
|
||||
content: row.content,
|
||||
updated_at: row.updated_at,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ProjectContextSnapshot { entries, summaries })
|
||||
}
|
||||
|
||||
async fn upsert_memory_entry(
|
||||
&self,
|
||||
project: &ProjectRef,
|
||||
key: &str,
|
||||
value: &Value,
|
||||
source: &str,
|
||||
) -> AgentResult<()> {
|
||||
use codetaker_db::schema::project_memory_entries::dsl;
|
||||
|
||||
let mut conn = self.get_conn().await?;
|
||||
let project_id = Self::ensure_project_id(&mut conn, project).await?;
|
||||
let now = Utc::now().naive_utc();
|
||||
let value_json = serde_json::to_string(value).map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to serialize memory entry '{key}': {err}"),
|
||||
})?;
|
||||
|
||||
let new_row = NewProjectMemoryEntryRow {
|
||||
project_id,
|
||||
key,
|
||||
value_json: &value_json,
|
||||
source,
|
||||
updated_at: now,
|
||||
};
|
||||
|
||||
diesel::insert_into(project_memory_entries::table)
|
||||
.values(&new_row)
|
||||
.on_conflict((dsl::project_id, dsl::key))
|
||||
.do_update()
|
||||
.set((
|
||||
dsl::value_json.eq(&value_json),
|
||||
dsl::source.eq(source),
|
||||
dsl::updated_at.eq(now),
|
||||
))
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to upsert memory entry '{key}': {err}"),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn upsert_memory_summary(
|
||||
&self,
|
||||
project: &ProjectRef,
|
||||
summary_type: &str,
|
||||
content: &str,
|
||||
) -> AgentResult<()> {
|
||||
use codetaker_db::schema::project_memory_summaries::dsl;
|
||||
|
||||
let mut conn = self.get_conn().await?;
|
||||
let project_id = Self::ensure_project_id(&mut conn, project).await?;
|
||||
let now = Utc::now().naive_utc();
|
||||
|
||||
let new_row = NewProjectMemorySummaryRow {
|
||||
project_id,
|
||||
summary_type,
|
||||
content,
|
||||
updated_at: now,
|
||||
};
|
||||
|
||||
diesel::insert_into(project_memory_summaries::table)
|
||||
.values(&new_row)
|
||||
.on_conflict((dsl::project_id, dsl::summary_type))
|
||||
.do_update()
|
||||
.set((dsl::content.eq(content), dsl::updated_at.eq(now)))
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to upsert memory summary '{summary_type}': {err}"),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_review_thread(
|
||||
&self,
|
||||
project: &ProjectRef,
|
||||
file: &str,
|
||||
line: i32,
|
||||
initial_comment: &str,
|
||||
) -> AgentResult<i32> {
|
||||
use codetaker_db::schema::review_threads::dsl;
|
||||
|
||||
let mut conn = self.get_conn().await?;
|
||||
let project_id = Self::ensure_project_id(&mut conn, project).await?;
|
||||
let now = Utc::now().naive_utc();
|
||||
let new_row = NewReviewThreadRow {
|
||||
project_id,
|
||||
file,
|
||||
line,
|
||||
initial_comment,
|
||||
created_at: now,
|
||||
};
|
||||
|
||||
diesel::insert_into(review_threads::table)
|
||||
.values(&new_row)
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to create review thread: {err}"),
|
||||
})?;
|
||||
|
||||
dsl::review_threads
|
||||
.filter(dsl::project_id.eq(project_id))
|
||||
.filter(dsl::file.eq(file))
|
||||
.filter(dsl::line.eq(line))
|
||||
.filter(dsl::initial_comment.eq(initial_comment))
|
||||
.order(dsl::id.desc())
|
||||
.select(dsl::id)
|
||||
.first::<i32>(&mut conn)
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to fetch inserted review thread id: {err}"),
|
||||
})
|
||||
}
|
||||
|
||||
async fn append_thread_message(
|
||||
&self,
|
||||
thread_id: i32,
|
||||
author: MessageAuthor,
|
||||
body: &str,
|
||||
) -> AgentResult<()> {
|
||||
let mut conn = self.get_conn().await?;
|
||||
let now = Utc::now().naive_utc();
|
||||
|
||||
let row = NewReviewThreadMessageRow {
|
||||
thread_id,
|
||||
author: author.as_db_value(),
|
||||
body,
|
||||
created_at: now,
|
||||
};
|
||||
|
||||
diesel::insert_into(review_thread_messages::table)
|
||||
.values(&row)
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to append review thread message: {err}"),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_thread_messages(&self, thread_id: i32) -> AgentResult<Vec<ThreadMessage>> {
|
||||
use codetaker_db::schema::review_thread_messages::dsl;
|
||||
|
||||
let mut conn = self.get_conn().await?;
|
||||
let rows = dsl::review_thread_messages
|
||||
.filter(dsl::thread_id.eq(thread_id))
|
||||
.order(dsl::created_at.asc())
|
||||
.load::<ReviewThreadMessageRow>(&mut conn)
|
||||
.await
|
||||
.map_err(|err| AgentError::MemoryError {
|
||||
message: format!("failed to load review thread messages: {err}"),
|
||||
})?;
|
||||
|
||||
rows.into_iter()
|
||||
.map(|row| {
|
||||
let author = MessageAuthor::from_db_value(&row.author).ok_or_else(|| {
|
||||
AgentError::MemoryError {
|
||||
message: format!("invalid message author in database: {}", row.author),
|
||||
}
|
||||
})?;
|
||||
let created_at = chrono::DateTime::from_naive_utc_and_offset(row.created_at, Utc);
|
||||
Ok(ThreadMessage {
|
||||
author,
|
||||
body: row.body,
|
||||
created_at,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
243
crates/codetaker-agent/src/tools/ast_grep.rs
Normal file
243
crates/codetaker-agent/src/tools/ast_grep.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
use git2::Repository;
|
||||
use rig::completion::ToolDefinition;
|
||||
use rig::tool::Tool;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::error::{AgentError, AgentResult};
|
||||
use crate::git_access;
|
||||
use crate::tools::{FileContext, SearchHit};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AstGrepTool {
|
||||
ast_grep_bin: String,
|
||||
repo_git_dir: Option<PathBuf>,
|
||||
head_ref: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "operation", rename_all = "snake_case")]
|
||||
pub enum AstGrepArgs {
|
||||
SearchSymbol {
|
||||
query: String,
|
||||
},
|
||||
SearchPattern {
|
||||
query: String,
|
||||
},
|
||||
GetFileContext {
|
||||
file: String,
|
||||
line: i32,
|
||||
radius: i32,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "operation", rename_all = "snake_case")]
|
||||
pub enum AstGrepOutput {
|
||||
SearchHits { hits: Vec<SearchHit> },
|
||||
FileContext { context: FileContext },
|
||||
}
|
||||
|
||||
impl Default for AstGrepTool {
|
||||
fn default() -> Self {
|
||||
Self::new("ast-grep")
|
||||
}
|
||||
}
|
||||
|
||||
impl AstGrepTool {
|
||||
pub fn new(ast_grep_bin: impl Into<String>) -> Self {
|
||||
Self {
|
||||
ast_grep_bin: ast_grep_bin.into(),
|
||||
repo_git_dir: None,
|
||||
head_ref: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bind_to_request(
|
||||
&self,
|
||||
repo: &Repository,
|
||||
head_ref: impl Into<String>,
|
||||
) -> AgentResult<Self> {
|
||||
git_access::ensure_bare_repository(repo)?;
|
||||
Ok(Self {
|
||||
ast_grep_bin: self.ast_grep_bin.clone(),
|
||||
repo_git_dir: Some(repo.path().to_path_buf()),
|
||||
head_ref: Some(head_ref.into()),
|
||||
})
|
||||
}
|
||||
|
||||
fn bound_state(&self) -> AgentResult<(PathBuf, String)> {
|
||||
let repo_git_dir = self
|
||||
.repo_git_dir
|
||||
.clone()
|
||||
.ok_or_else(|| AgentError::ConfigError {
|
||||
message: "ast-grep tool is not bound to a repository context".to_owned(),
|
||||
})?;
|
||||
|
||||
let head_ref = self
|
||||
.head_ref
|
||||
.clone()
|
||||
.ok_or_else(|| AgentError::ConfigError {
|
||||
message: "ast-grep tool is not bound to a head_ref".to_owned(),
|
||||
})?;
|
||||
|
||||
Ok((repo_git_dir, head_ref))
|
||||
}
|
||||
|
||||
fn open_repo(&self, repo_git_dir: &Path) -> AgentResult<Repository> {
|
||||
Repository::open_bare(repo_git_dir).map_err(|err| AgentError::GitError {
|
||||
operation: format!("open bare repository at '{}'", repo_git_dir.display()),
|
||||
message: err.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn run_query(&self, pattern: &str) -> AgentResult<Vec<SearchHit>> {
|
||||
let (repo_git_dir, head_ref) = self.bound_state()?;
|
||||
let repo = self.open_repo(&repo_git_dir)?;
|
||||
let snapshot = git_access::materialize_ref_to_temp_dir(&repo, &head_ref)?;
|
||||
|
||||
let output = Command::new(&self.ast_grep_bin)
|
||||
.arg("run")
|
||||
.arg("--pattern")
|
||||
.arg(pattern)
|
||||
.arg("--json=stream")
|
||||
.arg("--no-color")
|
||||
.current_dir(snapshot.path())
|
||||
.output()
|
||||
.map_err(|err| AgentError::ToolError {
|
||||
tool: "ast-grep".to_owned(),
|
||||
message: format!("failed to execute ast-grep: {err}"),
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
return Err(AgentError::ToolError {
|
||||
tool: "ast-grep".to_owned(),
|
||||
message: format!("ast-grep exited with status {}: {stderr}", output.status),
|
||||
});
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let mut hits = Vec::new();
|
||||
|
||||
for line in stdout.lines().filter(|line| !line.trim().is_empty()) {
|
||||
let parsed: Value = match serde_json::from_str(line) {
|
||||
Ok(value) => value,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let file = parsed
|
||||
.get("file")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default()
|
||||
.to_owned();
|
||||
let snippet = parsed
|
||||
.get("text")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default()
|
||||
.to_owned();
|
||||
let line_num = parsed
|
||||
.get("range")
|
||||
.and_then(|range| range.get("start"))
|
||||
.and_then(|start| start.get("line"))
|
||||
.and_then(Value::as_i64)
|
||||
.map(|value| value as i32 + 1)
|
||||
.unwrap_or(1);
|
||||
let column = parsed
|
||||
.get("range")
|
||||
.and_then(|range| range.get("start"))
|
||||
.and_then(|start| start.get("column"))
|
||||
.and_then(Value::as_i64)
|
||||
.map(|value| value as i32 + 1);
|
||||
|
||||
hits.push(SearchHit {
|
||||
file,
|
||||
line: line_num,
|
||||
column,
|
||||
snippet,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(hits)
|
||||
}
|
||||
|
||||
pub fn search_symbol(&self, query: &str) -> AgentResult<Vec<SearchHit>> {
|
||||
self.run_query(query)
|
||||
}
|
||||
|
||||
pub fn search_pattern(&self, query: &str) -> AgentResult<Vec<SearchHit>> {
|
||||
self.run_query(query)
|
||||
}
|
||||
|
||||
pub fn get_file_context(&self, file: &str, line: i32, radius: i32) -> AgentResult<FileContext> {
|
||||
let (repo_git_dir, head_ref) = self.bound_state()?;
|
||||
let repo = self.open_repo(&repo_git_dir)?;
|
||||
git_access::read_file_context_at_ref(&repo, &head_ref, file, line, radius)
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for AstGrepTool {
|
||||
const NAME: &'static str = "ast_grep";
|
||||
|
||||
type Error = AgentError;
|
||||
type Args = AstGrepArgs;
|
||||
type Output = AstGrepOutput;
|
||||
|
||||
async fn definition(&self, _prompt: String) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: Self::NAME.to_owned(),
|
||||
description: "Search source code using ast-grep and fetch file context.".to_owned(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": { "type": "string", "const": "search_symbol" },
|
||||
"query": { "type": "string", "description": "Symbol-like query" }
|
||||
},
|
||||
"required": ["operation", "query"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": { "type": "string", "const": "search_pattern" },
|
||||
"query": { "type": "string", "description": "Pattern query" }
|
||||
},
|
||||
"required": ["operation", "query"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": { "type": "string", "const": "get_file_context" },
|
||||
"file": { "type": "string", "description": "Path to file in repository" },
|
||||
"line": { "type": "integer", "description": "1-based line number" },
|
||||
"radius": { "type": "integer", "description": "Number of lines before/after line" }
|
||||
},
|
||||
"required": ["operation", "file", "line", "radius"]
|
||||
}
|
||||
]
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
|
||||
match args {
|
||||
AstGrepArgs::SearchSymbol { query } => {
|
||||
let hits = self.search_symbol(&query)?;
|
||||
Ok(AstGrepOutput::SearchHits { hits })
|
||||
}
|
||||
AstGrepArgs::SearchPattern { query } => {
|
||||
let hits = self.search_pattern(&query)?;
|
||||
Ok(AstGrepOutput::SearchHits { hits })
|
||||
}
|
||||
AstGrepArgs::GetFileContext { file, line, radius } => {
|
||||
let context = self.get_file_context(&file, line, radius)?;
|
||||
Ok(AstGrepOutput::FileContext { context })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
21
crates/codetaker-agent/src/tools/mod.rs
Normal file
21
crates/codetaker-agent/src/tools/mod.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
mod ast_grep;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use ast_grep::{AstGrepArgs, AstGrepOutput, AstGrepTool};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct SearchHit {
|
||||
pub file: String,
|
||||
pub line: i32,
|
||||
pub column: Option<i32>,
|
||||
pub snippet: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct FileContext {
|
||||
pub file: String,
|
||||
pub line_start: i32,
|
||||
pub line_end: i32,
|
||||
pub snippet: String,
|
||||
}
|
||||
123
crates/codetaker-agent/src/types.rs
Normal file
123
crates/codetaker-agent/src/types.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||
pub enum ReviewResult {
|
||||
Approve,
|
||||
RequestChanges,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ReviewComment {
|
||||
pub comment: String,
|
||||
pub file: String,
|
||||
pub line: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct PullRequestReviewOutput {
|
||||
pub review_result: ReviewResult,
|
||||
pub global_comment: String,
|
||||
pub comments: Vec<ReviewComment>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum MessageAuthor {
|
||||
User,
|
||||
Agent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ThreadMessage {
|
||||
pub author: MessageAuthor,
|
||||
pub body: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Forge {
|
||||
Gitea,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ProjectRef {
|
||||
pub forge: Forge,
|
||||
pub owner: String,
|
||||
pub repo: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PullRequestReviewInput {
|
||||
pub project: ProjectRef,
|
||||
pub base_ref: String,
|
||||
pub head_ref: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConversationInput {
|
||||
pub project: ProjectRef,
|
||||
pub head_ref: String,
|
||||
pub anchor_file: String,
|
||||
pub anchor_line: i32,
|
||||
pub initial_comment: String,
|
||||
pub message_chain: Vec<ThreadMessage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ConversationOutput {
|
||||
pub reply: String,
|
||||
}
|
||||
|
||||
impl Forge {
|
||||
pub fn as_db_value(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Gitea => "gitea",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_db_value(value: &str) -> Option<Self> {
|
||||
match value {
|
||||
"gitea" => Some(Self::Gitea),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MessageAuthor {
|
||||
pub fn as_db_value(&self) -> &'static str {
|
||||
match self {
|
||||
Self::User => "user",
|
||||
Self::Agent => "agent",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_db_value(value: &str) -> Option<Self> {
|
||||
match value {
|
||||
"user" => Some(Self::User),
|
||||
"agent" => Some(Self::Agent),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn review_result_round_trip_serde() {
|
||||
let value = ReviewResult::RequestChanges;
|
||||
let raw = serde_json::to_string(&value).expect("serialize ReviewResult");
|
||||
let parsed: ReviewResult = serde_json::from_str(&raw).expect("deserialize ReviewResult");
|
||||
assert_eq!(parsed, value);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_author_round_trip_serde() {
|
||||
let value = MessageAuthor::Agent;
|
||||
let raw = serde_json::to_string(&value).expect("serialize MessageAuthor");
|
||||
let parsed: MessageAuthor = serde_json::from_str(&raw).expect("deserialize MessageAuthor");
|
||||
assert_eq!(parsed, value);
|
||||
}
|
||||
}
|
||||
115
crates/codetaker-agent/tests/sqlite_memory_store.rs
Normal file
115
crates/codetaker-agent/tests/sqlite_memory_store.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codetaker_agent::memory::{DieselMemoryStore, MemoryStore};
|
||||
use codetaker_agent::types::{Forge, MessageAuthor, ProjectRef};
|
||||
use codetaker_db::create_pool;
|
||||
use serde_json::json;
|
||||
|
||||
fn test_db_path(test_name: &str) -> PathBuf {
|
||||
let mut path = std::env::temp_dir();
|
||||
let nonce = format!(
|
||||
"{}_{}_{}.sqlite",
|
||||
test_name,
|
||||
std::process::id(),
|
||||
chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
|
||||
);
|
||||
path.push(nonce);
|
||||
path
|
||||
}
|
||||
|
||||
fn sample_project() -> ProjectRef {
|
||||
ProjectRef {
|
||||
forge: Forge::Gitea,
|
||||
owner: "acme".to_owned(),
|
||||
repo: "rocket".to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pool_creation_runs_migrations() {
|
||||
let db_path = test_db_path("pool_creation_runs_migrations");
|
||||
let database_url = db_path.display().to_string();
|
||||
|
||||
let pool = create_pool(Some(&database_url))
|
||||
.await
|
||||
.expect("create sqlite pool");
|
||||
let _conn = pool
|
||||
.get()
|
||||
.await
|
||||
.expect("get pooled sqlite connection after migration");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn upsert_memory_entry_overwrites_by_key() {
|
||||
let db_path = test_db_path("upsert_memory_entry_overwrites_by_key");
|
||||
let database_url = db_path.display().to_string();
|
||||
let pool = create_pool(Some(&database_url))
|
||||
.await
|
||||
.expect("create sqlite pool");
|
||||
let store = DieselMemoryStore::new(pool);
|
||||
|
||||
let project = sample_project();
|
||||
|
||||
store
|
||||
.upsert_memory_entry(&project, "style.preferred_errors", &json!("typed"), "seed")
|
||||
.await
|
||||
.expect("insert memory entry");
|
||||
|
||||
store
|
||||
.upsert_memory_entry(
|
||||
&project,
|
||||
"style.preferred_errors",
|
||||
&json!("typed-and-contextual"),
|
||||
"refresh",
|
||||
)
|
||||
.await
|
||||
.expect("update memory entry");
|
||||
|
||||
let snapshot = store
|
||||
.project_context_snapshot(&project)
|
||||
.await
|
||||
.expect("fetch project snapshot");
|
||||
|
||||
assert_eq!(
|
||||
snapshot.entries.get("style.preferred_errors"),
|
||||
Some(&json!("typed-and-contextual"))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_messages_round_trip() {
|
||||
let db_path = test_db_path("thread_messages_round_trip");
|
||||
let database_url = db_path.display().to_string();
|
||||
let pool = create_pool(Some(&database_url))
|
||||
.await
|
||||
.expect("create sqlite pool");
|
||||
let store = DieselMemoryStore::new(pool);
|
||||
|
||||
let project = sample_project();
|
||||
let thread_id = store
|
||||
.create_review_thread(&project, "src/lib.rs", 42, "Please avoid unwrap here")
|
||||
.await
|
||||
.expect("create review thread");
|
||||
|
||||
store
|
||||
.append_thread_message(thread_id, MessageAuthor::User, "Can you clarify why?")
|
||||
.await
|
||||
.expect("append user message");
|
||||
store
|
||||
.append_thread_message(
|
||||
thread_id,
|
||||
MessageAuthor::Agent,
|
||||
"This path can fail on malformed input.",
|
||||
)
|
||||
.await
|
||||
.expect("append agent message");
|
||||
|
||||
let messages = store
|
||||
.load_thread_messages(thread_id)
|
||||
.await
|
||||
.expect("load thread messages");
|
||||
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert!(matches!(messages[0].author, MessageAuthor::User));
|
||||
assert!(matches!(messages[1].author, MessageAuthor::Agent));
|
||||
}
|
||||
Reference in New Issue
Block a user