Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more context for models #19511

Merged
merged 21 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ func notifyWatchers(ctx context.Context, actions ...*Action) error {
permPR[i] = false
continue
}
perm, err := getUserRepoPermission(ctx, repo, user)
perm, err := GetUserRepoPermission(ctx, repo, user)
if err != nil {
permCode[i] = false
permIssue[i] = false
Expand Down
10 changes: 5 additions & 5 deletions models/branches.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,18 +337,18 @@ type WhitelistOptions struct {
// If ID is 0, it creates a new record. Otherwise, updates existing record.
// This function also performs check if whitelist user and team's IDs have been changed
// to avoid unnecessary whitelist delete and regenerate.
func UpdateProtectBranch(repo *repo_model.Repository, protectBranch *ProtectedBranch, opts WhitelistOptions) (err error) {
func UpdateProtectBranch(ctx context.Context, repo *repo_model.Repository, protectBranch *ProtectedBranch, opts WhitelistOptions) (err error) {
if err = repo.GetOwner(db.DefaultContext); err != nil {
6543 marked this conversation as resolved.
Show resolved Hide resolved
return fmt.Errorf("GetOwner: %v", err)
}

whitelist, err := updateUserWhitelist(repo, protectBranch.WhitelistUserIDs, opts.UserIDs)
whitelist, err := updateUserWhitelist(ctx, repo, protectBranch.WhitelistUserIDs, opts.UserIDs)
if err != nil {
return err
}
protectBranch.WhitelistUserIDs = whitelist

whitelist, err = updateUserWhitelist(repo, protectBranch.MergeWhitelistUserIDs, opts.MergeUserIDs)
whitelist, err = updateUserWhitelist(ctx, repo, protectBranch.MergeWhitelistUserIDs, opts.MergeUserIDs)
if err != nil {
return err
}
Expand Down Expand Up @@ -437,7 +437,7 @@ func updateApprovalWhitelist(repo *repo_model.Repository, currentWhitelist, newW

// updateUserWhitelist checks whether the user whitelist changed and returns a whitelist with
// the users from newWhitelist which have write access to the repo.
func updateUserWhitelist(repo *repo_model.Repository, currentWhitelist, newWhitelist []int64) (whitelist []int64, err error) {
func updateUserWhitelist(ctx context.Context, repo *repo_model.Repository, currentWhitelist, newWhitelist []int64) (whitelist []int64, err error) {
hasUsersChanged := !util.IsSliceInt64Eq(currentWhitelist, newWhitelist)
if !hasUsersChanged {
return currentWhitelist, nil
Expand All @@ -449,7 +449,7 @@ func updateUserWhitelist(repo *repo_model.Repository, currentWhitelist, newWhite
if err != nil {
return nil, fmt.Errorf("GetUserByID [user_id: %d, repo_id: %d]: %v", userID, repo.ID, err)
}
perm, err := GetUserRepoPermission(repo, user)
perm, err := GetUserRepoPermission(ctx, repo, user)
if err != nil {
return nil, fmt.Errorf("GetUserRepoPermission [user_id: %d, repo_id: %d]: %v", userID, repo.ID, err)
}
Expand Down
3 changes: 2 additions & 1 deletion models/branches_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package models
import (
"testing"

"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"

Expand Down Expand Up @@ -99,7 +100,7 @@ func TestRenameBranch(t *testing.T) {
repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1}).(*repo_model.Repository)
_isDefault := false

err := UpdateProtectBranch(repo1, &ProtectedBranch{
err := UpdateProtectBranch(db.DefaultContext, repo1, &ProtectedBranch{
6543 marked this conversation as resolved.
Show resolved Hide resolved
RepoID: repo1.ID,
BranchName: "master",
}, WhitelistOptions{})
Expand Down
9 changes: 5 additions & 4 deletions models/commit_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,13 @@ type CommitStatusIndex struct {

// GetLatestCommitStatus returns all statuses with a unique context for a given commit.
func GetLatestCommitStatus(repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) {
return getLatestCommitStatus(db.GetEngine(db.DefaultContext), repoID, sha, listOptions)
return GetLatestCommitStatusCtx(db.DefaultContext, repoID, sha, listOptions)
}

func getLatestCommitStatus(e db.Engine, repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) {
// GetLatestCommitStatusCtx returns all statuses with a unique context for a given commit.
func GetLatestCommitStatusCtx(ctx context.Context, repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) {
ids := make([]int64, 0, 10)
sess := e.Table(&CommitStatus{}).
sess := db.GetEngine(ctx).Table(&CommitStatus{}).
Where("repo_id = ?", repoID).And("sha = ?", sha).
Select("max( id ) as id").
GroupBy("context_hash").OrderBy("max( id ) desc")
Expand All @@ -252,7 +253,7 @@ func getLatestCommitStatus(e db.Engine, repoID int64, sha string, listOptions db
if len(ids) == 0 {
return statuses, count, nil
}
return statuses, count, e.In("id", ids).Find(&statuses)
return statuses, count, db.GetEngine(ctx).In("id", ids).Find(&statuses)
}

// FindRepoRecentCommitStatusContexts returns repository's recent commit status contexts
Expand Down
6 changes: 3 additions & 3 deletions models/issue.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ func ClearIssueLabels(issue *Issue, doer *user_model.User) (err error) {
return err
}

perm, err := getUserRepoPermission(ctx, issue.Repo, doer)
perm, err := GetUserRepoPermission(ctx, issue.Repo, doer)
if err != nil {
return err
}
Expand Down Expand Up @@ -2341,9 +2341,9 @@ func ResolveIssueMentionsByVisibility(ctx context.Context, issue *Issue, doer *u
continue
}
// Normal users must have read access to the referencing issue
perm, err := getUserRepoPermission(ctx, issue.Repo, user)
perm, err := GetUserRepoPermission(ctx, issue.Repo, user)
if err != nil {
return nil, fmt.Errorf("getUserRepoPermission [%d]: %v", user.ID, err)
return nil, fmt.Errorf("GetUserRepoPermission [%d]: %v", user.ID, err)
}
if !perm.CanReadIssuesOrPulls(issue.IsPull) {
continue
Expand Down
16 changes: 3 additions & 13 deletions models/issue_label.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,23 +707,13 @@ func deleteIssueLabel(ctx context.Context, issue *Issue, label *Label, doer *use
}

// DeleteIssueLabel deletes issue-label relation.
func DeleteIssueLabel(issue *Issue, label *Label, doer *user_model.User) (err error) {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()

if err = deleteIssueLabel(ctx, issue, label, doer); err != nil {
func DeleteIssueLabel(ctx context.Context, issue *Issue, label *Label, doer *user_model.User) error {
if err := deleteIssueLabel(ctx, issue, label, doer); err != nil {
return err
}

issue.Labels = nil
if err = issue.loadLabels(db.GetEngine(ctx)); err != nil {
return err
}

return committer.Commit()
return issue.loadLabels(db.GetEngine(ctx))
}

func deleteLabelsByRepoID(sess db.Engine, repoID int64) error {
Expand Down
2 changes: 1 addition & 1 deletion models/issue_label_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ func TestDeleteIssueLabel(t *testing.T) {
}
}

assert.NoError(t, DeleteIssueLabel(issue, label, doer))
assert.NoError(t, DeleteIssueLabel(db.DefaultContext, issue, label, doer))
6543 marked this conversation as resolved.
Show resolved Hide resolved
unittest.AssertNotExistsBean(t, &IssueLabel{IssueID: issueID, LabelID: labelID})
unittest.AssertExistsAndLoadBean(t, &Comment{
Type: CommentTypeLabel,
Expand Down
2 changes: 1 addition & 1 deletion models/issue_xref.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (issue *Issue) verifyReferencedIssue(stdCtx context.Context, ctx *crossRefe

// Check doer permissions; set action to None if the doer can't change the destination
if refIssue.RepoID != ctx.OrigIssue.RepoID || ref.Action != references.XRefActionNone {
perm, err := getUserRepoPermission(stdCtx, refIssue.Repo, ctx.Doer)
perm, err := GetUserRepoPermission(stdCtx, refIssue.Repo, ctx.Doer)
if err != nil {
return nil, references.XRefActionNone, err
}
Expand Down
47 changes: 32 additions & 15 deletions models/lfs_lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package models

import (
"context"
"fmt"
"path"
"strings"
Expand Down Expand Up @@ -42,31 +43,39 @@ func cleanPath(p string) string {

// CreateLFSLock creates a new lock.
func CreateLFSLock(repo *repo_model.Repository, lock *LFSLock) (*LFSLock, error) {
err := CheckLFSAccessForRepo(lock.OwnerID, repo, perm.AccessModeWrite)
dbCtx, committer, err := db.TxContext()
if err != nil {
return nil, err
}
defer committer.Close()

if err := CheckLFSAccessForRepo(dbCtx, lock.OwnerID, repo, perm.AccessModeWrite); err != nil {
return nil, err
}

lock.Path = cleanPath(lock.Path)
lock.RepoID = repo.ID

l, err := GetLFSLock(repo, lock.Path)
l, err := GetLFSLock(dbCtx, repo, lock.Path)
if err == nil {
return l, ErrLFSLockAlreadyExist{lock.RepoID, lock.Path}
}
if !IsErrLFSLockNotExist(err) {
return nil, err
}

err = db.Insert(db.DefaultContext, lock)
return lock, err
if err := db.Insert(dbCtx, lock); err != nil {
return nil, err
}

return lock, committer.Commit()
}

// GetLFSLock returns release by given path.
func GetLFSLock(repo *repo_model.Repository, path string) (*LFSLock, error) {
func GetLFSLock(ctx context.Context, repo *repo_model.Repository, path string) (*LFSLock, error) {
path = cleanPath(path)
rel := &LFSLock{RepoID: repo.ID}
has, err := db.GetEngine(db.DefaultContext).Where("lower(path) = ?", strings.ToLower(path)).Get(rel)
has, err := db.GetEngine(ctx).Where("lower(path) = ?", strings.ToLower(path)).Get(rel)
if err != nil {
return nil, err
}
Expand All @@ -77,9 +86,9 @@ func GetLFSLock(repo *repo_model.Repository, path string) (*LFSLock, error) {
}

// GetLFSLockByID returns release by given id.
func GetLFSLockByID(id int64) (*LFSLock, error) {
func GetLFSLockByID(ctx context.Context, id int64) (*LFSLock, error) {
lock := new(LFSLock)
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(lock)
has, err := db.GetEngine(ctx).ID(id).Get(lock)
if err != nil {
return nil, err
} else if !has {
Expand Down Expand Up @@ -127,34 +136,42 @@ func CountLFSLockByRepoID(repoID int64) (int64, error) {

// DeleteLFSLockByID deletes a lock by given ID.
func DeleteLFSLockByID(id int64, repo *repo_model.Repository, u *user_model.User, force bool) (*LFSLock, error) {
lock, err := GetLFSLockByID(id)
dbCtx, committer, err := db.TxContext()
if err != nil {
return nil, err
}
defer committer.Close()

err = CheckLFSAccessForRepo(u.ID, repo, perm.AccessModeWrite)
lock, err := GetLFSLockByID(dbCtx, id)
if err != nil {
return nil, err
}

if err := CheckLFSAccessForRepo(dbCtx, u.ID, repo, perm.AccessModeWrite); err != nil {
return nil, err
}

if !force && u.ID != lock.OwnerID {
return nil, fmt.Errorf("user doesn't own lock and force flag is not set")
}

_, err = db.GetEngine(db.DefaultContext).ID(id).Delete(new(LFSLock))
return lock, err
if _, err := db.GetEngine(dbCtx).ID(id).Delete(new(LFSLock)); err != nil {
return nil, err
}

return lock, committer.Commit()
}

// CheckLFSAccessForRepo check needed access mode base on action
func CheckLFSAccessForRepo(ownerID int64, repo *repo_model.Repository, mode perm.AccessMode) error {
func CheckLFSAccessForRepo(ctx context.Context, ownerID int64, repo *repo_model.Repository, mode perm.AccessMode) error {
if ownerID == 0 {
return ErrLFSUnauthorizedAction{repo.ID, "undefined", mode}
}
u, err := user_model.GetUserByID(ownerID)
u, err := user_model.GetUserByIDCtx(ctx, ownerID)
if err != nil {
return err
}
perm, err := GetUserRepoPermission(repo, u)
perm, err := GetUserRepoPermission(ctx, repo, u)
if err != nil {
return err
}
Expand Down
39 changes: 24 additions & 15 deletions models/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ func (pr *PullRequest) LoadAttributes() error {
return pr.loadAttributes(db.GetEngine(db.DefaultContext))
}

func (pr *PullRequest) loadHeadRepo(ctx context.Context) (err error) {
// LoadHeadRepoCtx loads the head repository
func (pr *PullRequest) LoadHeadRepoCtx(ctx context.Context) (err error) {
if !pr.isHeadRepoLoaded && pr.HeadRepo == nil && pr.HeadRepoID > 0 {
if pr.HeadRepoID == pr.BaseRepoID {
if pr.BaseRepo != nil {
Expand All @@ -153,15 +154,16 @@ func (pr *PullRequest) loadHeadRepo(ctx context.Context) (err error) {

// LoadHeadRepo loads the head repository
func (pr *PullRequest) LoadHeadRepo() error {
return pr.loadHeadRepo(db.DefaultContext)
return pr.LoadHeadRepoCtx(db.DefaultContext)
}

// LoadBaseRepo loads the target repository
func (pr *PullRequest) LoadBaseRepo() error {
return pr.loadBaseRepo(db.DefaultContext)
return pr.LoadBaseRepoCtx(db.DefaultContext)
}

func (pr *PullRequest) loadBaseRepo(ctx context.Context) (err error) {
// LoadBaseRepoCtx loads the target repository
func (pr *PullRequest) LoadBaseRepoCtx(ctx context.Context) (err error) {
if pr.BaseRepo != nil {
return nil
}
Expand All @@ -185,15 +187,16 @@ func (pr *PullRequest) loadBaseRepo(ctx context.Context) (err error) {

// LoadIssue loads issue information from database
func (pr *PullRequest) LoadIssue() (err error) {
return pr.loadIssue(db.GetEngine(db.DefaultContext))
return pr.LoadIssueCtx(db.DefaultContext)
}

func (pr *PullRequest) loadIssue(e db.Engine) (err error) {
// LoadIssueCtx loads issue information from database
func (pr *PullRequest) LoadIssueCtx(ctx context.Context) (err error) {
if pr.Issue != nil {
return nil
}

pr.Issue, err = getIssueByID(e, pr.IssueID)
pr.Issue, err = getIssueByID(db.GetEngine(ctx), pr.IssueID)
if err == nil {
pr.Issue.PullRequest = pr
}
Expand All @@ -202,10 +205,11 @@ func (pr *PullRequest) loadIssue(e db.Engine) (err error) {

// LoadProtectedBranch loads the protected branch of the base branch
func (pr *PullRequest) LoadProtectedBranch() (err error) {
return pr.loadProtectedBranch(db.DefaultContext)
return pr.LoadProtectedBranchCtx(db.DefaultContext)
}

func (pr *PullRequest) loadProtectedBranch(ctx context.Context) (err error) {
// LoadProtectedBranchCtx loads the protected branch of the base branch
func (pr *PullRequest) LoadProtectedBranchCtx(ctx context.Context) (err error) {
if pr.ProtectedBranch == nil {
if pr.BaseRepo == nil {
if pr.BaseRepoID == 0 {
Expand Down Expand Up @@ -392,7 +396,7 @@ func (pr *PullRequest) SetMerged() (bool, error) {
}

pr.Issue = nil
if err := pr.loadIssue(sess); err != nil {
if err := pr.LoadIssueCtx(ctx); err != nil {
return false, err
}

Expand Down Expand Up @@ -510,6 +514,11 @@ func GetLatestPullRequestByHeadInfo(repoID int64, branch string) (*PullRequest,

// GetPullRequestByIndex returns a pull request by the given index
func GetPullRequestByIndex(repoID, index int64) (*PullRequest, error) {
return GetPullRequestByIndexCtx(db.DefaultContext, repoID, index)
}

// GetPullRequestByIndexCtx returns a pull request by the given index
func GetPullRequestByIndexCtx(ctx context.Context, repoID, index int64) (*PullRequest, error) {
if index < 1 {
return nil, ErrPullRequestNotExist{}
}
Expand All @@ -518,17 +527,17 @@ func GetPullRequestByIndex(repoID, index int64) (*PullRequest, error) {
Index: index,
}

has, err := db.GetEngine(db.DefaultContext).Get(pr)
has, err := db.GetEngine(ctx).Get(pr)
if err != nil {
return nil, err
} else if !has {
return nil, ErrPullRequestNotExist{0, 0, 0, repoID, "", ""}
}

if err = pr.LoadAttributes(); err != nil {
if err = pr.loadAttributes(db.GetEngine(ctx)); err != nil {
return nil, err
}
if err = pr.LoadIssue(); err != nil {
if err = pr.LoadIssueCtx(ctx); err != nil {
return nil, err
}

Expand All @@ -547,8 +556,8 @@ func getPullRequestByID(e db.Engine, id int64) (*PullRequest, error) {
}

// GetPullRequestByID returns a pull request by given ID.
func GetPullRequestByID(id int64) (*PullRequest, error) {
return getPullRequestByID(db.GetEngine(db.DefaultContext), id)
func GetPullRequestByID(ctx context.Context, id int64) (*PullRequest, error) {
return getPullRequestByID(db.GetEngine(ctx), id)
}

// GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID.
Expand Down
Loading