diff --git a/src/lib.rs b/src/lib.rs index 8600169..4395072 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -901,16 +901,22 @@ pub fn run() { None => action_args.name.clone(), }; + let mut remote_branch_exists = false; + let checkout_commit = match &action_args.track { Some(upstream_branch_name) => { match repo.find_branch(upstream_branch_name, git2::BranchType::Remote) { - Ok(branch) => branch.into_reference().peel_to_commit().unwrap(), + Ok(branch) => { + remote_branch_exists = true; + branch.into_reference().peel_to_commit().unwrap() + } Err(_) => { - print_error(&format!( - "Remote branch {} not found", - &upstream_branch_name - )); - process::exit(1); + remote_branch_exists = false; + get_default_branch(&repo) + .unwrap() + .into_reference() + .peel_to_commit() + .unwrap() } } } @@ -928,10 +934,66 @@ pub fn run() { }; if let Some(upstream_branch_name) = action_args.track { - target_branch - .set_upstream(Some(&upstream_branch_name)) - .unwrap(); - } + if remote_branch_exists { + target_branch + .set_upstream(Some(&upstream_branch_name)) + .unwrap(); + } else { + print_error(&format!( + "Remote branch {} not found", + &upstream_branch_name + )); + let split_at = upstream_branch_name.find("/").unwrap_or(0); + if split_at == 0 || split_at >= upstream_branch_name.len() - 1 { + print_error("Tracking branch needs to match the pattern /"); + process::exit(1); + } + + let (remote_name, remote_branch_name) = + &upstream_branch_name.split_at(split_at); + // strip the remaining slash + let remote_branch_name = &remote_branch_name[1..]; + + let mut remote = match repo.find_remote(remote_name) { + Ok(r) => r, + Err(_) => { + print_error(&format!("Remote {} not found", remote_name)); + process::exit(1); + } + }; + + let mut callbacks = git2::RemoteCallbacks::new(); + callbacks.push_update_reference(|_, status| { + if let Some(message) = status { + return Err(git2::Error::new( + git2::ErrorCode::GenericError, + git2::ErrorClass::None, + message, + )); + } + Ok(()) + }); + callbacks.credentials(|_url, username_from_url, _allowed_types| { + git2::Cred::ssh_key_from_agent(username_from_url.unwrap()) + }); + + let mut push_options = git2::PushOptions::new(); + push_options.remote_callbacks(callbacks); + + let push_refspec = format!( + "+{}:refs/heads/{}", + target_branch.get().name().unwrap(), + remote_branch_name + ); + remote + .push(&[push_refspec], Some(&mut push_options)) + .unwrap(); + + target_branch + .set_upstream(Some(&upstream_branch_name)) + .unwrap(); + } + }; let worktree = repo.worktree( &action_args.name,