Discord bot that plays music from every website ever via youtube-dl
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Chords/src/main/java/moe/nekojimi/chords/Downloader.java

533 lines
19 KiB

/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package moe.nekojimi.chords;
import com.beust.jcommander.Strings;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.BiConsumer;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.json.Json;
import javax.json.JsonArray;
import javax.json.JsonObject;
import javax.json.JsonReader;
import moe.nekojimi.musicsearcher.Result;
import net.dv8tion.jda.api.audio.AudioSendHandler;
/**
*
* @author jimj316
*/
public class Downloader extends QueueThing<TrackRequest, Track>
{
private static final int DOWNLOAD_TIMEOUT = 300;
private static final int INFO_TIMEOUT = 60;
private static final int FORMAT_TIMEOUT = 5;
private static final int BITRATE_TARGET = (int) AudioSendHandler.INPUT_FORMAT.getSampleRate();
private static final Pattern FORMAT_PATTERN = Pattern.compile("^([\\w]+)\\s+([\\w]+)\\s+(\\w+ ?\\w*)\\s+(.*)$");
public static final Pattern DESTINATION_PATTERN = Pattern.compile("Destination: (.*\\.wav)");
public static final Pattern STREAM_PATTERN = Pattern.compile("Destination: -");
public static final Pattern PROGRESS_PATTERN = Pattern.compile("\\[download\\].*?([\\d\\.]+)%");
private static final Pattern INFO_JSON_PATTERN = Pattern.compile("Writing video metadata as JSON to: (.*\\.info\\.json)");
private static final Pattern ETA_PATTERN = Pattern.compile("\\[download\\].*?ETA\\s+(\\d{1,2}:\\d{2})");
private static final Pattern DOWNLOAD_ITEM_PATTERN = Pattern.compile("\\[download\\] Downloading item (\\d+) of (\\d+)");
public static final String[] INFO_TITLE_KEYS =
{
"track", "title", "fulltitle"
};
public static final String[] INFO_ARTIST_KEYS =
{
"artist", "channel", "uploader"
};
private final List<TrackRequest> downloadingTracks = new LinkedList<>();
private final LinkedBlockingDeque<Runnable> workQueue = new LinkedBlockingDeque<>();
private final ThreadPoolExecutor executor = new ThreadPoolExecutor(2, 4, 30, TimeUnit.SECONDS, workQueue);
private final ScheduledExecutorService scheduler = new ScheduledThreadPoolExecutor(1);
// private Consumer<Track> next;
private BiConsumer<TrackRequest, Exception> messageHandler;
private File downloadDir = null;
private int trackNumber = 1;
public Downloader()
{
super(new LinkedList<>());
}
//
// @Override
// public void accept(TrackRequest request)
// {
// if all tracks of the request are already downloaded, just skip
// if (!request.getTracks().isEmpty() && request.getTracks().stream().allMatch((t) -> t.isDownloaded()))
// {
// for (Track track : request.getTracks())
// getDestination().accept(track);
// return;
// }
// inputQueue.add(task);
// }
@Override
protected boolean completePromise(Promise<TrackRequest, Track> promise)
{
TrackRequest request = promise.getInput();
if (request == null)
return false;
executor.submit(() ->
{
downloadingTracks.add(request);
try
{
getInfo(request);
// boolean streamOutput = request.getTracks().size() == 1;
boolean streamOutput = false;
download(promise, streamOutput);
} catch (Exception ex)
{
ex.printStackTrace();
}
downloadingTracks.remove(request);
});
return true;
}
private List<Format> sortFormats(Collection<Format> input)
{
List<Format> formats = new ArrayList<>(input);
formats.sort((Format a, Format b) ->
{
// audio only preferred to video
int comp = 0;
comp = Boolean.compare(a.isAudioOnly(), b.isAudioOnly());
if (comp == 0)
{
// known preferred to unknown
if (a.getSampleRate() == b.getSampleRate())
comp = 0;
else if (a.getSampleRate() <= 0)
comp = 1;
else if (b.getSampleRate() <= 0)
comp = -1;
else // closer to the target bitrate is best
{
int aDist = Math.abs(a.getSampleRate() - BITRATE_TARGET);
int bDist = Math.abs(b.getSampleRate() - BITRATE_TARGET);
comp = Integer.compare(bDist, aDist);
}
}
if (comp == 0)
{
// known preferred to unknown
if (a.getSize() == b.getSize())
comp = 0;
else if (a.getSize() <= 0)
comp = 1;
else if (b.getSize() <= 0)
comp = -1;
else // smaller is better
comp = Long.compare(b.getSize(), a.getSize());
}
return -comp;
});
return formats;
}
private void getFormats(Track track)
{
if (!track.getFormats().isEmpty())
return;
try
{
// String cmd = + " --skip-download -F " + track.getUrl().toString();
Process exec = runCommand(List.of(Chords.getSettings().getYtdlCommand(), "--skip-download", "-F", track.getUrl().toString()), FORMAT_TIMEOUT);
InputStream input = exec.getInputStream();
String output = new String(input.readAllBytes(), Charset.defaultCharset());
List<Format> formats = new ArrayList<>();
int codeCol = 0;
int extCol = 1;
int resCol = 2;
int noteCol = 3;
int sizeCol = -1;
int bitrateCol = -1;
List<String> list = output.lines().collect(Collectors.toList());
int i = 0;
while (!list.get(i).contains("Available formats"))
i++;
i++;
if (list.get(i).contains("FILESIZE"))
{
String[] split = list.get(i).split("\\s+");
for (int j = 0; j < split.length; j++)
{
switch (split[j])
{
case "ID":
codeCol = j;
break;
case "EXT":
extCol = j;
break;
case "RESOLUTION":
resCol = j;
break;
case "MORE":
noteCol = j;
break;
case "FILESIZE":
sizeCol = j;
break;
case "TBR":
bitrateCol = j;
}
}
i += 2;
}
for (; i < list.size(); i++)
{
String line = list.get(i);
String[] split = line.split("\\s+", Math.max(4, noteCol - 1));
if (split.length < 4)
continue;
final Format format = new Format(split[codeCol], split[extCol], split[resCol], split[noteCol]);
if (sizeCol >= 0)
format.setSize(Util.parseSize(split[sizeCol]));
if (bitrateCol >= 0)
format.setSampleRate(Util.parseSampleRate(split[bitrateCol]));
formats.add(format);
}
track.setFormats(formats);
} catch (Exception ex)
{
Logger.getLogger(Downloader.class.getName()).log(Level.SEVERE, null, ex);
// return List.of();
}
}
private List<Track> getInfo(TrackRequest request)
{
List<Track> ret = new ArrayList<>();
try
{
// String cmd = Chords.getSettings().getYtdlCommand() + " --skip-download --print-json " + request.getUrl().toString();
Process exec = runCommand(List.of(Chords.getSettings().getYtdlCommand(), "--skip-download", "--print-json", request.getUrl().toString()), INFO_TIMEOUT);
InputStream input = exec.getInputStream();
// read each line as JSON, turn each into a track object
Scanner sc = new Scanner(input);
while (sc.hasNextLine())
{
Track track = new Track(request);
track.setNumber(trackNumber);
trackNumber++;
request.addTrack(track);
String line = sc.nextLine();
JsonReader reader = Json.createReader(new StringReader(line));
JsonObject object = reader.readObject();
Result result = request.getResult();
// look for metadata
if (track.getTitle() == null)
{
if (result != null && !result.getTitle().isBlank())
track.setTitle(result.getTitle());
for (String key : INFO_TITLE_KEYS)
{
if (object.containsKey(key) && !object.getString(key).isBlank())
{
track.setTitle(object.getString(key));
break;
}
}
}
if (track.getArtist() == null)
{
if (result != null && !result.getArtist().isBlank())
track.setArtist(result.getArtist());
for (String key : INFO_ARTIST_KEYS)
{
if (object.containsKey(key) && !object.getString(key).isBlank())
{
track.setArtist(object.getString(key));
break;
}
}
}
if (track.getTitle().contains("-"))
{
String[] split = track.getTitle().split("-", 2);
track.setArtist(split[0]);
track.setTitle(split[1]);
}
if (track.getArtist().contains(" - Topic"))
{
track.setArtist(track.getArtist().replace(" - Topic", ""));
}
JsonArray formatsJSON = object.getJsonArray("formats");
if (formatsJSON != null)
{
List<Format> formats = new ArrayList<>();
for (JsonObject formatJson : formatsJSON.getValuesAs(JsonObject.class))
{
Format format = Format.fromJSON(formatJson);
if (format != null)
formats.add(format);
}
track.setFormats(formats);
}
ret.add(track);
}
} catch (Exception ex)
{
Logger.getLogger(Downloader.class.getName()).log(Level.SEVERE, null, ex);
}
return ret;
}
private File getDownloadDir() throws IOException
{
if (downloadDir == null || !downloadDir.exists() || !downloadDir.canWrite())
downloadDir = Files.createTempDirectory("chords").toFile();
return downloadDir;
}
private Track getTrackFromRequest(TrackRequest request, int idx)
{
// if there's less tracks in the request than expected, fill the array
while (idx >= request.getTracks().size())
{
Track track = new Track(request);
track.setNumber(trackNumber);
trackNumber++;
request.addTrack(track);
}
return request.getTracks().get(idx);
}
private void download(Promise<TrackRequest, Track> promise, boolean streamOutput) throws InterruptedException, ExecutionException
{
TrackRequest request = promise.getInput();
Set<Format> uniqueFormats = new HashSet<>();
for (Track track : request.getTracks())
{
uniqueFormats.addAll(track.getFormats());
}
List<Format> sortedFormats = sortFormats(uniqueFormats);
String formatCodes = "";
final List<Format> formats = sortedFormats;
for (int i = 0; i < 5 && i < sortedFormats.size(); i++)
formatCodes += formats.get(i).getCode() + "/";
int downloadIdx = 0;
try
{
messageHandler.accept(request, null);
List<String> cmd = new ArrayList<>();
cmd.add(Chords.getSettings().getYtdlCommand());
cmd.add("-x");
cmd.add("-f=" + formatCodes + "worstaudio/bestaudio/worst/best");
cmd.add("--audio-format=wav");
cmd.add("--no-playlist");
// cmd.add(" --extractor-args youtube:player_client=android";
cmd.add("-N8");
if (streamOutput)
{
cmd.add("--downloader=ffmpeg"); // download using FFMpeg
cmd.add("--downloader-args=ffmpeg:-f wav -c:a pcm_s16le"); // tell FFMpeg to convert to wav
cmd.add("-o"); // output to stdout
cmd.add("-");
} else
{
cmd.add("-o " + getDownloadDir().getAbsolutePath() + "/%(title)s.%(ext)s");
}
cmd.add(request.getUrl().toString());
Process exec = runCommand(cmd, streamOutput ? 0 : DOWNLOAD_TIMEOUT);
if (streamOutput)
{
Scanner sc = new Scanner(exec.getErrorStream());
while (sc.hasNextLine())
{
String line = sc.nextLine();
System.out.println(line);
Matcher streamMatcher = STREAM_PATTERN.matcher(line);
if (streamMatcher.find())
{
break;
}
}
if (!exec.isAlive() && exec.exitValue() != 0)
throw new RuntimeException("yt-dlp failed with error code " + exec.exitValue());
Track track = getTrackFromRequest(request, 0);
BufferedInputStream inBuf = new BufferedInputStream(exec.getInputStream());
inBuf.mark(128);
final byte[] headerBytes = inBuf.readNBytes(80);
String header = new String(headerBytes);
System.out.println("streaming data header: " + header);
if (!header.startsWith("RIFF"))
throw new RuntimeException("Streaming data has bad header!");
inBuf.reset();
track.setInputStream(inBuf);
promise.complete(track);
track.setProgress(100.0);
messageHandler.accept(request, null);
} else
{
Scanner sc = new Scanner(exec.getInputStream());
while (sc.hasNextLine())
{
String line = sc.nextLine();
System.out.println(line);
Matcher itemMatcher = DOWNLOAD_ITEM_PATTERN.matcher(line);
if (itemMatcher.find())
{
int idx = Integer.parseInt(itemMatcher.group(1)) - 1;
downloadIdx = idx;
}
Matcher progMatcher = PROGRESS_PATTERN.matcher(line);
if (progMatcher.find())
{
getTrackFromRequest(request, downloadIdx).setProgress(Double.parseDouble(progMatcher.group(1)));
messageHandler.accept(request, null);
}
Matcher destMatcher = DESTINATION_PATTERN.matcher(line);
if (destMatcher.find())
{
Track track = getTrackFromRequest(request, downloadIdx);
track.setLocation(new File(destMatcher.group(1)));
// this is currently our criteria for completion; submit the track and move on
promise.complete(track);
track.setProgress(100.0);
messageHandler.accept(request, null);
downloadIdx++;
}
}
boolean exited = exec.waitFor(DOWNLOAD_TIMEOUT, TimeUnit.SECONDS);
if (exited)
{
String error = new String(exec.getErrorStream().readAllBytes(), Charset.defaultCharset());
if (exec.exitValue() != 0)
throw new RuntimeException("youtube-dl failed with error " + exec.exitValue() + ", output:\n" + error);
} else
{
throw new RuntimeException("youtube-dl failed to exit.");
}
}
messageHandler.accept(request, null);
} catch (Exception ex)
{
Logger.getLogger(Downloader.class.getName()).log(Level.SEVERE, null, ex);
if (messageHandler != null)
messageHandler.accept(request, ex);
//downloadQueue.remove(task);
}
}
private Process runCommand(List<String> cmd, int timeoutSecs) throws RuntimeException, IOException, InterruptedException
{
System.out.println("Running command: " + cmd);
Process exec = new ProcessBuilder(cmd).redirectOutput(ProcessBuilder.Redirect.PIPE).redirectError(ProcessBuilder.Redirect.PIPE).start();
if (timeoutSecs > 0)
{
scheduler.schedule(() ->
{
if (exec.isAlive())
{
exec.destroyForcibly();
System.err.println("Process " + cmd + " took too long, killing process.");
}
}, timeoutSecs, TimeUnit.SECONDS);
}
return exec;
}
public BiConsumer<TrackRequest, Exception> getMessageHandler()
{
return messageHandler;
}
public void setMessageHandler(BiConsumer<TrackRequest, Exception> messageHandler)
{
this.messageHandler = messageHandler;
}
public List<TrackRequest> getDownloadQueue()
{
return new ArrayList<>(inputQueue);
}
// public static class DownloadTask
// {
//
// private final TrackRequest request;
// private final Consumer<Track> destination;
//
// public DownloadTask(TrackRequest request, Consumer<Track> destination)
// {
// this.request = request;
// this.destination = destination;
// }
//
// public TrackRequest getTrack()
// {
// return request;
// }
//
// public Consumer<Track> getDestination()
// {
// return destination;
// }
//
// }
public List<TrackRequest> getDownloadingTracks()
{
return downloadingTracks;
}
}