001// Copyright (c) Choreo contributors
002
003package choreo;
004
005import static edu.wpi.first.util.ErrorMessages.requireNonNullParam;
006import static edu.wpi.first.wpilibj.Alert.AlertType.kError;
007
008import choreo.trajectory.DifferentialSample;
009import choreo.trajectory.EventMarker;
010import choreo.trajectory.SwerveSample;
011import choreo.trajectory.Trajectory;
012import choreo.trajectory.TrajectorySample;
013import choreo.util.ChoreoAlert;
014import choreo.util.ChoreoAlert.*;
015import choreo.util.TrajSchemaVersion;
016import com.google.gson.Gson;
017import com.google.gson.GsonBuilder;
018import com.google.gson.JsonObject;
019import com.google.gson.JsonSyntaxException;
020import edu.wpi.first.hal.FRCNetComm.tResourceType;
021import edu.wpi.first.hal.HAL;
022import edu.wpi.first.wpilibj.DriverStation;
023import edu.wpi.first.wpilibj.Filesystem;
024import java.io.BufferedReader;
025import java.io.File;
026import java.io.FileNotFoundException;
027import java.io.FileReader;
028import java.util.ArrayList;
029import java.util.Arrays;
030import java.util.HashMap;
031import java.util.List;
032import java.util.Map;
033import java.util.Optional;
034import java.util.function.BiConsumer;
035
036/** Utilities to load and follow Choreo Trajectories */
037public final class Choreo {
038  private static final Gson GSON =
039      new GsonBuilder()
040          .registerTypeAdapter(EventMarker.class, new EventMarker.Deserializer())
041          .create();
042  private static final String TRAJECTORY_FILE_EXTENSION = ".traj";
043  private static final int TRAJ_SCHEMA_VERSION = TrajSchemaVersion.TRAJ_SCHEMA_VERSION;
044  private static final MultiAlert cantFindTrajectory =
045      ChoreoAlert.multiAlert(causes -> "Could not find trajectory files: " + causes, kError);
046  private static final MultiAlert cantParseTrajectory =
047      ChoreoAlert.multiAlert(causes -> "Could not parse trajectory files: " + causes, kError);
048
049  private static File CHOREO_DIR = new File(Filesystem.getDeployDirectory(), "choreo");
050
051  /** This should only be used for unit testing. */
052  static void setChoreoDir(File choreoDir) {
053    CHOREO_DIR = choreoDir;
054  }
055
056  /**
057   * This interface exists as a type alias. A TrajectoryLogger has a signature of ({@link
058   * Trajectory}, {@link Boolean})->void, where the function consumes a trajectory and a boolean
059   * indicating whether the trajectory is starting or finishing.
060   *
061   * @param <ST> {@link choreo.trajectory.DifferentialSample} or {@link
062   *     choreo.trajectory.SwerveSample}
063   */
064  public interface TrajectoryLogger<ST extends TrajectorySample<ST>>
065      extends BiConsumer<Trajectory<ST>, Boolean> {}
066
067  /** Default constructor. */
068  private Choreo() {
069    throw new UnsupportedOperationException("This is a utility class!");
070  }
071
072  /**
073   * Load a trajectory from the deploy directory. Choreolib expects .traj files to be placed in
074   * src/main/deploy/choreo/[trajectoryName].traj.
075   *
076   * @param <SampleType> The type of samples in the trajectory.
077   * @param trajectoryName The path name in Choreo, which matches the file name in the deploy
078   *     directory, file extension is optional.
079   * @return The loaded trajectory, or `Optional.empty()` if the trajectory could not be loaded.
080   */
081  @SuppressWarnings("unchecked")
082  public static <SampleType extends TrajectorySample<SampleType>>
083      Optional<Trajectory<SampleType>> loadTrajectory(String trajectoryName) {
084    requireNonNullParam(trajectoryName, "trajectoryName", "Choreo.loadTrajectory");
085
086    if (trajectoryName.endsWith(TRAJECTORY_FILE_EXTENSION)) {
087      trajectoryName =
088          trajectoryName.substring(0, trajectoryName.length() - TRAJECTORY_FILE_EXTENSION.length());
089    }
090    File trajectoryFile = new File(CHOREO_DIR, trajectoryName + TRAJECTORY_FILE_EXTENSION);
091    try {
092      var reader = new BufferedReader(new FileReader(trajectoryFile));
093      String str = reader.lines().reduce("", (a, b) -> a + b);
094      reader.close();
095      Trajectory<SampleType> trajectory = (Trajectory<SampleType>) loadTrajectoryString(str);
096      return Optional.of(trajectory);
097    } catch (FileNotFoundException ex) {
098      cantFindTrajectory.addCause(trajectoryFile.toString());
099    } catch (JsonSyntaxException ex) {
100      cantParseTrajectory.addCause(trajectoryFile.toString());
101    } catch (Exception ex) {
102      ChoreoAlert.alert(
103              "Unknown error when parsing " + trajectoryFile + "; check console for more details",
104              kError)
105          .set(true);
106      DriverStation.reportError(ex.getMessage(), ex.getStackTrace());
107    }
108    return Optional.empty();
109  }
110
111  /**
112   * Fetches the names of all available trajectories in the deploy directory.
113   *
114   * @return A list of all available trajectory names.
115   */
116  public static String[] availableTrajectories() {
117    List<String> trajectories = new ArrayList<>();
118    File[] files = CHOREO_DIR.listFiles();
119    if (files != null) {
120      for (File file : files) {
121        if (file.getName().endsWith(TRAJECTORY_FILE_EXTENSION)) {
122          trajectories.add(
123              file.getName()
124                  .substring(0, file.getName().length() - TRAJECTORY_FILE_EXTENSION.length()));
125        }
126      }
127    }
128    return trajectories.toArray(new String[0]);
129  }
130
131  /**
132   * Load a trajectory from a string.
133   *
134   * @param trajectoryJsonString The JSON string.
135   * @return The loaded trajectory, or `empty std::optional` if the trajectory could not be loaded.
136   */
137  static Trajectory<? extends TrajectorySample<?>> loadTrajectoryString(
138      String trajectoryJsonString) {
139    JsonObject wholeTrajectory = GSON.fromJson(trajectoryJsonString, JsonObject.class);
140    String name = wholeTrajectory.get("name").getAsString();
141    int version;
142    try {
143      version = wholeTrajectory.get("version").getAsInt();
144      if (version != TRAJ_SCHEMA_VERSION) {
145        throw new RuntimeException(
146            name + ".traj: Wrong version: " + version + ". Expected " + TRAJ_SCHEMA_VERSION);
147      }
148    } catch (ClassCastException e) {
149      throw new RuntimeException(
150          name
151              + ".traj: Wrong version: "
152              + wholeTrajectory.get("version").getAsString()
153              + ". Expected "
154              + TRAJ_SCHEMA_VERSION);
155    }
156    // Filter out markers with negative timestamps or empty names
157    List<EventMarker> unfilteredEvents =
158        new ArrayList<EventMarker>(
159            Arrays.asList(GSON.fromJson(wholeTrajectory.get("events"), EventMarker[].class)));
160    unfilteredEvents.removeIf(marker -> marker.timestamp < 0 || marker.event.length() == 0);
161    EventMarker[] events = new EventMarker[unfilteredEvents.size()];
162    unfilteredEvents.toArray(events);
163
164    JsonObject trajectoryObj = wholeTrajectory.getAsJsonObject("trajectory");
165    Integer[] splits = GSON.fromJson(trajectoryObj.get("splits"), Integer[].class);
166    if (splits.length == 0 || splits[0] != 0) {
167      Integer[] newArray = new Integer[splits.length + 1];
168      newArray[0] = 0;
169      System.arraycopy(splits, 0, newArray, 1, splits.length);
170      splits = newArray;
171    }
172    String sampleType = trajectoryObj.get("sampleType").getAsString();
173    if (sampleType.equals("Swerve")) {
174      HAL.report(tResourceType.kResourceType_ChoreoTrajectory, 1);
175
176      SwerveSample[] samples = GSON.fromJson(trajectoryObj.get("samples"), SwerveSample[].class);
177      return new Trajectory<SwerveSample>(name, List.of(samples), List.of(splits), List.of(events));
178    } else if (sampleType.equals("Differential")) {
179      HAL.report(tResourceType.kResourceType_ChoreoTrajectory, 2);
180
181      DifferentialSample[] sampleArray =
182          GSON.fromJson(trajectoryObj.get("samples"), DifferentialSample[].class);
183      return new Trajectory<DifferentialSample>(
184          name, List.of(sampleArray), List.of(splits), List.of(events));
185    } else {
186      throw new RuntimeException("Unknown drive type: " + sampleType);
187    }
188  }
189
190  /**
191   * A utility for caching loaded trajectories. This allows for loading trajectories only once, and
192   * then reusing them.
193   */
194  public static class TrajectoryCache {
195    private final Map<String, Trajectory<?>> cache;
196
197    /** Creates a new TrajectoryCache with a normal {@link HashMap} as the cache. */
198    public TrajectoryCache() {
199      cache = new HashMap<>();
200    }
201
202    /**
203     * Creates a new TrajectoryCache with a custom cache.
204     *
205     * <p>this could be useful if you want to use a concurrent map or a map with a maximum size.
206     *
207     * @param cache The cache to use.
208     */
209    public TrajectoryCache(Map<String, Trajectory<?>> cache) {
210      requireNonNullParam(cache, "cache", "TrajectoryCache.<init>");
211      this.cache = cache;
212    }
213
214    /**
215     * Load a trajectory from the deploy directory. Choreolib expects .traj files to be placed in
216     * src/main/deploy/choreo/[trajectoryName].traj.
217     *
218     * <p>This method will cache the loaded trajectory and reused it if it is requested again.
219     *
220     * @param trajectoryName the path name in Choreo, which matches the file name in the deploy
221     *     directory, file extension is optional.
222     * @return the loaded trajectory, or `Optional.empty()` if the trajectory could not be loaded.
223     * @see Choreo#loadTrajectory(String)
224     */
225    public Optional<? extends Trajectory<?>> loadTrajectory(String trajectoryName) {
226      requireNonNullParam(trajectoryName, "trajectoryName", "TrajectoryCache.loadTrajectory");
227      if (cache.containsKey(trajectoryName)) {
228        return Optional.of(cache.get(trajectoryName));
229      } else {
230        return Choreo.loadTrajectory(trajectoryName)
231            .map(
232                trajectory -> {
233                  cache.put(trajectoryName, trajectory);
234                  return trajectory;
235                });
236      }
237    }
238
239    /**
240     * Load a section of a split trajectory from the deploy directory. Choreolib expects .traj files
241     * to be placed in src/main/deploy/choreo/[trajectoryName].traj.
242     *
243     * <p>This method will cache the loaded trajectory and reused it if it is requested again. The
244     * trajectory that is split off of will also be cached.
245     *
246     * @param trajectoryName the path name in Choreo, which matches the file name in the deploy
247     *     directory, file extension is optional.
248     * @param splitIndex the index of the split trajectory to load
249     * @return the loaded trajectory, or `Optional.empty()` if the trajectory could not be loaded.
250     * @see Choreo#loadTrajectory(String)
251     */
252    public Optional<? extends Trajectory<?>> loadTrajectory(String trajectoryName, int splitIndex) {
253      requireNonNullParam(trajectoryName, "trajectoryName", "TrajectoryCache.loadTrajectory");
254      // make the key something that could never possibly be a valid trajectory name
255      String key = trajectoryName + ".:." + splitIndex;
256      if (cache.containsKey(key)) {
257        return Optional.of(cache.get(key));
258      } else if (cache.containsKey(trajectoryName)) {
259        return cache
260            .get(trajectoryName)
261            .getSplit(splitIndex)
262            .map(
263                trajectory -> {
264                  cache.put(key, trajectory);
265                  return trajectory;
266                });
267      } else {
268        return Choreo.loadTrajectory(trajectoryName)
269            .flatMap(
270                trajectory -> {
271                  cache.put(trajectoryName, trajectory);
272                  return trajectory
273                      .getSplit(splitIndex)
274                      .map(
275                          split -> {
276                            cache.put(key, split);
277                            return split;
278                          });
279                });
280      }
281    }
282
283    /** Clear the cache. */
284    public void clear() {
285      cache.clear();
286    }
287  }
288}