001// Copyright (c) Choreo contributors
002
003package choreo.trajectory;
004
005import edu.wpi.first.math.geometry.Pose2d;
006import java.util.ArrayList;
007import java.util.List;
008import java.util.Optional;
009
010/**
011 * A trajectory loaded from Choreo.
012 *
013 * @param <SampleType> DifferentialSample or SwerveSample.
014 */
015public class Trajectory<SampleType extends TrajectorySample<SampleType>> {
016  private final String name;
017  private final List<SampleType> samples;
018  private final List<Integer> splits;
019  private final List<EventMarker> events;
020
021  /**
022   * Constructs a Trajectory with the specified parameters.
023   *
024   * @param name The name of the trajectory.
025   * @param samples The samples of the trajectory.
026   * @param splits The indices of the splits in the trajectory.
027   * @param events The events in the trajectory.
028   */
029  public Trajectory(
030      String name, List<SampleType> samples, List<Integer> splits, List<EventMarker> events) {
031    this.name = name;
032    this.samples = samples;
033    this.splits = splits;
034    this.events = events;
035  }
036
037  /**
038   * Returns the name of the trajectory.
039   *
040   * @return the name of the trajectory.
041   */
042  public String name() {
043    return name;
044  }
045
046  /**
047   * Returns the samples of the trajectory.
048   *
049   * @return the samples of the trajectory.
050   */
051  public List<SampleType> samples() {
052    return samples;
053  }
054
055  /**
056   * Returns the indices of the splits in the trajectory.
057   *
058   * @return the indices of the splits in the trajectory.
059   */
060  public List<Integer> splits() {
061    return splits;
062  }
063
064  /**
065   * Returns the events in the trajectory.
066   *
067   * @return the events in the trajectory.
068   */
069  public List<EventMarker> events() {
070    return events;
071  }
072
073  /**
074   * Returns the first {@link SampleType} in the trajectory.
075   *
076   * <p>This function will return an empty Optional if the trajectory is empty.
077   *
078   * @param mirrorForRedAlliance whether or not to return the sample as mirrored across the field
079   * @return The first {@link SampleType} in the trajectory.
080   */
081  public Optional<SampleType> getInitialSample(boolean mirrorForRedAlliance) {
082    if (samples.isEmpty()) {
083      return Optional.empty();
084    }
085    final var sample = samples.get(0);
086    return Optional.of(mirrorForRedAlliance ? sample.flipped() : sample);
087  }
088
089  /**
090   * Returns the last {@link SampleType} in the trajectory.
091   *
092   * <p>This function will return an empty Optional if the trajectory is empty.
093   *
094   * @param mirrorForRedAlliance whether or not to return the sample as mirrored across the field
095   * @return The last {@link SampleType} in the trajectory.
096   */
097  public Optional<SampleType> getFinalSample(boolean mirrorForRedAlliance) {
098    if (samples.isEmpty()) {
099      return Optional.empty();
100    }
101    final var sample = samples.get(samples.size() - 1);
102    return Optional.of(mirrorForRedAlliance ? sample.flipped() : sample);
103  }
104
105  private Optional<SampleType> sampleInternal(double timestamp) {
106    if (samples.isEmpty()) {
107      return Optional.empty();
108    } else if (samples.size() == 1) {
109      return Optional.of(samples.get(0));
110    }
111    if (timestamp < samples.get(0).getTimestamp()) {
112      // timestamp oob, return the initial state
113      return getInitialSample(false);
114    }
115    if (timestamp >= getTotalTime()) {
116      // timestamp oob, return the final state
117      return getFinalSample(false);
118    }
119
120    // binary search to find the sample before and ahead of the timestamp
121    int low = 0;
122    int high = samples.size() - 1;
123
124    while (low != high) {
125      int mid = (low + high) / 2;
126      if (samples.get(mid).getTimestamp() < timestamp) {
127        low = mid + 1;
128      } else {
129        high = mid;
130      }
131    }
132
133    if (low == 0) {
134      return Optional.of(samples.get(low));
135    }
136
137    var behindState = samples.get(low - 1);
138    var aheadState = samples.get(low);
139
140    if ((aheadState.getTimestamp() - behindState.getTimestamp()) < 1e-6) {
141      return Optional.of(aheadState);
142    }
143
144    return Optional.of(behindState.interpolate(aheadState, timestamp));
145  }
146
147  /**
148   * Return an interpolated sample of the trajectory at the given timestamp.
149   *
150   * <p>This function will return an empty Optional if the trajectory is empty.
151   *
152   * @param timestamp The timestamp of this sample relative to the beginning of the trajectory.
153   * @param mirrorForRedAlliance whether or not to return the sample as mirrored across the field
154   *     midline (as in 2023).
155   * @return The SampleType at the given time.
156   */
157  public Optional<SampleType> sampleAt(double timestamp, boolean mirrorForRedAlliance) {
158    Optional<SampleType> sample = sampleInternal(timestamp);
159    return mirrorForRedAlliance ? sample.map(SampleType::flipped) : sample;
160  }
161
162  /**
163   * Returns the initial pose of the trajectory.
164   *
165   * <p>This function will return an empty Optional if the trajectory is empty.
166   *
167   * @param mirrorForRedAlliance whether or not to return the pose as mirrored across the field
168   * @return the initial pose of the trajectory.
169   */
170  public Optional<Pose2d> getInitialPose(boolean mirrorForRedAlliance) {
171    if (samples.isEmpty()) {
172      return Optional.empty();
173    }
174    return getInitialSample(mirrorForRedAlliance).map(SampleType::getPose);
175  }
176
177  /**
178   * Returns the final pose of the trajectory.
179   *
180   * <p>This function will return an empty Optional if the trajectory is empty.
181   *
182   * @param mirrorForRedAlliance whether or not to return the pose as mirrored across the field
183   * @return the final pose of the trajectory.
184   */
185  public Optional<Pose2d> getFinalPose(boolean mirrorForRedAlliance) {
186    if (samples.isEmpty()) {
187      return Optional.empty();
188    }
189    return getFinalSample(mirrorForRedAlliance).map(SampleType::getPose);
190  }
191
192  /**
193   * Returns the total time of the trajectory (the timestamp of the last sample). This will return 0
194   * if the trajectory is empty.
195   *
196   * @return the total time of the trajectory (the timestamp of the last sample)
197   */
198  public double getTotalTime() {
199    return getFinalSample(false).map(SampleType::getTimestamp).orElse(0.0);
200  }
201
202  /**
203   * Returns the array of poses corresponding to the trajectory.
204   *
205   * @return the array of poses corresponding to the trajectory.
206   */
207  public Pose2d[] getPoses() {
208    return samples.stream().map(SampleType::getPose).toArray(Pose2d[]::new);
209  }
210
211  /**
212   * Returns this trajectory, flipped to the other alliance according to the symmetry of the field.
213   *
214   * @return this trajectory, flipped to the other alliance according to the symmetry of the field.
215   */
216  public Trajectory<SampleType> flipped() {
217    var flippedStates = new ArrayList<SampleType>();
218    for (var state : samples) {
219      flippedStates.add(state.flipped());
220    }
221    return new Trajectory<SampleType>(this.name, flippedStates, this.splits, this.events);
222  }
223
224  /**
225   * Returns this trajectory, mirrored to the other alliance.
226   *
227   * @return this trajectory, mirrored to the other alliance.
228   */
229  public Trajectory<SampleType> mirrorX() {
230    var flippedStates = new ArrayList<SampleType>();
231    for (var state : samples) {
232      flippedStates.add(state.mirrorX());
233    }
234    return new Trajectory<SampleType>(this.name, flippedStates, this.splits, this.events);
235  }
236
237  /**
238   * Returns this trajectory, mirrored left-to-right from the driver's perspective.
239   *
240   * @return this trajectory, mirrored left-to-right from the driver's perspective.
241   */
242  public Trajectory<SampleType> mirrorY() {
243    var flippedStates = new ArrayList<SampleType>();
244    for (var state : samples) {
245      flippedStates.add(state.mirrorY());
246    }
247    return new Trajectory<SampleType>(this.name, flippedStates, this.splits, this.events);
248  }
249
250  /**
251   * Returns this trajectory, rotated 180 degrees around the field center.
252   *
253   * @return this trajectory, rotated 180 degrees around the field center.
254   */
255  public Trajectory<SampleType> rotateAround() {
256    var flippedStates = new ArrayList<SampleType>();
257    for (var state : samples) {
258      flippedStates.add(state.rotateAround());
259    }
260    return new Trajectory<SampleType>(this.name, flippedStates, this.splits, this.events);
261  }
262
263  /**
264   * Returns a list of all events with the given name in the trajectory.
265   *
266   * @param eventName The name of the event.
267   * @return A list of all events with the given name in the trajectory, if no events are found, an
268   *     empty list is returned.
269   */
270  public List<EventMarker> getEvents(String eventName) {
271    return events.stream().filter(event -> event.event.equals(eventName)).toList();
272  }
273
274  /**
275   * Returns a choreo trajectory that represents the split of the trajectory at the given index.
276   *
277   * @param splitIndex the index of the split trajectory to return.
278   * @return a choreo trajectory that represents the split of the trajectory at the given index.
279   */
280  public Optional<Trajectory<SampleType>> getSplit(int splitIndex) {
281    // Assumption: splits.get(splitIndex) is a valid index of samples.
282    if (splitIndex < 0 || splitIndex >= splits.size()) {
283      return Optional.empty();
284    }
285    int start = splits.get(splitIndex);
286    int end = splitIndex + 1 < splits.size() ? splits.get(splitIndex + 1) + 1 : samples.size();
287    var sublist = samples.subList(start, end);
288    // Empty section should not be achievable (would mean malformed splits array), but is handled
289    // for safety
290    if (sublist.size() == 0) {
291      return Optional.of(
292          new Trajectory<SampleType>(
293              this.name + "[" + splitIndex + "]", List.of(), List.of(), List.of()));
294    }
295    // Now we know sublist.size() >= 1
296    double startTime = sublist.get(0).getTimestamp();
297    double endTime = sublist.get(sublist.size() - 1).getTimestamp();
298    return Optional.of(
299        new Trajectory<SampleType>(
300            this.name + "[" + splitIndex + "]",
301            sublist.stream().map(s -> s.offsetBy(-startTime)).toList(),
302            List.of(),
303            events.stream()
304                .filter(e -> e.timestamp >= startTime && e.timestamp <= endTime)
305                .map(e -> e.offsetBy(-startTime))
306                .toList()));
307  }
308
309  @Override
310  public boolean equals(Object obj) {
311    if (!(obj instanceof Trajectory<?>)) {
312      return false;
313    }
314
315    var other = (Trajectory<?>) obj;
316    return this.name.equals(other.name)
317        && this.samples.equals(other.samples)
318        && this.splits.equals(other.splits)
319        && this.events.equals(other.events);
320  }
321}