From fad6f71876b14f0ebeeedc4a90c90d11facf4f20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20K=C3=B6rber?= Date: Mon, 13 Jun 2022 22:47:49 +0200 Subject: [PATCH] Improve default branch guessing --- src/repo.rs | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 82 insertions(+), 4 deletions(-) diff --git a/src/repo.rs b/src/repo.rs index 7b99179..763705b 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -1042,12 +1042,76 @@ impl RepoHandle { }) } - pub fn default_branch(&self) -> Result { - let branch_names = vec!["main", "master"]; + pub fn get_remote_default_branch(&self, remote_name: &str) -> Result, String> { + // libgit2's `git_remote_default_branch()` and `Remote::default_branch()` + // need an actual connection to the remote, so they may fail. + if let Some(mut remote) = self.find_remote(remote_name)? { + if remote.connected() { + let remote = remote; // unmut + if let Ok(remote_default_branch) = remote.default_branch() { + return Ok(Some(self.find_local_branch(&remote_default_branch)?)); + }; + } + } - for branch_name in &branch_names { + // Note that /HEAD only exists after a normal clone, there is no way to get the + // remote HEAD afterwards. So this is a "best effort" approach. + if let Ok(remote_head) = self.find_remote_branch(remote_name, "HEAD") { + if let Some(pointer_name) = remote_head.as_reference().symbolic_target() { + if let Some(local_branch_name) = + pointer_name.strip_prefix(&format!("refs/remotes/{}/", remote_name)) + { + return Ok(Some(self.find_local_branch(local_branch_name)?)); + } else { + eprintln!("Remote HEAD ({}) pointer is invalid", pointer_name); + } + } else { + eprintln!("Remote HEAD does not point to a symbolic target"); + } + } + Ok(None) + } + + pub fn default_branch(&self) -> Result { + // This is a bit of a guessing game. + // + // In the best case, there is only one remote. Then, we can check /HEAD to get the + // default remote branch. + // + // If there are multiple remotes, we first check whether they all have the same + // /HEAD branch. If yes, good! If not, we use whatever "origin" uses, if that + // exists. If it does not, there is no way to reliably get a remote default branch. + // + // In this case, we just try to guess a local branch from a list. If even that does not + // work, well, bad luck. + let remotes = self.remotes()?; + + if remotes.len() == 1 { + let remote_name = &remotes[0]; + if let Some(default_branch) = self.get_remote_default_branch(remote_name)? { + return Ok(default_branch); + } + } else { + let mut default_branches: Vec = vec![]; + for remote_name in remotes { + if let Some(default_branch) = self.get_remote_default_branch(&remote_name)? { + default_branches.push(default_branch) + } + } + + if !default_branches.is_empty() + && (default_branches.len() == 1 + || default_branches + .windows(2) + .all(|w| w[0].name() == w[1].name())) + { + return Ok(default_branches.remove(0)); + } + } + + for branch_name in &vec!["main", "master"] { if let Ok(branch) = self.0.find_branch(branch_name, git2::BranchType::Local) { - return Ok(Branch(branch)) + return Ok(Branch(branch)); } } @@ -1458,6 +1522,20 @@ impl RemoteHandle<'_> { .to_string() } + pub fn connected(&mut self) -> bool { + self.0.connected() + } + + pub fn default_branch(&self) -> Result { + Ok(self + .0 + .default_branch() + .map_err(convert_libgit2_error)? + .as_str() + .expect("Remote branch name is not valid utf-8") + .to_string()) + } + pub fn is_pushable(&self) -> Result { let remote_type = detect_remote_type(self.0.url().expect("Remote name is not valid utf-8")) .ok_or_else(|| String::from("Could not detect remote type"))?;