diff --git a/Cargo.lock b/Cargo.lock index 6286da2697..265b093c75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2349,6 +2349,7 @@ dependencies = [ "glob", "google-cloud-auth", "http 1.4.0", + "libc", "libsqlite3-sys", "oauth2", "open", diff --git a/crates/forge_api/src/forge_api.rs b/crates/forge_api/src/forge_api.rs index aca7637afc..9af231bd02 100644 --- a/crates/forge_api/src/forge_api.rs +++ b/crates/forge_api/src/forge_api.rs @@ -274,7 +274,7 @@ impl< .credential .as_ref() .and_then(|c| match &c.auth_details { - forge_domain::AuthDetails::ApiKey(key) => Some(key.as_str()), + forge_domain::AuthDetails::ApiKey(provider) => Some(provider.api_key().as_str()), _ => None, }) { diff --git a/crates/forge_app/src/command_generator.rs b/crates/forge_app/src/command_generator.rs index 122fbc2ec8..0b6c9bb0dd 100644 --- a/crates/forge_app/src/command_generator.rs +++ b/crates/forge_app/src/command_generator.rs @@ -273,7 +273,7 @@ mod tests { url_params: vec![], credential: Some(AuthCredential { id: ProviderId::OPENAI, - auth_details: AuthDetails::ApiKey("test-key".to_string().into()), + auth_details: AuthDetails::static_api_key("test-key".to_string().into()), url_params: Default::default(), }), custom_headers: None, diff --git a/crates/forge_app/src/dto/openai/transformers/pipeline.rs b/crates/forge_app/src/dto/openai/transformers/pipeline.rs index dd169e442c..c948b15fbc 100644 --- a/crates/forge_app/src/dto/openai/transformers/pipeline.rs +++ b/crates/forge_app/src/dto/openai/transformers/pipeline.rs @@ -150,7 +150,7 @@ mod tests { fn make_credential(provider_id: ProviderId, key: &str) -> Option { Some(forge_domain::AuthCredential { id: provider_id, - auth_details: forge_domain::AuthDetails::ApiKey(forge_domain::ApiKey::from( + auth_details: forge_domain::AuthDetails::static_api_key(forge_domain::ApiKey::from( key.to_string(), )), url_params: HashMap::new(), diff --git a/crates/forge_config/src/config.rs b/crates/forge_config/src/config.rs index d08fdb3a9b..63841ceff1 100644 --- a/crates/forge_config/src/config.rs +++ b/crates/forge_config/src/config.rs @@ -72,6 +72,11 @@ pub struct ProviderEntry { /// Environment variable holding the API key for this provider. #[serde(default, skip_serializing_if = "Option::is_none")] pub api_key_var: Option, + /// Shell command that produces an API key on stdout. When set, the + /// command is executed instead of reading a static key from an environment + /// variable. Falls back to `{api_key_var}_HELPER` env var when absent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub api_key_helper: Option, /// URL template for chat completions; may contain `{{VAR}}` placeholders /// that are substituted from the credential's url params. pub url: String, @@ -367,4 +372,26 @@ mod tests { assert_eq!(actual.temperature, fixture.temperature); } + + #[test] + fn test_provider_entry_api_key_helper_round_trip() { + let fixture = ForgeConfig { + providers: vec![ProviderEntry { + id: "test_provider".to_string(), + url: "https://api.example.com/v1/chat".to_string(), + api_key_helper: Some("vault read -field=token secret/key".to_string()), + ..Default::default() + }], + ..Default::default() + }; + + let toml = toml_edit::ser::to_string_pretty(&fixture).unwrap(); + let actual = ConfigReader::default().read_toml(&toml).build().unwrap(); + + assert_eq!(actual.providers.len(), 1); + assert_eq!( + actual.providers[0].api_key_helper, + Some("vault read -field=token secret/key".to_string()) + ); + } } diff --git a/crates/forge_domain/src/auth/auth_context.rs b/crates/forge_domain/src/auth/auth_context.rs index 418c058577..9ab646424b 100644 --- a/crates/forge_domain/src/auth/auth_context.rs +++ b/crates/forge_domain/src/auth/auth_context.rs @@ -26,6 +26,10 @@ pub struct ApiKeyRequest { pub struct ApiKeyResponse { pub api_key: ApiKey, pub url_params: HashMap, + /// When set, the API key was produced by this shell command and the + /// credential should be stored as a + /// [`HelperCommand`](super::ApiKeyProvider::HelperCommand). + pub helper_command: Option, } // Authorization Code Flow @@ -95,7 +99,7 @@ pub enum AuthContextResponse { } impl AuthContextResponse { - /// Creates an API key authentication context + /// Creates an API key authentication context with a static key. pub fn api_key( request: ApiKeyRequest, api_key: impl ToString, @@ -109,6 +113,27 @@ impl AuthContextResponse { .into_iter() .map(|(k, v)| (k.into(), v.into())) .collect(), + helper_command: None, + }, + }) + } + + /// Creates an API key authentication context backed by a helper command. + pub fn api_key_with_helper( + request: ApiKeyRequest, + api_key: impl ToString, + url_params: HashMap, + command: String, + ) -> Self { + Self::ApiKey(AuthContext { + request, + response: ApiKeyResponse { + api_key: api_key.to_string().into(), + url_params: url_params + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(), + helper_command: Some(command), }, }) } diff --git a/crates/forge_domain/src/auth/credentials.rs b/crates/forge_domain/src/auth/credentials.rs index 15b5494855..141525c5be 100644 --- a/crates/forge_domain/src/auth/credentials.rs +++ b/crates/forge_domain/src/auth/credentials.rs @@ -6,6 +6,57 @@ use serde::{Deserialize, Serialize}; use crate::{AccessToken, ApiKey, OAuthConfig, ProviderId, RefreshToken, URLParam, URLParamValue}; +/// Strategy for providing API keys to a credential. +/// +/// Uses untagged serde representation so that a bare string (legacy format) +/// deserializes as [`StaticKey`](Self::StaticKey), preserving backward +/// compatibility with existing `~/.forge/.credentials.json` files. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ApiKeyProvider { + /// A static, user-supplied API key. + StaticKey(ApiKey), + /// A shell command executed via `sh -c` whose trimmed stdout is used as the + /// API key. The `last_key` and `expires_at` are cached on disk so the + /// command only re-executes when the key expires (or on first use). + HelperCommand { + command: String, + #[serde(default)] + last_key: ApiKey, + #[serde(default, skip_serializing_if = "Option::is_none")] + expires_at: Option>, + }, +} + +impl ApiKeyProvider { + /// Returns the current API key value. + /// + /// For [`StaticKey`](Self::StaticKey) this is the user-supplied key. For + /// [`HelperCommand`](Self::HelperCommand) this is the last key obtained by + /// executing the command (empty until the first refresh). + pub fn api_key(&self) -> &ApiKey { + match self { + Self::StaticKey(key) => key, + Self::HelperCommand { last_key, .. } => last_key, + } + } + + /// Returns `true` when the key should be refreshed before use. + /// + /// Static keys never expire. Helper-command keys expire based on the + /// `expires_at` field (populated from the command's TTL/Expires metadata). + /// When `expires_at` is `None` the key is treated as single-use and + /// refreshed on every call. + pub fn needs_refresh(&self, buffer: chrono::Duration) -> bool { + match self { + Self::StaticKey(_) => false, + Self::HelperCommand { expires_at: Some(exp), .. } => Utc::now() + buffer >= *exp, + Self::HelperCommand { expires_at: None, .. } => true, + } + } +} + +/// Stored authentication credential for a provider. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Setters)] pub struct AuthCredential { pub id: ProviderId, @@ -14,13 +65,15 @@ pub struct AuthCredential { pub url_params: HashMap, } impl AuthCredential { + /// Creates a credential with a static API key. pub fn new_api_key(id: ProviderId, api_key: ApiKey) -> Self { Self { id, - auth_details: AuthDetails::ApiKey(api_key), + auth_details: AuthDetails::static_api_key(api_key), url_params: HashMap::new(), } } + /// Creates a credential with OAuth tokens. pub fn new_oauth(id: ProviderId, tokens: OAuthTokens, config: OAuthConfig) -> Self { Self { id, @@ -28,6 +81,7 @@ impl AuthCredential { url_params: HashMap::new(), } } + /// Creates a credential with OAuth tokens and an API key. pub fn new_oauth_with_api_key( id: ProviderId, tokens: OAuthTokens, @@ -49,6 +103,8 @@ impl AuthCredential { } } + /// Creates a credential with a Google Application Default Credentials + /// token. pub fn new_google_adc(id: ProviderId, access_token: ApiKey) -> Self { Self { id, @@ -60,7 +116,7 @@ impl AuthCredential { /// Checks if the credential needs to be refreshed. pub fn needs_refresh(&self, buffer: chrono::Duration) -> bool { match &self.auth_details { - AuthDetails::ApiKey(_) => false, + AuthDetails::ApiKey(provider) => provider.needs_refresh(buffer), // AWS Profile credentials are managed by the AWS SDK internally AuthDetails::AwsProfile(_) => false, // Google ADC tokens are short-lived (1 hour) and should always be checked/refreshed @@ -86,7 +142,7 @@ impl AuthCredential { #[serde(rename_all = "snake_case")] pub enum AuthDetails { #[serde(alias = "ApiKey")] - ApiKey(ApiKey), + ApiKey(ApiKeyProvider), #[serde(alias = "GoogleAdc")] GoogleAdc(ApiKey), #[serde(alias = "AwsProfile")] @@ -105,9 +161,24 @@ pub enum AuthDetails { } impl AuthDetails { + /// Creates a static API key auth details. + pub fn static_api_key(key: ApiKey) -> Self { + Self::ApiKey(ApiKeyProvider::StaticKey(key)) + } + + /// Creates an API key auth details backed by a helper command. + pub fn api_key_from_helper( + command: String, + last_key: ApiKey, + expires_at: Option>, + ) -> Self { + Self::ApiKey(ApiKeyProvider::HelperCommand { command, last_key, expires_at }) + } + + /// Returns the API key if these auth details contain one. pub fn api_key(&self) -> Option<&ApiKey> { match self { - AuthDetails::ApiKey(api_key) => Some(api_key), + AuthDetails::ApiKey(provider) => Some(provider.api_key()), AuthDetails::GoogleAdc(api_key) => Some(api_key), AuthDetails::AwsProfile(_) => None, AuthDetails::OAuth { .. } => None, @@ -148,3 +219,239 @@ impl OAuthTokens { Utc::now() >= self.expires_at } } + +#[cfg(test)] +mod tests { + use super::*; + + mod api_key_provider { + use super::*; + + mod static_key { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn api_key_returns_the_key() { + let fixture = ApiKeyProvider::StaticKey(ApiKey::from("sk-test".to_string())); + let actual = fixture.api_key(); + let expected = &ApiKey::from("sk-test".to_string()); + assert_eq!(actual, expected); + } + + #[test] + fn serde_roundtrip() { + let fixture = ApiKeyProvider::StaticKey(ApiKey::from("sk-test".to_string())); + let json = serde_json::to_string(&fixture).unwrap(); + let actual: ApiKeyProvider = serde_json::from_str(&json).unwrap(); + assert_eq!(actual, fixture); + } + + #[test] + fn serializes_as_bare_string() { + let fixture = ApiKeyProvider::StaticKey(ApiKey::from("sk-test".to_string())); + let actual = serde_json::to_string(&fixture).unwrap(); + let expected = r#""sk-test""#; + assert_eq!(actual, expected); + } + + #[test] + fn deserializes_from_bare_string() { + let actual: ApiKeyProvider = serde_json::from_str(r#""sk-old-key""#).unwrap(); + let expected = ApiKeyProvider::StaticKey(ApiKey::from("sk-old-key".to_string())); + assert_eq!(actual, expected); + } + } + + mod helper_command { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn api_key_returns_last_key() { + let fixture = ApiKeyProvider::HelperCommand { + command: "echo key".to_string(), + last_key: ApiKey::from("dynamic-key".to_string()), + expires_at: None, + }; + let actual = fixture.api_key(); + let expected = &ApiKey::from("dynamic-key".to_string()); + assert_eq!(actual, expected); + } + + #[test] + fn serializes_command_and_last_key() { + let fixture = ApiKeyProvider::HelperCommand { + command: "vault read -field=token".to_string(), + last_key: ApiKey::from("resolved".to_string()), + expires_at: None, + }; + let actual = serde_json::to_string(&fixture).unwrap(); + let expected = r#"{"command":"vault read -field=token","last_key":"resolved"}"#; + assert_eq!(actual, expected); + } + + #[test] + fn round_trips_with_cached_key() { + let fixture = ApiKeyProvider::HelperCommand { + command: "vault read -field=token".to_string(), + last_key: ApiKey::from("cached-key".to_string()), + expires_at: Some(Utc::now() + chrono::Duration::hours(1)), + }; + let json = serde_json::to_string(&fixture).unwrap(); + let actual: ApiKeyProvider = serde_json::from_str(&json).unwrap(); + assert_eq!(actual, fixture); + } + + #[test] + fn deserializes_with_empty_last_key() { + let json = r#"{"command":"vault read -field=token"}"#; + let actual: ApiKeyProvider = serde_json::from_str(json).unwrap(); + let expected = ApiKeyProvider::HelperCommand { + command: "vault read -field=token".to_string(), + last_key: ApiKey::default(), + expires_at: None, + }; + assert_eq!(actual, expected); + } + + #[test] + fn deserialized_needs_refresh() { + let json = r#"{"command":"echo fresh-key"}"#; + let fixture: ApiKeyProvider = serde_json::from_str(json).unwrap(); + let actual = fixture.needs_refresh(chrono::Duration::minutes(5)); + assert!(actual); + } + } + } + + mod needs_refresh { + use super::*; + + mod helper_command { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn without_expires_at_returns_true() { + let fixture = AuthCredential { + auth_details: AuthDetails::api_key_from_helper( + "echo key".to_string(), + ApiKey::from("key".to_string()), + None, + ), + ..AuthCredential::new_api_key( + ProviderId::from("test".to_string()), + ApiKey::from("key".to_string()), + ) + }; + let actual = fixture.needs_refresh(chrono::Duration::minutes(5)); + let expected = true; + assert_eq!(actual, expected); + } + + #[test] + fn with_future_expires_at_returns_false() { + let fixture = AuthCredential { + auth_details: AuthDetails::api_key_from_helper( + "echo key".to_string(), + ApiKey::from("key".to_string()), + Some(Utc::now() + chrono::Duration::hours(1)), + ), + ..AuthCredential::new_api_key( + ProviderId::from("test".to_string()), + ApiKey::from("key".to_string()), + ) + }; + let actual = fixture.needs_refresh(chrono::Duration::minutes(5)); + let expected = false; + assert_eq!(actual, expected); + } + + #[test] + fn with_past_expires_at_returns_true() { + let fixture = AuthCredential { + auth_details: AuthDetails::api_key_from_helper( + "echo key".to_string(), + ApiKey::from("key".to_string()), + Some(Utc::now() - chrono::Duration::minutes(1)), + ), + ..AuthCredential::new_api_key( + ProviderId::from("test".to_string()), + ApiKey::from("key".to_string()), + ) + }; + let actual = fixture.needs_refresh(chrono::Duration::minutes(5)); + let expected = true; + assert_eq!(actual, expected); + } + } + + mod static_key { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn returns_false() { + let fixture = AuthCredential::new_api_key( + ProviderId::from("test".to_string()), + ApiKey::from("key".to_string()), + ); + let actual = fixture.needs_refresh(chrono::Duration::minutes(5)); + let expected = false; + assert_eq!(actual, expected); + } + } + } + + mod backward_compat { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn legacy_credential_json_deserializes() { + let fixture = r#"{ + "id": "anthropic", + "auth_details": {"api_key": "sk-legacy-key"} + }"#; + let actual: AuthCredential = serde_json::from_str(fixture).unwrap(); + let expected = AuthCredential::new_api_key( + ProviderId::from("anthropic".to_string()), + ApiKey::from("sk-legacy-key".to_string()), + ); + assert_eq!(actual, expected); + } + + #[test] + fn helper_credential_serializes_as_expected() { + let fixture = vec![AuthCredential { + id: ProviderId::from("xai".to_string()), + auth_details: AuthDetails::api_key_from_helper( + "printf 'sk-test\\n---\\nTTL: 300'".to_string(), + ApiKey::from("sk-test".to_string()), + None, + ), + url_params: HashMap::new(), + }]; + let actual = serde_json::to_string_pretty(&fixture).unwrap(); + // command persisted, last_key and expires_at skipped + assert!( + actual.contains(r#""command""#), + "should contain command: {actual}" + ); + assert!( + !actual.contains("last_key"), + "should NOT contain last_key: {actual}" + ); + assert!( + !actual.contains("expires_at"), + "should NOT contain expires_at: {actual}" + ); + } + } +} diff --git a/crates/forge_domain/src/auth/new_types.rs b/crates/forge_domain/src/auth/new_types.rs index 579bf2b0ba..2139a75090 100644 --- a/crates/forge_domain/src/auth/new_types.rs +++ b/crates/forge_domain/src/auth/new_types.rs @@ -1,7 +1,16 @@ use serde::{Deserialize, Serialize}; #[derive( - Clone, Serialize, Deserialize, derive_more::From, derive_more::Deref, PartialEq, Eq, Hash, Debug, + Clone, + Serialize, + Deserialize, + derive_more::From, + derive_more::Deref, + PartialEq, + Eq, + Hash, + Debug, + Default, )] #[serde(transparent)] pub struct ApiKey(String); diff --git a/crates/forge_domain/src/node.rs b/crates/forge_domain/src/node.rs index 5f172185b8..aaca589f51 100644 --- a/crates/forge_domain/src/node.rs +++ b/crates/forge_domain/src/node.rs @@ -94,7 +94,7 @@ pub struct WorkspaceAuth { impl From for crate::AuthDetails { fn from(auth: WorkspaceAuth) -> Self { - crate::AuthDetails::ApiKey(auth.token) + crate::AuthDetails::static_api_key(auth.token) } } diff --git a/crates/forge_domain/src/provider.rs b/crates/forge_domain/src/provider.rs index 3840aef3a0..505724ca68 100644 --- a/crates/forge_domain/src/provider.rs +++ b/crates/forge_domain/src/provider.rs @@ -271,7 +271,7 @@ impl Provider { self.credential .as_ref() .and_then(|c| match &c.auth_details { - AuthDetails::ApiKey(key) => Some(key), + AuthDetails::ApiKey(provider) => Some(provider.api_key()), _ => None, }) } @@ -373,7 +373,7 @@ mod test_helpers { fn make_credential(provider_id: ProviderId, key: &str) -> Option { Some(AuthCredential { id: provider_id, - auth_details: AuthDetails::ApiKey(ApiKey::from(key.to_string())), + auth_details: AuthDetails::static_api_key(ApiKey::from(key.to_string())), url_params: HashMap::new(), }) } @@ -690,7 +690,7 @@ mod tests { .unwrap(), credential: Some(AuthCredential { id: ProviderId::IO_INTELLIGENCE, - auth_details: AuthDetails::ApiKey(ApiKey::from(fixture.to_string())), + auth_details: AuthDetails::static_api_key(ApiKey::from(fixture.to_string())), url_params: HashMap::new(), }), auth_methods: vec![crate::AuthMethod::ApiKey], @@ -714,7 +714,7 @@ mod tests { url: Url::from_str("https://api.x.ai/v1/chat/completions").unwrap(), credential: Some(AuthCredential { id: ProviderId::XAI, - auth_details: AuthDetails::ApiKey(ApiKey::from(fixture.to_string())), + auth_details: AuthDetails::static_api_key(ApiKey::from(fixture.to_string())), url_params: HashMap::new(), }), auth_methods: vec![crate::AuthMethod::ApiKey], diff --git a/crates/forge_infra/Cargo.toml b/crates/forge_infra/Cargo.toml index 1f25994741..08d8a3f887 100644 --- a/crates/forge_infra/Cargo.toml +++ b/crates/forge_infra/Cargo.toml @@ -52,6 +52,9 @@ open.workspace = true aws-config.workspace = true aws-credential-types.workspace = true +[target.'cfg(unix)'.dependencies] +libc = "0.2" + [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt", "time", "test-util"] } serial_test = "3.4" diff --git a/crates/forge_infra/src/auth/api_key_helper.rs b/crates/forge_infra/src/auth/api_key_helper.rs new file mode 100644 index 0000000000..d49cb2f5ec --- /dev/null +++ b/crates/forge_infra/src/auth/api_key_helper.rs @@ -0,0 +1,269 @@ +use std::process::Stdio; +use std::time::Duration; + +use chrono::{DateTime, Utc}; +use forge_domain::{ApiKey, ApiKeyProvider}; + +/// Default timeout for helper command execution. +const DEFAULT_TIMEOUT_SECS: u64 = 30; +const MAX_TIMEOUT_SECS: u64 = 300; + +/// Returns the configured helper command timeout, read from the +/// `FORGE_API_KEY_HELPER_TIMEOUT` environment variable. Falls back to +/// [`DEFAULT_TIMEOUT_SECS`] when the variable is absent or unparseable. +/// Capped at [`MAX_TIMEOUT_SECS`]. +fn helper_timeout() -> Duration { + let secs = std::env::var("FORGE_API_KEY_HELPER_TIMEOUT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(DEFAULT_TIMEOUT_SECS) + .min(MAX_TIMEOUT_SECS); + Duration::from_secs(secs) +} + +/// Kills the entire process group rooted at `pid`. +#[cfg(unix)] +fn kill_process_group(pid: u32) { + unsafe { + libc::kill(-(pid as libc::pid_t), libc::SIGKILL); + } +} + +/// Executes an [`ApiKeyProvider`] to obtain a fresh key. +/// +/// - `StaticKey` — returns the provider unchanged (no-op). +/// - `HelperCommand` — runs the shell command via `sh -c`, parses stdout, and +/// returns an updated provider with the new key and optional expiry. The +/// child is placed in its own process group so that the entire tree +/// (including grandchildren) is killed on timeout. +pub async fn execute(provider: &ApiKeyProvider) -> anyhow::Result { + match provider { + ApiKeyProvider::StaticKey(_) => Ok(provider.clone()), + ApiKeyProvider::HelperCommand { command, .. } => { + let timeout = helper_timeout(); + + let mut cmd = tokio::process::Command::new("sh"); + cmd.arg("-c") + .arg(command) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .kill_on_drop(true); + + #[cfg(unix)] + cmd.process_group(0); + + let child = cmd.spawn().map_err(|e| { + anyhow::anyhow!("Failed to execute auth helper (command: {command:.40}): {e}") + })?; + + let pid = child.id(); + + match tokio::time::timeout(timeout, child.wait_with_output()).await { + Ok(Ok(output)) => { + if !output.status.success() { + anyhow::bail!( + "Auth helper exited with status {} (command: {command:.40})", + output.status + ); + } + + let stdout = String::from_utf8(output.stdout).map_err(|e| { + anyhow::anyhow!("Auth helper output is not valid UTF-8: {e}") + })?; + + let (key, expires_at) = parse_output(&stdout)?; + Ok(ApiKeyProvider::HelperCommand { + command: command.clone(), + last_key: key, + expires_at, + }) + } + Ok(Err(e)) => Err(anyhow::anyhow!( + "Failed to execute auth helper (command: {command:.40}): {e}" + )), + Err(_) => { + #[cfg(unix)] + if let Some(pid) = pid { + kill_process_group(pid); + } + anyhow::bail!( + "Auth helper timed out after {}s (command: {command:.40})", + timeout.as_secs() + ); + } + } + } + } +} + +/// Parses helper command output into an API key and optional expiry. +/// +/// Format: `` or `\n---\nTTL: ` or +/// `\n---\nExpires: `. +fn parse_output(raw: &str) -> anyhow::Result<(ApiKey, Option>)> { + // Normalize CRLF to LF for cross-platform compatibility + let output = raw.replace("\r\n", "\n"); + let (key_part, metadata) = match output.split_once("\n---\n") { + Some((key, rest)) => (key, Some(rest)), + None => (output.as_str(), None), + }; + + let key = key_part.trim().to_string(); + if key.is_empty() { + anyhow::bail!("Auth helper produced empty output"); + } + + let expires_at = if let Some(meta) = metadata { + let meta = meta.trim(); + if let Some(secs) = meta.strip_prefix("TTL:") { + let ttl: u64 = secs + .trim() + .parse() + .map_err(|e| anyhow::anyhow!("Invalid TTL value: {e}"))?; + let duration = chrono::Duration::seconds( + i64::try_from(ttl).map_err(|_| anyhow::anyhow!("TTL value too large: {ttl}"))?, + ); + Some(Utc::now() + duration) + } else if let Some(ts) = meta.strip_prefix("Expires:") { + let timestamp: i64 = ts + .trim() + .parse() + .map_err(|e| anyhow::anyhow!("Invalid Expires timestamp: {e}"))?; + Some( + DateTime::from_timestamp(timestamp, 0) + .ok_or_else(|| anyhow::anyhow!("Invalid unix timestamp: {timestamp}"))?, + ) + } else { + None + } + } else { + None + }; + + Ok((ApiKey::from(key), expires_at)) +} + +#[cfg(test)] +mod tests { + use super::*; + + mod parse_output { + use super::*; + + #[test] + fn key_only() { + let (key, expires_at) = parse_output("sk-test-key\n").unwrap(); + assert_eq!(key.as_ref(), "sk-test-key"); + assert!(expires_at.is_none()); + } + + #[test] + fn key_with_ttl() { + let (key, expires_at) = parse_output("sk-test-key\n---\nTTL: 3600\n").unwrap(); + assert_eq!(key.as_ref(), "sk-test-key"); + let exp = expires_at.unwrap(); + let expected = Utc::now() + chrono::Duration::seconds(3600); + assert!((exp - expected).num_seconds().abs() < 5); + } + + #[test] + fn key_with_expires() { + let future_ts = Utc::now().timestamp() + 7200; + let input = format!("sk-test-key\n---\nExpires: {future_ts}\n"); + let (key, expires_at) = parse_output(&input).unwrap(); + assert_eq!(key.as_ref(), "sk-test-key"); + let exp = expires_at.unwrap(); + let expected = DateTime::from_timestamp(future_ts, 0).unwrap(); + assert!((exp - expected).num_seconds().abs() < 5); + } + + #[test] + fn empty_output_returns_error() { + assert!(parse_output(" \n").is_err()); + } + + #[test] + fn unknown_metadata_ignored() { + let (key, expires_at) = parse_output("sk-key\n---\nFoo: bar\n").unwrap(); + assert_eq!(key.as_ref(), "sk-key"); + assert!(expires_at.is_none()); + } + + #[test] + fn crlf_line_endings() { + let (key, expires_at) = parse_output("sk-test-key\r\n---\r\nTTL: 3600\r\n").unwrap(); + assert_eq!(key.as_ref(), "sk-test-key"); + assert!(expires_at.is_some()); + } + } + + mod execute { + use super::*; + + #[tokio::test] + async fn static_key_returns_unchanged() { + let provider = ApiKeyProvider::StaticKey(ApiKey::from("sk-static".to_string())); + let result = execute(&provider).await.unwrap(); + assert_eq!(result, provider); + } + + #[tokio::test] + async fn helper_command_returns_key() { + let provider = ApiKeyProvider::HelperCommand { + command: "echo sk-from-helper".to_string(), + last_key: ApiKey::from("old".to_string()), + expires_at: None, + }; + let result = execute(&provider).await.unwrap(); + match &result { + ApiKeyProvider::HelperCommand { last_key, .. } => { + assert_eq!(last_key.as_ref(), "sk-from-helper"); + } + _ => panic!("Expected HelperCommand"), + } + } + + #[tokio::test] + async fn failing_command_returns_error() { + let provider = ApiKeyProvider::HelperCommand { + command: "false".to_string(), + last_key: ApiKey::from("old".to_string()), + expires_at: None, + }; + assert!(execute(&provider).await.is_err()); + } + + #[cfg(unix)] + #[tokio::test] + async fn timeout_kills_entire_process_group() { + let pid_file = + std::env::temp_dir().join(format!("forge_test_pgkill_{}", std::process::id())); + + let command = format!("(echo $$ > {}; sleep 300) & sleep 300", pid_file.display()); + + unsafe { std::env::set_var("FORGE_API_KEY_HELPER_TIMEOUT", "1") }; + + let provider = ApiKeyProvider::HelperCommand { + command, + last_key: ApiKey::from("old".to_string()), + expires_at: None, + }; + + let result = execute(&provider).await; + + unsafe { std::env::remove_var("FORGE_API_KEY_HELPER_TIMEOUT") }; + + assert!(result.unwrap_err().to_string().contains("timed out")); + + tokio::time::sleep(Duration::from_millis(200)).await; + + let pid_str = std::fs::read_to_string(&pid_file).unwrap(); + let grandchild_pid: libc::pid_t = pid_str.trim().parse().unwrap(); + + let alive = unsafe { libc::kill(grandchild_pid, 0) }; + assert_eq!(alive, -1, "Grandchild process should have been killed"); + + std::fs::remove_file(&pid_file).ok(); + } + } +} diff --git a/crates/forge_infra/src/auth/mod.rs b/crates/forge_infra/src/auth/mod.rs index 1edf7f7320..964def4f99 100644 --- a/crates/forge_infra/src/auth/mod.rs +++ b/crates/forge_infra/src/auth/mod.rs @@ -1,3 +1,4 @@ +pub mod api_key_helper; mod mcp_credentials; mod mcp_token_storage; diff --git a/crates/forge_infra/src/auth/strategy.rs b/crates/forge_infra/src/auth/strategy.rs index 559a365f19..5a30545758 100644 --- a/crates/forge_infra/src/auth/strategy.rs +++ b/crates/forge_infra/src/auth/strategy.rs @@ -2,9 +2,9 @@ use std::time::Duration; use forge_app::{AuthStrategy, OAuthHttpProvider, StrategyFactory}; use forge_domain::{ - ApiKey, ApiKeyRequest, AuthContextRequest, AuthContextResponse, AuthCredential, CodeRequest, - DeviceCodeRequest, OAuthConfig, OAuthTokenResponse, OAuthTokens, ProviderId, URLParam, - URLParamSpec, + ApiKey, ApiKeyProvider, ApiKeyRequest, AuthContextRequest, AuthContextResponse, AuthCredential, + AuthDetails, CodeRequest, DeviceCodeRequest, OAuthConfig, OAuthTokenResponse, OAuthTokens, + ProviderId, URLParam, URLParamSpec, }; use google_cloud_auth::credentials::Builder; use oauth2::basic::BasicClient; @@ -43,17 +43,40 @@ impl AuthStrategy for ApiKeyStrategy { context_response: AuthContextResponse, ) -> anyhow::Result { match context_response { - AuthContextResponse::ApiKey(ctx) => Ok(AuthCredential::new_api_key( - self.provider_id.clone(), - ctx.response.api_key, - ) - .url_params(ctx.response.url_params)), + AuthContextResponse::ApiKey(ctx) => { + let auth_details = if let Some(command) = ctx.response.helper_command { + // Execute the helper to validate and obtain the initial key + let initial = forge_domain::ApiKeyProvider::HelperCommand { + command, + last_key: forge_domain::ApiKey::default(), + expires_at: None, + }; + let provider = crate::auth::api_key_helper::execute(&initial).await?; + AuthDetails::ApiKey(provider) + } else { + AuthDetails::static_api_key(ctx.response.api_key) + }; + Ok(AuthCredential { + id: self.provider_id.clone(), + auth_details, + url_params: ctx.response.url_params, + }) + } _ => Err(AuthError::InvalidContext("Expected ApiKey context".to_string()).into()), } } async fn refresh(&self, credential: &AuthCredential) -> anyhow::Result { - // API keys don't expire - return as-is + if let AuthDetails::ApiKey(provider @ ApiKeyProvider::HelperCommand { .. }) = + &credential.auth_details + { + let refreshed_provider = crate::auth::api_key_helper::execute(provider).await?; + return Ok(AuthCredential { + id: credential.id.clone(), + auth_details: AuthDetails::ApiKey(refreshed_provider), + url_params: credential.url_params.clone(), + }); + } Ok(credential.clone()) } } @@ -1209,7 +1232,7 @@ impl StrategyFactory for ForgeAuthStrategyFactory { mod tests { use std::collections::HashMap; - use forge_domain::URLParam; + use forge_domain::{ApiKeyProvider, URLParam}; use pretty_assertions::assert_eq; use super::*; @@ -1486,4 +1509,40 @@ mod tests { let expected = fixture_url_params; assert_eq!(actual.url_params, expected); } + + #[tokio::test] + async fn test_api_key_strategy_refresh_with_helper_command() { + let strategy = ApiKeyStrategy::new(ProviderId::OPENAI, vec![]); + let fixture = + AuthCredential::new_api_key(ProviderId::OPENAI, ApiKey::from("old-key".to_string())); + // Replace auth_details with a HelperCommand + let fixture = AuthCredential { + auth_details: AuthDetails::api_key_from_helper( + "echo refreshed-key".to_string(), + ApiKey::from("old-key".to_string()), + None, + ), + ..fixture + }; + + let actual = strategy.refresh(&fixture).await.unwrap(); + + match &actual.auth_details { + AuthDetails::ApiKey(ApiKeyProvider::HelperCommand { last_key, .. }) => { + assert_eq!(last_key.as_ref(), "refreshed-key"); + } + other => panic!("Expected HelperCommand, got {other:?}"), + } + } + + #[tokio::test] + async fn test_api_key_strategy_refresh_static_key_unchanged() { + let strategy = ApiKeyStrategy::new(ProviderId::OPENAI, vec![]); + let fixture = + AuthCredential::new_api_key(ProviderId::OPENAI, ApiKey::from("static-key".to_string())); + + let actual = strategy.refresh(&fixture).await.unwrap(); + + assert_eq!(actual, fixture); + } } diff --git a/crates/forge_infra/src/lib.rs b/crates/forge_infra/src/lib.rs index a6a726d477..dc86d010bd 100644 --- a/crates/forge_infra/src/lib.rs +++ b/crates/forge_infra/src/lib.rs @@ -18,6 +18,7 @@ mod mcp_client; mod mcp_server; mod walker; +pub use auth::api_key_helper; pub use console::StdConsoleWriter; pub use env::ForgeEnvironmentInfra; pub use executor::ForgeCommandExecutorService; diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 6552c3f84e..b9480e4cac 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -2941,43 +2941,75 @@ impl A + Send + Sync> UI }) .collect::>>()?; - // Check if API key is already provided - // For Google ADC, we use a marker to skip prompting - // For other providers, we use the existing key as a default value (autofill) - let api_key_str = if let Some(default_key) = &request.api_key { - let key_str = default_key.as_ref(); - - // Skip prompting for markers that indicate non-API-key auth - if key_str == "google_adc_marker" || key_str == "aws_profile_marker" { - key_str.to_string() - } else { - // For other providers, show the existing key as default (autofill) + // Skip interactive prompts for auth markers (Google ADC, AWS profile) + let is_auth_marker = request.api_key.as_ref().is_some_and(|k| { + k.as_ref() == "google_adc_marker" || k.as_ref() == "aws_profile_marker" + }); + + if is_auth_marker { + // Google ADC / AWS profile: skip prompting entirely + let response = AuthContextResponse::api_key( + request.clone(), + request.api_key.as_ref().unwrap().as_ref(), + url_params, + ); + self.api + .complete_provider_auth(provider_id, response, Duration::from_secs(0)) + .await?; + return Ok(()); + } + + let auth_type_options = vec![ + "Static API Key".to_string(), + "Helper Command (script that generates a key)".to_string(), + ]; + let use_helper = ForgeWidget::select("Authentication type:", auth_type_options) + .prompt()? + .as_deref() + == Some("Helper Command (script that generates a key)"); + + if use_helper { + // Helper command flow + let command = ForgeWidget::input( + "Enter helper command (e.g. vault read -field=token secret/key)", + ) + .prompt()? + .context("Helper command input cancelled")?; + anyhow::ensure!(!command.trim().is_empty(), "Helper command cannot be empty"); + let command = command.trim().to_string(); + + // Send with a placeholder key — the strategy will validate by + // executing the command during complete_provider_auth + let response = + AuthContextResponse::api_key_with_helper(request.clone(), "", url_params, command); + self.spinner.start(Some("Validating helper command..."))?; + self.api + .complete_provider_auth(provider_id, response, Duration::from_secs(0)) + .await?; + self.spinner.stop(None)?; + } else { + // Static API key flow + let api_key_str = if let Some(default_key) = &request.api_key { + let key_str = default_key.as_ref(); let input = ForgeWidget::input(format!("Enter your {provider_id} API key")) .with_default(key_str); let api_key = input.prompt()?.context("API key input cancelled")?; let api_key_str = api_key.trim(); anyhow::ensure!(!api_key_str.is_empty(), "API key cannot be empty"); api_key_str.to_string() - } - } else { - // Prompt for API key input (no existing key) - let input = ForgeWidget::input(format!("Enter your {provider_id} API key")); - let api_key = input.prompt()?.context("API key input cancelled")?; - let api_key_str = api_key.trim(); - anyhow::ensure!(!api_key_str.is_empty(), "API key cannot be empty"); - api_key_str.to_string() - }; - - // Update the context with collected data - let response = AuthContextResponse::api_key(request.clone(), &api_key_str, url_params); + } else { + let input = ForgeWidget::input(format!("Enter your {provider_id} API key")); + let api_key = input.prompt()?.context("API key input cancelled")?; + let api_key_str = api_key.trim(); + anyhow::ensure!(!api_key_str.is_empty(), "API key cannot be empty"); + api_key_str.to_string() + }; - self.api - .complete_provider_auth( - provider_id, - response, - Duration::from_secs(0), // No timeout needed since we have the data - ) - .await?; + let response = AuthContextResponse::api_key(request.clone(), &api_key_str, url_params); + self.api + .complete_provider_auth(provider_id, response, Duration::from_secs(0)) + .await?; + } Ok(()) } diff --git a/crates/forge_repo/src/provider/anthropic.rs b/crates/forge_repo/src/provider/anthropic.rs index 052c370a1f..e1527538de 100644 --- a/crates/forge_repo/src/provider/anthropic.rs +++ b/crates/forge_repo/src/provider/anthropic.rs @@ -47,7 +47,7 @@ impl Anthropic { .credential .as_ref() .and_then(|c| match &c.auth_details { - forge_domain::AuthDetails::ApiKey(key) => Some(key.as_str()), + forge_domain::AuthDetails::ApiKey(provider) => Some(provider.api_key().as_str()), forge_domain::AuthDetails::OAuthWithApiKey { api_key, .. } => { Some(api_key.as_str()) } @@ -454,9 +454,9 @@ mod tests { url: chat_url, credential: Some(forge_domain::AuthCredential { id: forge_app::domain::ProviderId::ANTHROPIC, - auth_details: forge_domain::AuthDetails::ApiKey(forge_domain::ApiKey::from( - "sk-test-key".to_string(), - )), + auth_details: forge_domain::AuthDetails::static_api_key( + forge_domain::ApiKey::from("sk-test-key".to_string()), + ), url_params: std::collections::HashMap::new(), }), auth_methods: vec![forge_domain::AuthMethod::ApiKey], @@ -522,9 +522,9 @@ mod tests { url: chat_url, credential: Some(forge_domain::AuthCredential { id: forge_app::domain::ProviderId::ANTHROPIC, - auth_details: forge_domain::AuthDetails::ApiKey(forge_domain::ApiKey::from( - "sk-some-key".to_string(), - )), + auth_details: forge_domain::AuthDetails::static_api_key( + forge_domain::ApiKey::from("sk-some-key".to_string()), + ), url_params: std::collections::HashMap::new(), }), auth_methods: vec![forge_domain::AuthMethod::ApiKey], @@ -662,9 +662,9 @@ mod tests { url: chat_url, credential: Some(forge_domain::AuthCredential { id: forge_app::domain::ProviderId::ANTHROPIC, - auth_details: forge_domain::AuthDetails::ApiKey(forge_domain::ApiKey::from( - "sk-test-key".to_string(), - )), + auth_details: forge_domain::AuthDetails::static_api_key( + forge_domain::ApiKey::from("sk-test-key".to_string()), + ), url_params: std::collections::HashMap::new(), }), auth_methods: vec![forge_domain::AuthMethod::ApiKey], @@ -810,9 +810,9 @@ mod tests { url: chat_url, credential: Some(forge_domain::AuthCredential { id: forge_app::domain::ProviderId::ANTHROPIC, - auth_details: forge_domain::AuthDetails::ApiKey(forge_domain::ApiKey::from( - "sk-test-key".to_string(), - )), + auth_details: forge_domain::AuthDetails::static_api_key( + forge_domain::ApiKey::from("sk-test-key".to_string()), + ), url_params: std::collections::HashMap::new(), }), auth_methods: vec![forge_domain::AuthMethod::ApiKey], @@ -866,9 +866,9 @@ mod tests { url: chat_url, credential: Some(forge_domain::AuthCredential { id: forge_app::domain::ProviderId::ANTHROPIC, - auth_details: forge_domain::AuthDetails::ApiKey(forge_domain::ApiKey::from( - "sk-test-key".to_string(), - )), + auth_details: forge_domain::AuthDetails::static_api_key( + forge_domain::ApiKey::from("sk-test-key".to_string()), + ), url_params: std::collections::HashMap::new(), }), auth_methods: vec![forge_domain::AuthMethod::ApiKey], diff --git a/crates/forge_repo/src/provider/bedrock.rs b/crates/forge_repo/src/provider/bedrock.rs index 339eaf2b1d..2f2d319287 100644 --- a/crates/forge_repo/src/provider/bedrock.rs +++ b/crates/forge_repo/src/provider/bedrock.rs @@ -50,8 +50,8 @@ impl BedrockProvider { .context("Bedrock requires credentials")?; let auth_mode = match &credential.auth_details { - AuthDetails::ApiKey(key) if !key.is_empty() => { - BedrockAuthMode::BearerToken(key.as_ref().to_string()) + AuthDetails::ApiKey(provider) if !provider.api_key().is_empty() => { + BedrockAuthMode::BearerToken(provider.api_key().as_ref().to_string()) } AuthDetails::AwsProfile(profile) if !profile.is_empty() => { BedrockAuthMode::AwsProfile(profile.as_ref().to_string()) @@ -1074,7 +1074,7 @@ mod tests { url_params: vec![], credential: Some(AuthCredential { id: ProviderId::from("bedrock".to_string()), - auth_details: AuthDetails::ApiKey(ApiKey::from(token.to_string())), + auth_details: AuthDetails::static_api_key(ApiKey::from(token.to_string())), url_params, }), custom_headers: None, diff --git a/crates/forge_repo/src/provider/google.rs b/crates/forge_repo/src/provider/google.rs index 0493b93896..4825283873 100644 --- a/crates/forge_repo/src/provider/google.rs +++ b/crates/forge_repo/src/provider/google.rs @@ -166,7 +166,9 @@ impl GoogleResponseRepository { // For Vertex AI, the Google ADC token is stored as ApiKey // For OAuth, extract the access token let (token, use_api_key_header) = match creds { - forge_domain::AuthDetails::ApiKey(api_key) => (api_key.as_str().to_string(), true), + forge_domain::AuthDetails::ApiKey(provider) => { + (provider.api_key().as_str().to_string(), true) + } forge_domain::AuthDetails::GoogleAdc(token) => (token.as_str().to_string(), false), forge_domain::AuthDetails::OAuth { tokens, .. } => { (tokens.access_token.as_str().to_string(), false) diff --git a/crates/forge_repo/src/provider/openai.rs b/crates/forge_repo/src/provider/openai.rs index ac4866c05b..f2db9ecf84 100644 --- a/crates/forge_repo/src/provider/openai.rs +++ b/crates/forge_repo/src/provider/openai.rs @@ -58,7 +58,9 @@ impl OpenAIProvider { .credential .as_ref() .and_then(|c| match &c.auth_details { - forge_domain::AuthDetails::ApiKey(key) => Some(key.as_str()), + forge_domain::AuthDetails::ApiKey(provider) => { + Some(provider.api_key().as_str()) + } forge_domain::AuthDetails::OAuthWithApiKey { api_key, .. } => { Some(api_key.as_str()) } @@ -389,7 +391,7 @@ mod tests { fn make_credential(provider_id: ProviderId, key: &str) -> Option { Some(forge_domain::AuthCredential { id: provider_id, - auth_details: forge_domain::AuthDetails::ApiKey(forge_domain::ApiKey::from( + auth_details: forge_domain::AuthDetails::static_api_key(forge_domain::ApiKey::from( key.to_string(), )), url_params: HashMap::new(), diff --git a/crates/forge_repo/src/provider/openai_responses/repository.rs b/crates/forge_repo/src/provider/openai_responses/repository.rs index dd89830110..27f8fe410d 100644 --- a/crates/forge_repo/src/provider/openai_responses/repository.rs +++ b/crates/forge_repo/src/provider/openai_responses/repository.rs @@ -77,7 +77,9 @@ impl OpenAIResponsesProvider { .credential .as_ref() .and_then(|c| match &c.auth_details { - forge_domain::AuthDetails::ApiKey(key) => Some(key.as_str()), + forge_domain::AuthDetails::ApiKey(provider) => { + Some(provider.api_key().as_str()) + } forge_domain::AuthDetails::OAuthWithApiKey { api_key, .. } => { Some(api_key.as_str()) } @@ -454,7 +456,7 @@ mod tests { fn make_credential(provider_id: ProviderId, key: &str) -> Option { Some(forge_domain::AuthCredential { id: provider_id, - auth_details: forge_domain::AuthDetails::ApiKey(forge_domain::ApiKey::from( + auth_details: forge_domain::AuthDetails::static_api_key(forge_domain::ApiKey::from( key.to_string(), )), url_params: HashMap::new(), @@ -1161,9 +1163,9 @@ mod tests { url: Url::parse("https://api.openai.com/v1").unwrap(), credential: Some(forge_domain::AuthCredential { id: ProviderId::OPENAI, - auth_details: forge_domain::AuthDetails::ApiKey(forge_domain::ApiKey::from( - "test-key".to_string(), - )), + auth_details: forge_domain::AuthDetails::static_api_key( + forge_domain::ApiKey::from("test-key".to_string()), + ), url_params, }), auth_methods: vec![], diff --git a/crates/forge_repo/src/provider/provider_repo.rs b/crates/forge_repo/src/provider/provider_repo.rs index 575b6733d2..e486f9a76f 100644 --- a/crates/forge_repo/src/provider/provider_repo.rs +++ b/crates/forge_repo/src/provider/provider_repo.rs @@ -4,8 +4,8 @@ use bytes::Bytes; use forge_app::domain::{ProviderId, ProviderResponse}; use forge_app::{EnvironmentInfra, FileReaderInfra, FileWriterInfra, HttpInfra}; use forge_domain::{ - AnyProvider, ApiKey, AuthCredential, AuthDetails, Error, MigrationResult, Provider, - ProviderRepository, ProviderType, URLParam, URLParamSpec, URLParamValue, + AnyProvider, ApiKey, ApiKeyProvider, AuthCredential, AuthDetails, Error, MigrationResult, + Provider, ProviderRepository, ProviderType, URLParam, URLParamSpec, URLParamValue, }; use merge::Merge; use serde::Deserialize; @@ -62,6 +62,9 @@ struct ProviderConfig { #[merge(strategy = overwrite)] api_key_vars: Option, #[serde(default)] + #[merge(strategy = overwrite)] + api_key_helper: Option, + #[serde(default)] #[merge(strategy = merge::vec::append)] url_param_vars: Vec, #[serde(default)] @@ -152,6 +155,7 @@ impl From for ProviderConfig { id: ProviderId::from(entry.id), provider_type, api_key_vars: entry.api_key_var, + api_key_helper: entry.api_key_helper, url_param_vars: entry.url_param_vars.into_iter().map(Into::into).collect(), response_type, url: entry.url, @@ -316,7 +320,7 @@ impl< } // Try to create credential from environment variables - if let Ok(credential) = self.create_credential_from_env(&config) { + if let Ok(credential) = self.create_credential_from_env(&config).await { migrated_providers.push(config.id); credentials.push(credential); } @@ -332,20 +336,10 @@ impl< } /// Creates a credential from environment variables for a given config - fn create_credential_from_env( + async fn create_credential_from_env( &self, config: &ProviderConfig, ) -> anyhow::Result { - // Check API key environment variable (if specified) - let api_key = if let Some(api_key_var) = &config.api_key_vars { - self.infra - .get_env_var(api_key_var) - .ok_or_else(|| Error::env_var_not_found(config.id.clone(), api_key_var))? - } else { - // For context engine, we don't use env vars for API key - String::new() - }; - // Check URL parameter environment variables let mut url_params = std::collections::HashMap::new(); @@ -358,10 +352,42 @@ impl< } } - // Create AuthCredential + // Check for helper command: direct config or {api_key_var}_HELPER env var + let helper_command = config.api_key_helper.clone().or_else(|| { + let helper_var = config + .api_key_vars + .as_ref() + .map(|v| format!("{v}_HELPER"))?; + self.infra.get_env_var(&helper_var) + }); + + if let Some(command) = helper_command { + let initial = ApiKeyProvider::HelperCommand { + command, + last_key: ApiKey::from(String::new()), + expires_at: None, + }; + let provider = forge_infra::api_key_helper::execute(&initial).await?; + return Ok(AuthCredential { + id: config.id.clone(), + auth_details: AuthDetails::ApiKey(provider), + url_params, + }); + } + + // Fall back to static API key + let api_key = if let Some(api_key_var) = &config.api_key_vars { + self.infra + .get_env_var(api_key_var) + .ok_or_else(|| Error::env_var_not_found(config.id.clone(), api_key_var))? + } else { + // For context engine, we don't use env vars for API key + String::new() + }; + Ok(AuthCredential { id: config.id.clone(), - auth_details: AuthDetails::ApiKey(ApiKey::from(api_key)), + auth_details: AuthDetails::static_api_key(ApiKey::from(api_key)), url_params, }) } @@ -382,8 +408,8 @@ impl< // Google ADC tokens expire quickly, so we refresh them on every load if (credential.id == forge_domain::ProviderId::VERTEX_AI || credential.id == forge_domain::ProviderId::VERTEX_AI_ANTHROPIC) - && let forge_domain::AuthDetails::ApiKey(ref api_key) = credential.auth_details - && api_key.as_ref() == "google_adc_marker" + && let forge_domain::AuthDetails::ApiKey(ref provider) = credential.auth_details + && provider.api_key().as_ref() == "google_adc_marker" { // Refresh the Google ADC credential, preserving url_params match self.refresh_google_adc_credential(&credential).await { @@ -398,6 +424,33 @@ impl< } } + // Refresh helper-command credentials on load — the last_key is not + // persisted so it must be obtained by executing the command. + if let forge_domain::AuthDetails::ApiKey(ref provider) = credential.auth_details + && provider.needs_refresh(chrono::Duration::zero()) + { + match forge_infra::api_key_helper::execute(provider).await { + Ok(refreshed) => { + credential.auth_details = forge_domain::AuthDetails::ApiKey(refreshed); + tracing::info!( + provider = %config.id, + "Successfully refreshed API key from helper command" + ); + } + Err(e) => { + tracing::error!( + provider = %config.id, + error = %e, + "Failed to refresh API key from helper command" + ); + return Err(e.context(format!( + "Failed to execute API key helper for provider {}", + config.id + ))); + } + } + } + // Handle models - keep as templates let models = config.models.as_ref().map(|m| match m { Models::Url(model_url_template) => forge_domain::ModelSource::Url( @@ -1046,7 +1099,9 @@ mod env_tests { .find(|c| c.id == ProviderId::OPENAI_COMPATIBLE) .unwrap(); match &openai_compat_cred.auth_details { - AuthDetails::ApiKey(key) => assert_eq!(key.as_str(), "test-openai-key"), + AuthDetails::ApiKey(provider) => { + assert_eq!(provider.api_key().as_str(), "test-openai-key") + } _ => panic!("Expected API key"), } @@ -1066,7 +1121,9 @@ mod env_tests { .find(|c| c.id == ProviderId::ANTHROPIC) .unwrap(); match &anthropic_cred.auth_details { - AuthDetails::ApiKey(key) => assert_eq!(key.as_str(), "test-anthropic-key"), + AuthDetails::ApiKey(provider) => { + assert_eq!(provider.api_key().as_str(), "test-anthropic-key") + } _ => panic!("Expected API key"), } } @@ -1213,7 +1270,8 @@ mod env_tests { .credential .as_ref() .and_then(|c| match &c.auth_details { - forge_domain::AuthDetails::ApiKey(key) => Some(key.to_string()), + forge_domain::AuthDetails::ApiKey(provider) => + Some(provider.api_key().to_string()), _ => None, }), Some("test-key-123".to_string()) diff --git a/crates/forge_services/src/app_config.rs b/crates/forge_services/src/app_config.rs index 3e279aae9b..89211c143c 100644 --- a/crates/forge_services/src/app_config.rs +++ b/crates/forge_services/src/app_config.rs @@ -107,7 +107,7 @@ mod tests { url: Url::parse("https://api.openai.com").unwrap(), credential: Some(forge_domain::AuthCredential { id: ProviderId::OPENAI, - auth_details: forge_domain::AuthDetails::ApiKey( + auth_details: forge_domain::AuthDetails::static_api_key( forge_domain::ApiKey::from("test-key".to_string()), ), url_params: HashMap::new(), @@ -135,7 +135,7 @@ mod tests { url_params: vec![], credential: Some(forge_domain::AuthCredential { id: ProviderId::ANTHROPIC, - auth_details: forge_domain::AuthDetails::ApiKey( + auth_details: forge_domain::AuthDetails::static_api_key( forge_domain::ApiKey::from("test-key".to_string()), ), url_params: HashMap::new(), diff --git a/crates/forge_services/src/context_engine.rs b/crates/forge_services/src/context_engine.rs index 713ca56bb4..3017c6b423 100644 --- a/crates/forge_services/src/context_engine.rs +++ b/crates/forge_services/src/context_engine.rs @@ -103,7 +103,7 @@ impl< .context("No authentication credentials found. Please authenticate first.")?; match &credential.auth_details { - AuthDetails::ApiKey(token) => { + AuthDetails::ApiKey(provider) => { // Extract user_id from URL params let user_id_str = credential .url_params @@ -113,7 +113,7 @@ impl< })?; let user_id = UserId::from_string(user_id_str.as_str())?; - Ok((token.clone(), user_id)) + Ok((provider.api_key().clone(), user_id)) } _ => anyhow::bail!("ForgeServices credential must be an API key"), } diff --git a/crates/forge_services/src/provider_auth.rs b/crates/forge_services/src/provider_auth.rs index 67c27b9595..959869de28 100644 --- a/crates/forge_services/src/provider_auth.rs +++ b/crates/forge_services/src/provider_auth.rs @@ -155,10 +155,12 @@ where // Iterate through auth methods and try to refresh for auth_method in &provider.auth_methods { match auth_method { - AuthMethod::OAuthDevice(_) + AuthMethod::ApiKey + | AuthMethod::OAuthDevice(_) | AuthMethod::OAuthCode(_) | AuthMethod::CodexDevice(_) - | AuthMethod::GoogleAdc => { + | AuthMethod::GoogleAdc + | AuthMethod::AwsProfile => { // Get existing credential let existing_credential = self.infra.get_credential(&provider.id).await?.ok_or_else( @@ -203,7 +205,6 @@ where } } } - _ => {} } } } diff --git a/crates/forge_services/src/provider_service.rs b/crates/forge_services/src/provider_service.rs index 39748dd3d6..4e12f63722 100644 --- a/crates/forge_services/src/provider_service.rs +++ b/crates/forge_services/src/provider_service.rs @@ -214,7 +214,7 @@ mod tests { url_params: vec![], credential: Some(AuthCredential { id: ProviderId::OPENAI, - auth_details: AuthDetails::ApiKey(forge_domain::ApiKey::from( + auth_details: AuthDetails::static_api_key(forge_domain::ApiKey::from( "test-key".to_string(), )), url_params: HashMap::new(), @@ -238,7 +238,7 @@ mod tests { url_params: vec![], credential: Some(AuthCredential { id: ProviderId::OPENAI, - auth_details: AuthDetails::ApiKey(forge_domain::ApiKey::from( + auth_details: AuthDetails::static_api_key(forge_domain::ApiKey::from( "test-key".to_string(), )), url_params: HashMap::new(), diff --git a/forge.schema.json b/forge.schema.json index e6b2d2e953..f584fff67d 100644 --- a/forge.schema.json +++ b/forge.schema.json @@ -635,6 +635,13 @@ "description": "A single provider entry defined inline in `forge.toml`.\n\nInline providers are merged with the built-in provider list; entries with\nthe same `id` override the corresponding built-in entry field-by-field,\nwhile entries with a new `id` are appended to the list.", "type": "object", "properties": { + "api_key_helper": { + "description": "Shell command that produces an API key on stdout. When set, the\ncommand is executed instead of reading a static key from an environment\nvariable. Falls back to `{api_key_var}_HELPER` env var when absent.", + "type": [ + "string", + "null" + ] + }, "api_key_var": { "description": "Environment variable holding the API key for this provider.", "type": [