Login
4 branches 0 tags
Ben (T14/NixOS) Improved flake 41eb128 11 days ago 252 Commits
rubhub / src / controllers / runner.rs
//! Remote CI runner API endpoints

use std::convert::Infallible;
use std::path::Path;

use axum::{
    Json,
    body::Body,
    extract::{Path as AxumPath, Query, State},
    http::{StatusCode, header},
    response::{IntoResponse, Response},
};
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use tokio::process::Command;
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;

use crate::GlobalState;
use crate::extractors::RunnerAuth;
use rubhub_state::{ConnectedRunner, RunnerJobEvent, RunnerTags};

/// Query parameters for runner connection
#[derive(Debug, Deserialize)]
pub struct RunnerConnectQuery {
    /// Operating system (e.g., "linux", "darwin")
    pub os: Option<String>,
    /// CPU architecture (e.g., "x86_64", "aarch64")
    pub arch: Option<String>,
    /// Additional tags (comma-separated)
    pub tags: Option<String>,
}

/// Path parameters for job-specific endpoints
#[derive(Debug, Deserialize)]
pub struct JobPath {
    pub owner: String,
    pub project: String,
    pub job_id: Uuid,
}

/// GET /.runners/connect - SSE endpoint for runners to receive jobs
pub async fn runner_connect(
    State(state): State<GlobalState>,
    RunnerAuth(token): RunnerAuth,
    Query(query): Query<RunnerConnectQuery>,
) -> Response<Body> {
    // Build tags from query parameters or detect from system
    let tags = RunnerTags {
        os: query.os.unwrap_or_else(|| std::env::consts::OS.to_string()),
        arch: query.arch.unwrap_or_else(|| match std::env::consts::ARCH {
            "x86_64" | "amd64" => "x86_64".to_string(),
            "aarch64" | "arm64" => "aarch64".to_string(),
            other => other.to_string(),
        }),
        extra: query
            .tags
            .map(|t| t.split(',').map(|s| s.trim().to_string()).collect())
            .unwrap_or_else(|| vec!["nixos".to_string()]),
    };

    // Create channel for sending events to this runner
    let (tx, rx) = mpsc::channel::<RunnerJobEvent>(32);

    // Register the runner
    let runner = ConnectedRunner::new(token.token_hash.clone(), tags.clone(), tx);
    let runner_id = state.runner_registry.register(runner);

    println!(
        "Runner connected: {} (label: {}, os: {}, arch: {})",
        runner_id, token.label, tags.os, tags.arch
    );

    // Create SSE stream from the receiver
    let state_clone = state.clone();
    let stream = ReceiverStream::new(rx)
        .map(|event| {
            let json = serde_json::to_string(&event).unwrap_or_default();
            Ok::<_, Infallible>(format!("data: {}\n\n", json))
        })
        .chain(futures::stream::once(async move {
            // When the stream ends, unregister the runner
            state_clone.runner_registry.unregister(runner_id).await;
            println!("Runner disconnected: {}", runner_id);
            Ok::<_, Infallible>("".to_string())
        }));

    Response::builder()
        .status(StatusCode::OK)
        .header(header::CONTENT_TYPE, "text/event-stream")
        .header(header::CACHE_CONTROL, "no-cache")
        .header(header::CONNECTION, "keep-alive")
        .body(Body::from_stream(stream))
        .unwrap()
}

/// Heartbeat request body (empty - just proves the runner is alive)
#[derive(Debug, Deserialize)]
pub struct HeartbeatRequest {}

/// Heartbeat response
#[derive(Debug, Serialize)]
pub struct HeartbeatResponse {
    pub ok: bool,
    #[serde(with = "time::serde::rfc3339")]
    pub server_time: OffsetDateTime,
}

/// POST /.runners/heartbeat - Runner heartbeat
pub async fn runner_heartbeat(
    State(state): State<GlobalState>,
    RunnerAuth(token): RunnerAuth,
    Json(_body): Json<HeartbeatRequest>,
) -> Result<Json<HeartbeatResponse>, StatusCode> {
    // Find runner by token
    let runner = state
        .runner_registry
        .get_by_token(&token.token_hash)
        .ok_or(StatusCode::NOT_FOUND)?;

    // Update last heartbeat time
    {
        let mut runner = runner.write().await;
        runner.last_heartbeat = OffsetDateTime::now_utc();
        // Note: current_job is managed by the master when assigning/completing jobs,
        // not reported by the runner
    }

    Ok(Json(HeartbeatResponse {
        ok: true,
        server_time: OffsetDateTime::now_utc(),
    }))
}

/// GET /.runners/jobs/{owner}/{project}/{job_id}/source.tar.gz - Download source tarball
pub async fn runner_source(
    State(state): State<GlobalState>,
    RunnerAuth(_token): RunnerAuth,
    AxumPath(path): AxumPath<JobPath>,
) -> Result<impl IntoResponse, StatusCode> {
    // Construct job directory directly from path
    let job_dir = state
        .config
        .ci_root
        .join(&path.owner)
        .join(&path.project)
        .join(path.job_id.to_string());

    // Load the job metadata
    let job_file = job_dir.join("job.json");
    let content = tokio::fs::read_to_string(&job_file)
        .await
        .map_err(|_| StatusCode::NOT_FOUND)?;

    let job: crate::models::CiJob =
        serde_json::from_str(&content).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // Get the repo path
    let repo_path = state.config.git_root.join(&path.owner).join(&path.project);

    // Create tarball using git archive
    let tarball = create_git_archive(&repo_path, &job.commit_hash)
        .await
        .map_err(|e| {
            eprintln!("Failed to create git archive: {}", e);
            StatusCode::INTERNAL_SERVER_ERROR
        })?;

    Ok(([(header::CONTENT_TYPE, "application/gzip")], tarball))
}

/// Create a tarball of the repository at a specific commit
async fn create_git_archive(repo_path: &Path, commit: &str) -> anyhow::Result<Vec<u8>> {
    let output = Command::new("git")
        .args(["archive", "--format=tar.gz", commit])
        .current_dir(repo_path)
        .output()
        .await?;

    if !output.status.success() {
        anyhow::bail!(
            "git archive failed: {}",
            String::from_utf8_lossy(&output.stderr)
        );
    }

    Ok(output.stdout)
}

/// Log chunk request body
#[derive(Debug, Deserialize)]
pub struct LogChunkRequest {
    /// Name of the job within the workflow
    pub job_name: String,
    /// Base64-encoded log data
    pub data: String,
    /// Job token for authentication
    pub job_token: String,
}

/// POST /.runners/jobs/{owner}/{project}/{job_id}/log - Stream log chunk from runner
pub async fn runner_log(
    State(state): State<GlobalState>,
    RunnerAuth(_token): RunnerAuth,
    AxumPath(path): AxumPath<JobPath>,
    Json(body): Json<LogChunkRequest>,
) -> Result<StatusCode, StatusCode> {
    use base64::Engine;
    use tokio::io::AsyncWriteExt;

    // Construct job directory directly from path
    let job_dir = state
        .config
        .ci_root
        .join(&path.owner)
        .join(&path.project)
        .join(path.job_id.to_string());

    // Verify job token
    let job_file = job_dir.join("job.json");
    let content = tokio::fs::read_to_string(&job_file)
        .await
        .map_err(|_| StatusCode::NOT_FOUND)?;
    let job: crate::models::CiJob =
        serde_json::from_str(&content).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    if job.job_token.as_deref() != Some(&body.job_token) {
        return Err(StatusCode::FORBIDDEN);
    }

    // Decode the log data
    let log_data = base64::engine::general_purpose::STANDARD
        .decode(&body.data)
        .map_err(|_| StatusCode::BAD_REQUEST)?;

    // Append to the log file
    let log_path = job_dir.join("jobs").join(&body.job_name).join("output.log");

    // Ensure the directory exists
    if let Some(parent) = log_path.parent() {
        tokio::fs::create_dir_all(parent)
            .await
            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    }

    // Append to the log file
    let mut file = tokio::fs::OpenOptions::new()
        .create(true)
        .append(true)
        .open(&log_path)
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    file.write_all(&log_data)
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    Ok(StatusCode::OK)
}

/// Job completion request body
#[derive(Debug, Deserialize)]
pub struct JobCompleteRequest {
    pub success: bool,
    pub exit_code: Option<i32>,
    pub jobs: Vec<JobResultUpdate>,
    /// Job token for authentication
    pub job_token: String,
}

#[derive(Debug, Deserialize)]
pub struct JobResultUpdate {
    pub name: String,
    pub exit_code: Option<i32>,
    pub steps: Vec<StepResultUpdate>,
}

#[derive(Debug, Deserialize)]
pub struct StepResultUpdate {
    pub name: String,
    pub exit_code: Option<i32>,
}

/// POST /.runners/jobs/{owner}/{project}/{job_id}/complete - Runner reports job completion
pub async fn runner_complete(
    State(state): State<GlobalState>,
    RunnerAuth(token): RunnerAuth,
    AxumPath(path): AxumPath<JobPath>,
    Json(body): Json<JobCompleteRequest>,
) -> Result<StatusCode, StatusCode> {
    use crate::models::CiJobStatus;

    // Construct job directory directly from path
    let job_dir = state
        .config
        .ci_root
        .join(&path.owner)
        .join(&path.project)
        .join(path.job_id.to_string());

    // Load the job metadata
    let job_file = job_dir.join("job.json");
    let content = tokio::fs::read_to_string(&job_file)
        .await
        .map_err(|_| StatusCode::NOT_FOUND)?;

    let mut job: crate::models::CiJob =
        serde_json::from_str(&content).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // Verify job token
    if job.job_token.as_deref() != Some(&body.job_token) {
        return Err(StatusCode::FORBIDDEN);
    }

    // Update job status
    job.status = if body.success {
        CiJobStatus::Success
    } else {
        CiJobStatus::Failed
    };
    job.exit_code = body.exit_code;
    job.finished_at = Some(OffsetDateTime::now_utc());

    // Update individual job results
    for update in &body.jobs {
        if let Some(job_result) = job.jobs.iter_mut().find(|j| j.name == update.name) {
            job_result.exit_code = update.exit_code;
            job_result.status = if update.exit_code == Some(0) {
                CiJobStatus::Success
            } else {
                CiJobStatus::Failed
            };
            job_result.finished_at = Some(OffsetDateTime::now_utc());

            // Update steps
            for step_update in &update.steps {
                if let Some(step) = job_result
                    .steps
                    .iter_mut()
                    .find(|s| s.name == step_update.name)
                {
                    step.exit_code = step_update.exit_code;
                    step.status = if step_update.exit_code == Some(0) {
                        CiJobStatus::Success
                    } else {
                        CiJobStatus::Failed
                    };
                }
            }
        }
    }

    // Save updated job metadata atomically
    let content =
        serde_json::to_string_pretty(&job).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    let tmp_file = job_file.with_extension("json.tmp");
    tokio::fs::write(&tmp_file, &content)
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
    tokio::fs::rename(&tmp_file, &job_file)
        .await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // Clear the runner's current job
    if let Some(runner) = state.runner_registry.get_by_token(&token.token_hash) {
        let mut runner = runner.write().await;
        if runner
            .current_job
            .as_ref()
            .is_some_and(|j| j.job_id == path.job_id)
        {
            runner.current_job = None;
        }
    }

    println!("Job {} completed: success={}", path.job_id, body.success);

    Ok(StatusCode::OK)
}