Skip to content

Commit

Permalink
feat: parallel downloads (#43)
Browse files Browse the repository at this point in the history
* feat: parallel downloads

* feat: mod extract progress using file size

* feat: pass mod version in install progress updates

* fix: only close update channels after finished sending

* chore: verbose ci tests

* fix: store mod in cache
chore: add progress logging to tests

* chore: lint

* test: add concurrent download limit

* fix: prevent concurrent map access

* chore: bump pubgrub
fix: fix race conditions

---------

Co-authored-by: Vilsol <[email protected]>
  • Loading branch information
mircearoata and Vilsol authored Dec 6, 2023
1 parent a192a63 commit 6088d1e
Show file tree
Hide file tree
Showing 12 changed files with 400 additions and 221 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@ jobs:
run: go generate -tags tools -x ./...

- name: Test
run: go test ./...
run: go test -v ./...
env:
SF_DEDICATED_SERVER: ${{ github.workspace }}/SatisfactoryDedicatedServer
1 change: 1 addition & 0 deletions cfg/test_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ func SetDefaults() {
viper.SetDefault("dry-run", false)
viper.SetDefault("api-base", "https://api.ficsit.app")
viper.SetDefault("graphql-api", "/v2/query")
viper.SetDefault("concurrent-downloads", 5)
}
51 changes: 21 additions & 30 deletions cli/cache/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,11 @@ import (
"github.com/satisfactorymodding/ficsit-cli/utils"
)

type Progresser struct {
io.Reader
updates chan utils.GenericUpdate
total int64
running int64
}

func (pt *Progresser) Read(p []byte) (int, error) {
n, err := pt.Reader.Read(p)
pt.running += int64(n)

if err == nil {
if pt.updates != nil {
select {
case pt.updates <- utils.GenericUpdate{Progress: float64(pt.running) / float64(pt.total)}:
default:
}
}
}

if err == io.EOF {
return n, io.EOF
func DownloadOrCache(cacheKey string, hash string, url string, updates chan<- utils.GenericProgress, downloadSemaphore chan int) (io.ReaderAt, int64, error) {
if updates != nil {
defer close(updates)
}

return n, errors.Wrap(err, "failed to read")
}

func DownloadOrCache(cacheKey string, hash string, url string, updates chan utils.GenericUpdate) (io.ReaderAt, int64, error) {
downloadCache := filepath.Join(viper.GetString("cache-dir"), "downloadCache")
if err := os.MkdirAll(downloadCache, 0o777); err != nil {
if !os.IsExist(err) {
Expand Down Expand Up @@ -82,6 +59,20 @@ func DownloadOrCache(cacheKey string, hash string, url string, updates chan util
return nil, 0, errors.Wrap(err, "failed to stat file: "+location)
}

if updates != nil {
headResp, err := http.Head(url)
if err != nil {
return nil, 0, errors.Wrap(err, "failed to head: "+url)
}
defer headResp.Body.Close()
updates <- utils.GenericProgress{Total: headResp.ContentLength}
}

if downloadSemaphore != nil {
downloadSemaphore <- 1
defer func() { <-downloadSemaphore }()
}

out, err := os.Create(location)
if err != nil {
return nil, 0, errors.Wrap(err, "failed creating file at: "+location)
Expand All @@ -98,10 +89,10 @@ func DownloadOrCache(cacheKey string, hash string, url string, updates chan util
return nil, 0, fmt.Errorf("bad status: %s on url: %s", resp.Status, url)
}

progresser := &Progresser{
progresser := &utils.Progresser{
Reader: resp.Body,
total: resp.ContentLength,
updates: updates,
Total: resp.ContentLength,
Updates: updates,
}

_, err = io.Copy(out, progresser)
Expand All @@ -116,7 +107,7 @@ func DownloadOrCache(cacheKey string, hash string, url string, updates chan util

if updates != nil {
select {
case updates <- utils.GenericUpdate{Progress: 1}:
case updates <- utils.GenericProgress{Completed: resp.ContentLength, Total: resp.ContentLength}:
default:
}
}
Expand Down
10 changes: 6 additions & 4 deletions cli/dependency_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/mircearoata/pubgrub-go/pubgrub/helpers"
"github.com/mircearoata/pubgrub-go/pubgrub/semver"
"github.com/pkg/errors"
"github.com/puzpuzpuz/xsync/v3"
"github.com/spf13/viper"

"github.com/satisfactorymodding/ficsit-cli/cli/provider"
Expand All @@ -35,7 +36,7 @@ type ficsitAPISource struct {
provider provider.Provider
lockfile *LockFile
toInstall map[string]semver.Constraint
modVersionInfo map[string]ficsit.ModVersionsWithDependenciesResponse
modVersionInfo *xsync.MapOf[string, ficsit.ModVersionsWithDependenciesResponse]
gameVersion semver.Version
smlVersions []ficsit.SMLVersionsSmlVersionsGetSMLVersionsSml_versionsSMLVersion
}
Expand Down Expand Up @@ -74,7 +75,7 @@ func (f *ficsitAPISource) GetPackageVersions(pkg string) ([]pubgrub.PackageVersi
if response.Mod.Id == "" {
return nil, errors.Errorf("mod %s not found", pkg)
}
f.modVersionInfo[pkg] = *response
f.modVersionInfo.Store(pkg, *response)
versions := make([]pubgrub.PackageVersion, len(response.Mod.Versions))
for i, modVersion := range response.Mod.Versions {
v, err := semver.NewVersion(modVersion.Version)
Expand Down Expand Up @@ -145,7 +146,7 @@ func (d DependencyResolver) ResolveModDependencies(constraints map[string]string
gameVersion: gameVersionSemver,
lockfile: lockFile,
toInstall: toInstall,
modVersionInfo: make(map[string]ficsit.ModVersionsWithDependenciesResponse),
modVersionInfo: xsync.NewMapOf[string, ficsit.ModVersionsWithDependenciesResponse](),
}

result, err := pubgrub.Solve(helpers.NewCachingSource(ficsitSource), rootPkg)
Expand All @@ -170,7 +171,8 @@ func (d DependencyResolver) ResolveModDependencies(constraints map[string]string
}
continue
}
versions := ficsitSource.modVersionInfo[k].Mod.Versions
value, _ := ficsitSource.modVersionInfo.Load(k)
versions := value.Mod.Versions
for _, ver := range versions {
if ver.Version == v.RawString() {
outputLock[k] = LockedMod{
Expand Down
205 changes: 142 additions & 63 deletions cli/installations.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"golang.org/x/sync/errgroup"

"github.com/satisfactorymodding/ficsit-cli/cli/cache"
"github.com/satisfactorymodding/ficsit-cli/cli/disk"
Expand Down Expand Up @@ -351,14 +352,27 @@ func (i *Installation) ResolveProfile(ctx *GlobalContext) (LockFile, error) {
return lockfile, nil
}

type InstallUpdateType string

var (
InstallUpdateTypeOverall InstallUpdateType = "overall"
InstallUpdateTypeModDownload InstallUpdateType = "download"
InstallUpdateTypeModExtract InstallUpdateType = "extract"
InstallUpdateTypeModComplete InstallUpdateType = "complete"
)

type InstallUpdate struct {
ModName string
OverallProgress float64
DownloadProgress float64
ExtractProgress float64
Type InstallUpdateType
Item InstallUpdateItem
Progress utils.GenericProgress
}

func (i *Installation) Install(ctx *GlobalContext, updates chan InstallUpdate) error {
type InstallUpdateItem struct {
Mod string
Version string
}

func (i *Installation) Install(ctx *GlobalContext, updates chan<- InstallUpdate) error {
if err := i.Validate(ctx); err != nil {
return errors.Wrap(err, "failed to validate installation")
}
Expand Down Expand Up @@ -403,78 +417,65 @@ func (i *Installation) Install(ctx *GlobalContext, updates chan InstallUpdate) e
}
}

downloading := true
completed := 0

var genericUpdates chan utils.GenericUpdate
if updates != nil {
var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
log.Info().Int("concurrency", viper.GetInt("concurrent-downloads")).Str("path", i.Path).Msg("starting installation")

genericUpdates = make(chan utils.GenericUpdate)
defer close(genericUpdates)
errg := errgroup.Group{}
channelUsers := sync.WaitGroup{}
downloadSemaphore := make(chan int, viper.GetInt("concurrent-downloads"))
defer close(downloadSemaphore)

var modComplete chan int
if updates != nil {
channelUsers.Add(1)
modComplete = make(chan int)
defer close(modComplete)
go func() {
defer wg.Done()

update := InstallUpdate{
OverallProgress: float64(completed) / float64(len(lockfile)),
DownloadProgress: 0,
ExtractProgress: 0,
}

select {
case updates <- update:
default:
}

for up := range genericUpdates {
if downloading {
update.DownloadProgress = up.Progress
} else {
update.DownloadProgress = 1
update.ExtractProgress = up.Progress
}

if up.ModReference != nil {
update.ModName = *up.ModReference
}

update.OverallProgress = float64(completed) / float64(len(lockfile))

select {
case updates <- update:
default:
defer channelUsers.Done()
completed := 0
for range modComplete {
completed++
overallUpdate := InstallUpdate{
Type: InstallUpdateTypeOverall,
Progress: utils.GenericProgress{
Completed: int64(completed),
Total: int64(len(lockfile)),
},
}
updates <- overallUpdate
}
}()
}

for modReference, version := range lockfile {
// Only install if a link is provided, otherwise assume mod is already installed
if version.Link != "" {
downloading = true

if genericUpdates != nil {
genericUpdates <- utils.GenericUpdate{ModReference: &modReference}
channelUsers.Add(1)
modReference := modReference
version := version
errg.Go(func() error {
defer channelUsers.Done()
// Only install if a link is provided, otherwise assume mod is already installed
if version.Link != "" {
err := downloadAndExtractMod(modReference, version.Version, version.Link, version.Hash, modsDirectory, updates, downloadSemaphore, d)
if err != nil {
return errors.Wrapf(err, "failed to install %s@%s", modReference, version.Version)
}
}

log.Info().Str("mod_reference", modReference).Str("version", version.Version).Str("link", version.Link).Msg("downloading mod")
reader, size, err := cache.DownloadOrCache(modReference+"_"+version.Version+".zip", version.Hash, version.Link, genericUpdates)
if err != nil {
return errors.Wrap(err, "failed to download "+modReference+" from: "+version.Link)
if modComplete != nil {
modComplete <- 1
}
return nil
})
}

downloading = false

log.Info().Str("mod_reference", modReference).Str("version", version.Version).Str("link", version.Link).Msg("extracting mod")
if err := utils.ExtractMod(reader, size, filepath.Join(modsDirectory, modReference), version.Hash, genericUpdates, d); err != nil {
return errors.Wrap(err, "could not extract "+modReference)
}
}
if updates != nil {
go func() {
channelUsers.Wait()
close(updates)
}()
}

completed++
if err := errg.Wait(); err != nil {
return errors.Wrap(err, "failed to install mods")
}

return nil
Expand Down Expand Up @@ -518,6 +519,84 @@ func (i *Installation) UpdateMods(ctx *GlobalContext, mods []string) error {
return nil
}

func downloadAndExtractMod(modReference string, version string, link string, hash string, modsDirectory string, updates chan<- InstallUpdate, downloadSemaphore chan int, d disk.Disk) error {
var downloadUpdates chan utils.GenericProgress

if updates != nil {
// Forward the inner updates as InstallUpdates
downloadUpdates = make(chan utils.GenericProgress)

go func() {
for up := range downloadUpdates {
updates <- InstallUpdate{
Item: InstallUpdateItem{
Mod: modReference,
Version: version,
},
Type: InstallUpdateTypeModDownload,
Progress: up,
}
}
}()
}

log.Info().Str("mod_reference", modReference).Str("version", version).Str("link", link).Msg("downloading mod")
reader, size, err := cache.DownloadOrCache(modReference+"_"+version+".zip", hash, link, downloadUpdates, downloadSemaphore)
if err != nil {
return errors.Wrap(err, "failed to download "+modReference+" from: "+link)
}

var extractUpdates chan utils.GenericProgress

var wg sync.WaitGroup
if updates != nil {
// Forward the inner updates as InstallUpdates
extractUpdates = make(chan utils.GenericProgress)

wg.Add(1)
go func() {
defer wg.Done()
for up := range extractUpdates {
select {
case updates <- InstallUpdate{
Item: InstallUpdateItem{
Mod: modReference,
Version: version,
},
Type: InstallUpdateTypeModExtract,
Progress: up,
}:
default:
}
}
}()
}

log.Info().Str("mod_reference", modReference).Str("version", version).Str("link", link).Msg("extracting mod")
if err := utils.ExtractMod(reader, size, filepath.Join(modsDirectory, modReference), hash, extractUpdates, d); err != nil {
return errors.Wrap(err, "could not extract "+modReference)
}

if updates != nil {
select {
case updates <- InstallUpdate{
Type: InstallUpdateTypeModComplete,
Item: InstallUpdateItem{
Mod: modReference,
Version: version,
},
}:
default:
}

close(extractUpdates)
}

wg.Wait()

return nil
}

func (i *Installation) SetProfile(ctx *GlobalContext, profile string) error {
found := false
for _, p := range ctx.Profiles.Profiles {
Expand Down
Loading

0 comments on commit 6088d1e

Please sign in to comment.