diff --git a/2023/day6/src/main.rs b/2023/day6/src/main.rs index 4ef05d3..4818c88 100644 --- a/2023/day6/src/main.rs +++ b/2023/day6/src/main.rs @@ -6,6 +6,33 @@ use nom::{ sequence::{delimited, preceded, separated_pair, tuple}, IResult, }; +use std::str::FromStr; + +#[derive(PartialEq, Eq, Clone, Copy)] +#[allow(dead_code)] +enum Approach { + BruteForce, + Math, +} + +impl FromStr for Approach { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + Ok(match s { + "bruteforce" => Self::BruteForce, + "math" => Self::Math, + _ => return Err("unknown approach"), + }) + } +} + +#[allow(dead_code)] +impl Approach { + fn values() -> Vec { + vec![Self::BruteForce, Self::Math] + } +} fn number(i: &str) -> IResult<&str, usize> { map(digit1, |f: &str| f.parse::().unwrap())(i) @@ -21,6 +48,79 @@ struct Race { distance: usize, } +impl Race { + fn wins(&self, approach: Approach) -> usize { + match approach { + Approach::BruteForce => (0..=self.time) + .filter(|hold_time| { + let time_travelled = self.time - hold_time; + let speed = hold_time; + let distance_travelled = time_travelled * speed; + + distance_travelled > self.distance + }) + .collect::>() + .len(), + _ => { + // the races form a quadratic function: + // + // T = the total time of the race + // + // t = charging time, in [0..=self.time] + // y = resulting distance + // + // y = (T - t) x t + // y = -t^2 + Tt + // + // By substracing the winning time, we know that all values of x *above* + // y = 0 are winings + // + // D = winning distance + // + // y = -t^2 + Tt - D + // + // We need to calculate the roots of that function using the quadratic formula, + // with + // + // a = -1 + // b = T + // c = -D + + // use isizes to enable negation and pow/sqrt + let a: f64 = -1.0; + let b: f64 = self.time as f64; + let c: f64 = -(self.distance as f64); + + let roots: Option<(f64, f64)> = { + let discriminant = (b.powi(2)) - 4.0 * a * c; + if !discriminant.is_sign_positive() { + None + } else { + Some(( + (-b + f64::sqrt(discriminant)) / (2.0 * a), + (-b - f64::sqrt(discriminant)) / (2.0 * a), + )) + } + }; + + if let Some((x1, x2)) = roots { + // Sort the roots by size. The order is determined by the sign of a, so it's + // constant for our input, but this solution is more general + let (x1, x2) = if x1 > x2 { (x2, x1) } else { (x1, x2) }; + assert!(x1 < x2); + + // We actually found two roots. All whole integers that lie *between* the + // roots are races that we win. Exact matches are a tie and do not win. + (x2.ceil() as usize - 1) - (x1.floor() as usize + 1) + 1 + } else { + // No roots, there is no way for us to win the race. + 0 + } + } + } + } +} + #[derive(Debug)] struct RaceSheet { races: Vec, @@ -54,7 +154,7 @@ impl RaceSheet { } } -fn part1(input: &str) -> Result { +fn part1(input: &str, approach: Approach) -> Result { let (rest, racesheet) = RaceSheet::parse(input).map_err(|e| e.to_string())?; if !rest.is_empty() { eprintln!("parsing rest found: {rest}"); @@ -63,23 +163,12 @@ fn part1(input: &str) -> Result { let result = racesheet .races .into_iter() - .map(|race| { - (0..=race.time) - .filter(|hold_time| { - let time_travelled = race.time - hold_time; - let speed = hold_time; - let distance_travelled = time_travelled * speed; - - distance_travelled > race.distance - }) - .collect::>() - .len() - }) + .map(|race| race.wins(approach)) .product(); Ok(result) } -fn part2(input: &str) -> Result { +fn part2(input: &str, approach: Approach) -> Result { let (rest, racesheet) = RaceSheet::parse(input).map_err(|e| e.to_string())?; if !rest.is_empty() { eprintln!("parsing rest found: {rest}"); @@ -102,18 +191,9 @@ fn part2(input: &str) -> Result { .parse::() .unwrap(); - let result = (0..=time) - .filter(|hold_time| { - let time_travelled = time - hold_time; - let speed = hold_time; - let distance_travelled = time_travelled * speed; + let race = Race { time, distance }; - distance_travelled > distance - }) - .collect::>() - .len(); - - Ok(result) + Ok(race.wins(approach)) } fn main() -> Result<(), String> { @@ -121,13 +201,14 @@ fn main() -> Result<(), String> { let args = std::env::args().skip(1).collect::>(); let part = args[0].parse::().unwrap(); + let approach = args[1].parse()?; if part == 1 { - println!("Part 1 : {}", part1(input)?); + println!("Part 1 : {}", part1(input, approach)?); } else if part == 2 { - println!("Part 2 : {}", part2(input)?); + println!("Part 2 : {}", part2(input, approach)?); } else { - panic!("unknown part") + return Err("unknown part".into()); } Ok(()) @@ -145,7 +226,9 @@ mod tests { Distance: 9 40 200 "}; - assert_eq!(part1(&input).unwrap(), 288); + for approach in Approach::values() { + assert_eq!(part1(&input, approach).unwrap(), 288); + } } #[test] @@ -155,6 +238,8 @@ mod tests { Distance: 9 40 200 "}; - assert_eq!(part2(&input).unwrap(), 71503); + for approach in Approach::values() { + assert_eq!(part2(&input, approach).unwrap(), 71503); + } } }