refactor: moved docker out of controller module

This commit is contained in:
hdbg
2025-12-04 18:40:27 +01:00
parent ba079d24b5
commit d39f67f3fe
7 changed files with 419 additions and 428 deletions

View File

@@ -8,7 +8,7 @@ use std::{
str::FromStr,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct PostgresVersion {
pub major: u32,
pub minor: u32,

View File

@@ -1,414 +1,21 @@
use miette::{bail, miette};
use rand::{Rng, distr::Alphanumeric};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, fmt::Write, pin::pin, str::FromStr};
use bollard::{
Docker,
errors::Error,
query_parameters::{
CreateContainerOptions, CreateImageOptions, InspectContainerOptions, ListImagesOptions,
StartContainerOptions, StopContainerOptions,
},
secret::{ContainerCreateBody, CreateImageInfo},
};
use futures::{Stream, StreamExt, TryStreamExt};
use indicatif::{MultiProgress, ProgressBar, ProgressState, ProgressStyle};
use miette::{Context, IntoDiagnostic, Result};
use tracing::info;
use miette::Result;
use crate::{
config::{PgxConfig, PostgresVersion, Project},
controller::docker::DockerController,
state::{InstanceState, StateManager},
};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ContainerStatus {
Running,
Stopped,
Paused,
Restarting,
Dead,
Unknown,
}
const DOCKERHUB_POSTGRES: &str = "postgres";
const DEFAULT_POSTGRES_PORT: u16 = 5432;
const PORT_SEARCH_RANGE: u16 = 100;
fn format_image(ver: &PostgresVersion) -> String {
format!("{DOCKERHUB_POSTGRES}:{}", ver)
}
fn find_available_port() -> Result<u16> {
use std::net::TcpListener;
for port in DEFAULT_POSTGRES_PORT..(DEFAULT_POSTGRES_PORT + PORT_SEARCH_RANGE) {
if TcpListener::bind(("127.0.0.1", port)).is_ok() {
return Ok(port);
}
}
miette::bail!(
"No available ports found in range {}-{}",
DEFAULT_POSTGRES_PORT,
DEFAULT_POSTGRES_PORT + PORT_SEARCH_RANGE - 1
)
}
fn new_download_pb(multi: &MultiProgress, layer_id: &str) -> ProgressBar {
let pb = multi.add(ProgressBar::new(0));
pb.set_style(
ProgressStyle::with_template(&"{spinner:.green} [{elapsed_precise}] {msg} [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})".to_string())
.unwrap()
.with_key("eta", |state: &ProgressState, w: &mut dyn Write| {
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
})
.progress_chars("#>-"),
);
pb.set_message(format!("Layer {}", layer_id));
pb
}
// sadly type ... = impl ... is unstable
pub async fn perform_download(
multi: MultiProgress,
chunks: impl Stream<Item = Result<CreateImageInfo, Error>>,
) -> Result<()> {
let mut chunks = pin!(chunks);
let mut layer_progress: HashMap<String, ProgressBar> = HashMap::new();
while let Some(download_info) = chunks.try_next().await.into_diagnostic()? {
download_check_for_error(&mut layer_progress, &download_info)?;
let layer_id = download_info.id.as_deref().unwrap_or("unknown");
// Get or create progress bar for this layer
let pb = layer_progress
.entry(layer_id.to_string())
.or_insert_with(|| new_download_pb(&multi, layer_id));
download_drive_progress(pb, download_info);
}
// Clean up any remaining progress bars
for (_, pb) in layer_progress.drain() {
pb.finish_and_clear();
}
Ok(())
}
fn download_drive_progress(pb: &mut ProgressBar, download_info: CreateImageInfo) {
match download_info.progress_detail {
Some(info) => match (info.current, info.total) {
(None, None) => {
pb.inc(1);
}
(current, total) => {
if let Some(total) = total {
pb.set_length(total as u64);
}
if let Some(current) = current {
pb.set_position(current as u64);
}
if let (Some(current), Some(total)) = (current, total)
&& (current == total)
{
pb.finish_with_message("Completed!");
}
}
},
None => {
// No progress detail, just show activity
pb.tick();
}
}
}
fn download_check_for_error(
layer_progress: &mut HashMap<String, ProgressBar>,
download_info: &CreateImageInfo,
) -> Result<()> {
if let Some(error_detail) = &download_info.error_detail {
for (_, pb) in layer_progress.drain() {
pb.finish_and_clear();
}
match (error_detail.code, &error_detail.message) {
(None, Some(msg)) => miette::bail!("docker image download error: {}", msg),
(Some(code), None) => miette::bail!("docker image download error: code {}", code),
(Some(code), Some(msg)) => {
miette::bail!(
"docker image download error: code {}, message: {}",
code,
msg
)
}
_ => (),
}
}
Ok(())
}
pub struct DockerController {
daemon: Docker,
}
impl DockerController {
pub async fn new() -> Result<Self> {
let docker = Docker::connect_with_local_defaults()
.into_diagnostic()
.wrap_err(
"Failed to connect to Docker! pgx required Docker installed. Make sure it's running.",
)?;
info!("docker.created");
docker
.list_images(Some(ListImagesOptions::default()))
.await
.into_diagnostic()
.wrap_err("Docker basic connectivity test refused")?;
Ok(Self { daemon: docker })
}
pub async fn download_image(&self, image: String) -> Result<()> {
let options = Some(CreateImageOptions {
from_image: Some(image.clone()),
..Default::default()
});
let download_progress = self.daemon.create_image(options, None, None);
let multi = MultiProgress::new();
println!("Downloading {image}");
perform_download(multi, download_progress).await?;
println!("Download complete!");
Ok(())
}
pub async fn ensure_version_downloaded(&self, ver: &PostgresVersion) -> Result<()> {
let desired_image_tag = format_image(ver);
let images = self
.daemon
.list_images(Some(ListImagesOptions::default()))
.await
.into_diagnostic()
.wrap_err("failed to list installed docker images")?;
let is_downloaded = images
.iter()
.any(|img| img.repo_tags.contains(&desired_image_tag));
if !is_downloaded {
self.download_image(desired_image_tag).await?;
}
Ok(())
}
// TODO: make client to get available versions from dockerhub
pub async fn available_versions(&self) -> Result<Vec<PostgresVersion>> {
Ok(vec!["18.1", "17.7", "16.11", "15.15", "14.20"]
.into_iter()
.map(|v| PostgresVersion::from_str(v).unwrap())
.collect())
}
pub async fn container_exists(&self, container_id: &str) -> Result<bool> {
match self
.daemon
.inspect_container(container_id, None::<InspectContainerOptions>)
.await
{
Ok(_) => Ok(true),
Err(bollard::errors::Error::DockerResponseServerError {
status_code: 404, ..
}) => Ok(false),
Err(e) => Err(e)
.into_diagnostic()
.wrap_err("Failed to inspect container"),
}
}
pub async fn is_container_running(&self, container_name: &str) -> Result<bool> {
let container = self
.daemon
.inspect_container(container_name, None::<InspectContainerOptions>)
.await
.into_diagnostic()
.wrap_err("Failed to inspect container")?;
Ok(container.state.and_then(|s| s.running).unwrap_or(false))
}
pub async fn create_postgres_container(
&self,
container_name: &str,
version: &PostgresVersion,
password: &str,
port: u16,
) -> Result<String> {
use bollard::models::{HostConfig, PortBinding};
use std::collections::HashMap;
let image = format_image(version);
let env = vec![
format!("POSTGRES_PASSWORD={}", password),
format!("POSTGRES_USER={}", USERNAME),
format!("POSTGRES_DB={}", DATABASE),
];
let mut port_bindings = HashMap::new();
port_bindings.insert(
"5432/tcp".to_string(),
Some(vec![PortBinding {
host_ip: Some("127.0.0.1".to_string()),
host_port: Some(port.to_string()),
}]),
);
let host_config = HostConfig {
port_bindings: Some(port_bindings),
..Default::default()
};
let mut labels = HashMap::new();
labels.insert("pgx.postgres.version".to_string(), version.to_string());
let config = ContainerCreateBody {
image: Some(image),
env: Some(env),
host_config: Some(host_config),
labels: Some(labels),
..Default::default()
};
let options = CreateContainerOptions {
name: Some(container_name.to_owned()),
platform: String::new(),
};
let response = self
.daemon
.create_container(Some(options), config)
.await
.into_diagnostic()
.wrap_err("Failed to create container")?;
Ok(response.id)
}
pub async fn start_container(&self, container_id: &str) -> Result<()> {
self.daemon
.start_container(container_id, None::<StartContainerOptions>)
.await
.into_diagnostic()
.wrap_err("Failed to start container")?;
Ok(())
}
pub async fn container_exists_by_id(&self, container_id: &str) -> Result<bool> {
match self
.daemon
.inspect_container(container_id, None::<InspectContainerOptions>)
.await
{
Ok(_) => Ok(true),
Err(bollard::errors::Error::DockerResponseServerError {
status_code: 404, ..
}) => Ok(false),
Err(e) => Err(e)
.into_diagnostic()
.wrap_err("Failed to inspect container by ID"),
}
}
pub async fn is_container_running_by_id(&self, container_id: &str) -> Result<bool> {
let container = self
.daemon
.inspect_container(container_id, None::<InspectContainerOptions>)
.await
.into_diagnostic()
.wrap_err("Failed to inspect container")?;
Ok(container.state.and_then(|s| s.running).unwrap_or(false))
}
pub async fn start_container_by_id(&self, container_id: &str) -> Result<()> {
self.start_container(container_id).await
}
pub async fn stop_container(&self, container_id: &str, timeout: i32) -> Result<()> {
self.daemon
.stop_container(
container_id,
Some(StopContainerOptions {
t: Some(timeout),
signal: None,
}),
)
.await
.into_diagnostic()
.wrap_err("Failed to stop container")?;
Ok(())
}
pub async fn get_container_postgres_version(
&self,
container_id: &str,
) -> Result<PostgresVersion> {
let container = self
.daemon
.inspect_container(container_id, None::<InspectContainerOptions>)
.await
.into_diagnostic()
.wrap_err("Failed to inspect container")?;
let labels = container
.config
.and_then(|c| c.labels)
.ok_or_else(|| miette!("Container has no labels"))?;
let version_str = labels
.get("pgx.postgres.version")
.ok_or_else(|| miette!("Container missing pgx.postgres.version label"))?;
PostgresVersion::from_str(version_str)
.map_err(|_| miette!("Invalid version in label: {}", version_str))
}
}
const USERNAME: &str = "postgres";
const DATABASE: &str = "postgres";
const PASSWORD_LENGTH: usize = 16;
pub fn generate_password() -> String {
(&mut rand::rng())
.sample_iter(Alphanumeric)
.take(PASSWORD_LENGTH)
.map(|b| b as char)
.collect()
}
mod docker;
mod utils;
const MAX_RETRIES: u32 = 10;
const VERIFY_DURATION_SECS: u64 = 10;
pub struct Controller {
pub docker: DockerController,
docker: DockerController,
project: Option<Project>,
state: StateManager,
}
@@ -437,8 +44,8 @@ impl Controller {
let config = PgxConfig {
version: *latest_version,
password: generate_password(),
port: find_available_port()?,
password: utils::generate_password(),
port: utils::find_available_port()?,
};
let project = Project::new(config)?;
@@ -561,9 +168,7 @@ impl Controller {
project.name.clone(),
crate::state::InstanceState::new(
id.clone(),
project.config.version.to_string(),
DATABASE.to_string(),
USERNAME.to_string(),
project.config.version,
project.config.port,
),
);

264
src/controller/docker.rs Normal file
View File

@@ -0,0 +1,264 @@
use miette::miette;
use std::str::FromStr;
use bollard::{
Docker,
query_parameters::{
CreateContainerOptions, CreateImageOptions, InspectContainerOptions, ListImagesOptions,
StartContainerOptions, StopContainerOptions,
},
secret::ContainerCreateBody,
};
use indicatif::MultiProgress;
use miette::{Context, IntoDiagnostic, Result};
use tracing::info;
use crate::{
config::PostgresVersion,
consts::{DATABASE, USERNAME},
};
mod download;
const DOCKERHUB_POSTGRES: &str = "postgres";
fn format_image(ver: &PostgresVersion) -> String {
format!("{DOCKERHUB_POSTGRES}:{}", ver)
}
pub struct DockerController {
daemon: Docker,
}
impl DockerController {
pub async fn new() -> Result<Self> {
let docker = Docker::connect_with_local_defaults()
.into_diagnostic()
.wrap_err(
"Failed to connect to Docker! pgx required Docker installed. Make sure it's running.",
)?;
info!("docker.created");
docker
.list_images(Some(ListImagesOptions::default()))
.await
.into_diagnostic()
.wrap_err("Docker basic connectivity test refused")?;
Ok(Self { daemon: docker })
}
pub async fn download_image(&self, image: String) -> Result<()> {
let options = Some(CreateImageOptions {
from_image: Some(image.clone()),
..Default::default()
});
let download_progress = self.daemon.create_image(options, None, None);
let multi = MultiProgress::new();
println!("Downloading {image}");
download::perform_download(multi, download_progress).await?;
println!("Download complete!");
Ok(())
}
pub async fn ensure_version_downloaded(&self, ver: &PostgresVersion) -> Result<()> {
let desired_image_tag = format_image(ver);
let images = self
.daemon
.list_images(Some(ListImagesOptions::default()))
.await
.into_diagnostic()
.wrap_err("failed to list installed docker images")?;
let is_downloaded = images
.iter()
.any(|img| img.repo_tags.contains(&desired_image_tag));
if !is_downloaded {
self.download_image(desired_image_tag).await?;
}
Ok(())
}
// TODO: make client to get available versions from dockerhub
pub async fn available_versions(&self) -> Result<Vec<PostgresVersion>> {
Ok(vec!["18.1", "17.7", "16.11", "15.15", "14.20"]
.into_iter()
.map(|v| PostgresVersion::from_str(v).unwrap())
.collect())
}
pub async fn container_exists(&self, container_id: &str) -> Result<bool> {
match self
.daemon
.inspect_container(container_id, None::<InspectContainerOptions>)
.await
{
Ok(_) => Ok(true),
Err(bollard::errors::Error::DockerResponseServerError {
status_code: 404, ..
}) => Ok(false),
Err(e) => Err(e)
.into_diagnostic()
.wrap_err("Failed to inspect container"),
}
}
pub async fn is_container_running(&self, container_name: &str) -> Result<bool> {
let container = self
.daemon
.inspect_container(container_name, None::<InspectContainerOptions>)
.await
.into_diagnostic()
.wrap_err("Failed to inspect container")?;
Ok(container.state.and_then(|s| s.running).unwrap_or(false))
}
pub async fn create_postgres_container(
&self,
container_name: &str,
version: &PostgresVersion,
password: &str,
port: u16,
) -> Result<String> {
use bollard::models::{HostConfig, PortBinding};
use std::collections::HashMap;
let image = format_image(version);
let env = vec![
format!("POSTGRES_PASSWORD={}", password),
format!("POSTGRES_USER={}", USERNAME),
format!("POSTGRES_DB={}", DATABASE),
];
let mut port_bindings = HashMap::new();
port_bindings.insert(
"5432/tcp".to_string(),
Some(vec![PortBinding {
host_ip: Some("127.0.0.1".to_string()),
host_port: Some(port.to_string()),
}]),
);
let host_config = HostConfig {
port_bindings: Some(port_bindings),
..Default::default()
};
let mut labels = HashMap::new();
labels.insert("pgx.postgres.version".to_string(), version.to_string());
let config = ContainerCreateBody {
image: Some(image),
env: Some(env),
host_config: Some(host_config),
labels: Some(labels),
..Default::default()
};
let options = CreateContainerOptions {
name: Some(container_name.to_owned()),
platform: String::new(),
};
let response = self
.daemon
.create_container(Some(options), config)
.await
.into_diagnostic()
.wrap_err("Failed to create container")?;
Ok(response.id)
}
pub async fn start_container(&self, container_id: &str) -> Result<()> {
self.daemon
.start_container(container_id, None::<StartContainerOptions>)
.await
.into_diagnostic()
.wrap_err("Failed to start container")?;
Ok(())
}
pub async fn container_exists_by_id(&self, container_id: &str) -> Result<bool> {
match self
.daemon
.inspect_container(container_id, None::<InspectContainerOptions>)
.await
{
Ok(_) => Ok(true),
Err(bollard::errors::Error::DockerResponseServerError {
status_code: 404, ..
}) => Ok(false),
Err(e) => Err(e)
.into_diagnostic()
.wrap_err("Failed to inspect container by ID"),
}
}
pub async fn is_container_running_by_id(&self, container_id: &str) -> Result<bool> {
let container = self
.daemon
.inspect_container(container_id, None::<InspectContainerOptions>)
.await
.into_diagnostic()
.wrap_err("Failed to inspect container")?;
Ok(container.state.and_then(|s| s.running).unwrap_or(false))
}
pub async fn start_container_by_id(&self, container_id: &str) -> Result<()> {
self.start_container(container_id).await
}
pub async fn stop_container(&self, container_id: &str, timeout: i32) -> Result<()> {
self.daemon
.stop_container(
container_id,
Some(StopContainerOptions {
t: Some(timeout),
signal: None,
}),
)
.await
.into_diagnostic()
.wrap_err("Failed to stop container")?;
Ok(())
}
pub async fn get_container_postgres_version(
&self,
container_id: &str,
) -> Result<PostgresVersion> {
let container = self
.daemon
.inspect_container(container_id, None::<InspectContainerOptions>)
.await
.into_diagnostic()
.wrap_err("Failed to inspect container")?;
let labels = container
.config
.and_then(|c| c.labels)
.ok_or_else(|| miette!("Container has no labels"))?;
let version_str = labels
.get("pgx.postgres.version")
.ok_or_else(|| miette!("Container missing pgx.postgres.version label"))?;
PostgresVersion::from_str(version_str)
.map_err(|_| miette!("Invalid version in label: {}", version_str))
}
}

View File

@@ -0,0 +1,103 @@
use miette::{IntoDiagnostic, Result};
use std::{collections::HashMap, fmt::Write, pin::pin};
use bollard::{errors::Error, secret::CreateImageInfo};
use futures::{Stream, TryStreamExt};
use indicatif::{MultiProgress, ProgressBar, ProgressState, ProgressStyle};
fn new_download_pb(multi: &MultiProgress, layer_id: &str) -> ProgressBar {
let pb = multi.add(ProgressBar::new(0));
pb.set_style(
ProgressStyle::with_template(&"{spinner:.green} [{elapsed_precise}] {msg} [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})".to_string())
.unwrap()
.with_key("eta", |state: &ProgressState, w: &mut dyn Write| {
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
})
.progress_chars("#>-"),
);
pb.set_message(format!("Layer {}", layer_id));
pb
}
fn download_drive_progress(pb: &mut ProgressBar, download_info: CreateImageInfo) {
match download_info.progress_detail {
Some(info) => match (info.current, info.total) {
(None, None) => {
pb.inc(1);
}
(current, total) => {
if let Some(total) = total {
pb.set_length(total as u64);
}
if let Some(current) = current {
pb.set_position(current as u64);
}
if let (Some(current), Some(total)) = (current, total)
&& (current == total)
{
pb.finish_with_message("Completed!");
}
}
},
None => {
// No progress detail, just show activity
pb.tick();
}
}
}
fn download_check_for_error(
layer_progress: &mut HashMap<String, ProgressBar>,
download_info: &CreateImageInfo,
) -> Result<()> {
if let Some(error_detail) = &download_info.error_detail {
for (_, pb) in layer_progress.drain() {
pb.finish_and_clear();
}
match (error_detail.code, &error_detail.message) {
(None, Some(msg)) => miette::bail!("docker image download error: {}", msg),
(Some(code), None) => miette::bail!("docker image download error: code {}", code),
(Some(code), Some(msg)) => {
miette::bail!(
"docker image download error: code {}, message: {}",
code,
msg
)
}
_ => (),
}
}
Ok(())
}
// sadly type ... = impl ... is unstable
pub async fn perform_download(
multi: MultiProgress,
chunks: impl Stream<Item = Result<CreateImageInfo, Error>>,
) -> Result<()> {
let mut chunks = pin!(chunks);
let mut layer_progress: HashMap<String, ProgressBar> = HashMap::new();
while let Some(download_info) = chunks.try_next().await.into_diagnostic()? {
download_check_for_error(&mut layer_progress, &download_info)?;
let layer_id = download_info.id.as_deref().unwrap_or("unknown");
// Get or create progress bar for this layer
let pb = layer_progress
.entry(layer_id.to_string())
.or_insert_with(|| new_download_pb(&multi, layer_id));
download_drive_progress(pb, download_info);
}
// Clean up any remaining progress bars
for (_, pb) in layer_progress.drain() {
pb.finish_and_clear();
}
Ok(())
}

29
src/controller/utils.rs Normal file
View File

@@ -0,0 +1,29 @@
use miette::Result;
use rand::{Rng, distr::Alphanumeric};
const DEFAULT_POSTGRES_PORT: u16 = 5432;
const PORT_SEARCH_RANGE: u16 = 100;
pub fn find_available_port() -> Result<u16> {
use std::net::TcpListener;
for port in DEFAULT_POSTGRES_PORT..(DEFAULT_POSTGRES_PORT + PORT_SEARCH_RANGE) {
if TcpListener::bind(("127.0.0.1", port)).is_ok() {
return Ok(port);
}
}
miette::bail!(
"No available ports found in range {}-{}",
DEFAULT_POSTGRES_PORT,
DEFAULT_POSTGRES_PORT + PORT_SEARCH_RANGE - 1
)
}
const PASSWORD_LENGTH: usize = 16;
pub fn generate_password() -> String {
(&mut rand::rng())
.sample_iter(Alphanumeric)
.take(PASSWORD_LENGTH)
.map(|b| b as char)
.collect()
}

View File

@@ -2,6 +2,11 @@ mod cli;
mod config;
mod state;
mod consts {
pub const USERNAME: &str = "postgres";
pub const DATABASE: &str = "postgres";
}
mod controller;
use clap::Parser;
@@ -31,7 +36,7 @@ async fn main() -> Result<()> {
}
fn init_tracing(verbose: bool) {
use tracing_subscriber::{fmt, prelude::*};
tracing_subscriber::fmt::init();
}

View File

@@ -3,6 +3,8 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use crate::config::PostgresVersion;
/// State information for a single PostgreSQL instance
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InstanceState {
@@ -10,22 +12,13 @@ pub struct InstanceState {
pub container_id: String,
/// PostgreSQL version running in the container
pub postgres_version: String,
/// Database name
pub database_name: String,
/// User name
pub user_name: String,
pub postgres_version: PostgresVersion,
/// Port the container is bound to
pub port: u16,
/// Timestamp when the instance was created (Unix timestamp)
pub created_at: u64,
/// Timestamp when the instance was last started (Unix timestamp)
pub last_started_at: Option<u64>,
}
/// Manages the global state file at ~/.pgx/state.json
@@ -147,13 +140,7 @@ impl StateManager {
}
impl InstanceState {
pub fn new(
container_id: String,
postgres_version: String,
database_name: String,
user_name: String,
port: u16,
) -> Self {
pub fn new(container_id: String, postgres_version: PostgresVersion, port: u16) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
@@ -162,11 +149,8 @@ impl InstanceState {
InstanceState {
container_id,
postgres_version,
database_name,
user_name,
port,
created_at: now,
last_started_at: Some(now),
}
}
}
@@ -183,9 +167,10 @@ mod tests {
let state = InstanceState::new(
"container123".to_string(),
"16".to_string(),
"mydb".to_string(),
"postgres".to_string(),
PostgresVersion {
major: 18,
minor: 1,
},
5432,
);