Skip to content

Commit 79f8ec6

Browse files
authored
Update query cursor to include relation + fix implicit result logic (#311)
1 parent 9542553 commit 79f8ec6

File tree

3 files changed

+333
-43
lines changed

3 files changed

+333
-43
lines changed

pkg/authz/query/resultset.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,9 @@ func (rs *ResultSet) Union(other *ResultSet) *ResultSet {
9999
}
100100

101101
for iter := other.List(); iter != nil; iter = iter.Next() {
102-
isImplicit := iter.IsImplicit
103-
if resultSet.Has(iter.ObjectType, iter.ObjectId, iter.Relation) {
104-
isImplicit = isImplicit && resultSet.Get(iter.ObjectType, iter.ObjectId, iter.Relation).IsImplicit
102+
if !resultSet.Has(iter.ObjectType, iter.ObjectId, iter.Relation) || !iter.IsImplicit {
103+
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit)
105104
}
106-
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, isImplicit)
107105
}
108106

109107
return resultSet
@@ -112,10 +110,13 @@ func (rs *ResultSet) Union(other *ResultSet) *ResultSet {
112110
func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet {
113111
resultSet := NewResultSet()
114112
for iter := rs.List(); iter != nil; iter = iter.Next() {
115-
isImplicit := iter.IsImplicit
116113
if other.Has(iter.ObjectType, iter.ObjectId, iter.Relation) {
117-
isImplicit = isImplicit || other.Get(iter.ObjectType, iter.ObjectId, iter.Relation).IsImplicit
118-
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, isImplicit)
114+
otherRes := other.Get(iter.ObjectType, iter.ObjectId, iter.Relation)
115+
if !otherRes.IsImplicit {
116+
resultSet.Add(otherRes.ObjectType, otherRes.ObjectId, otherRes.Relation, otherRes.Warrant, otherRes.IsImplicit)
117+
} else {
118+
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit)
119+
}
119120
}
120121
}
121122

@@ -125,7 +126,11 @@ func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet {
125126
func (rs *ResultSet) String() string {
126127
var strs []string
127128
for iter := rs.List(); iter != nil; iter = iter.Next() {
128-
strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
129+
if iter.IsImplicit {
130+
strs = append(strs, fmt.Sprintf("%s => %s [implicit]", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
131+
} else {
132+
strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
133+
}
129134
}
130135

131136
return strings.Join(strs, ", ")

pkg/authz/query/service.go

Lines changed: 72 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -222,25 +222,25 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
222222
paginatedQueryResults := make([]QueryResult, 0)
223223
//nolint:gocritic
224224
if listParams.NextCursor != nil { // seek forward if NextCursor passed in
225-
lastObjectType, lastObjectId, err := objectTypeAndObjectIdFromCursor(listParams.NextCursor)
225+
lastObjectType, lastObjectId, lastRelation, err := objectTypeAndObjectIdAndRelationFromCursor(listParams.NextCursor)
226226
if err != nil {
227227
return nil, nil, nil, service.NewInvalidParameterError("nextCursor", "invalid cursor")
228228
}
229229

230230
start = 0
231-
for start < len(queryResults) && (queryResults[start].ObjectType != lastObjectType || queryResults[start].ObjectId != lastObjectId) {
231+
for start < len(queryResults) && (queryResults[start].ObjectType != lastObjectType || queryResults[start].ObjectId != lastObjectId || queryResults[start].Relation != lastRelation) {
232232
start++
233233
}
234234

235235
end = start + listParams.Limit
236236
} else if listParams.PrevCursor != nil { // seek backward if PrevCursor passed in
237-
lastObjectType, lastObjectId, err := objectTypeAndObjectIdFromCursor(listParams.PrevCursor)
237+
lastObjectType, lastObjectId, lastRelation, err := objectTypeAndObjectIdAndRelationFromCursor(listParams.PrevCursor)
238238
if err != nil {
239239
return nil, nil, nil, service.NewInvalidParameterError("prevCursor", "invalid cursor")
240240
}
241241

242242
end = len(queryResults) - 1
243-
for end > 0 && (queryResults[end].ObjectType != lastObjectType || queryResults[end].ObjectId != lastObjectId) {
243+
for end > 0 && (queryResults[end].ObjectType != lastObjectType || queryResults[end].ObjectId != lastObjectId || queryResults[end].Relation != lastRelation) {
244244
end--
245245
}
246246

@@ -262,7 +262,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
262262
value = queryResults[start].Meta[listParams.SortBy]
263263
}
264264

265-
prevCursor = service.NewCursor(objectKey(queryResults[start].ObjectType, queryResults[start].ObjectId), value)
265+
prevCursor = service.NewCursor(objectRelationKey(queryResults[start].ObjectType, queryResults[start].ObjectId, queryResults[start].Relation), value)
266266
}
267267

268268
// if there are more results forward
@@ -277,7 +277,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
277277
value = queryResults[end].Meta[listParams.SortBy]
278278
}
279279

280-
nextCursor = service.NewCursor(objectKey(queryResults[end].ObjectType, queryResults[end].ObjectId), value)
280+
nextCursor = service.NewCursor(objectRelationKey(queryResults[end].ObjectType, queryResults[end].ObjectId, queryResults[end].Relation), value)
281281
}
282282

283283
for start < end && start < len(queryResults) {
@@ -336,7 +336,7 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
336336
for _, matchedWarrant := range matchedWarrants {
337337
if matchedWarrant.Subject.Relation != "" {
338338
// handle group warrants
339-
userset, err := svc.query(ctx, Query{
339+
subset, err := svc.query(ctx, Query{
340340
Expand: query.Expand,
341341
SelectSubjects: &SelectSubjects{
342342
Relations: []string{matchedWarrant.Subject.Relation},
@@ -352,8 +352,8 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
352352
return nil, err
353353
}
354354

355-
for res := userset.List(); res != nil; res = res.Next() {
356-
if res.ObjectType != query.SelectObjects.WhereSubject.Type || res.ObjectId != query.SelectObjects.WhereSubject.Id {
355+
for sub := subset.List(); sub != nil; sub = sub.Next() {
356+
if sub.ObjectType != query.SelectObjects.WhereSubject.Type || sub.ObjectId != query.SelectObjects.WhereSubject.Id {
357357
continue
358358
}
359359

@@ -367,27 +367,29 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
367367

368368
for _, w := range expandedWildcardWarrants {
369369
if w.ObjectId != warrant.Wildcard {
370-
resultSet.Add(w.ObjectType, w.ObjectId, relation, matchedWarrant, level > 0)
370+
resultSet.Add(w.ObjectType, w.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0)
371371
}
372372
}
373373
} else {
374-
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, level > 0)
374+
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0)
375375
}
376376
}
377377
} else if query.SelectObjects.WhereSubject == nil ||
378378
(matchedWarrant.Subject.ObjectType == query.SelectObjects.WhereSubject.Type &&
379379
matchedWarrant.Subject.ObjectId == query.SelectObjects.WhereSubject.Id) {
380-
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, level > 0)
380+
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, false)
381381
}
382382
}
383383

384384
if query.Expand {
385-
implicitResultSet, err := svc.queryRule(ctx, query, level, objectTypeDef.Relations[relation])
385+
implicitResultSet, err := svc.queryRule(ctx, query, level+1, relation, objectTypeDef.Relations[relation])
386386
if err != nil {
387387
return nil, err
388388
}
389389

390-
resultSet = resultSet.Union(implicitResultSet)
390+
for res := implicitResultSet.List(); res != nil; res = res.Next() {
391+
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
392+
}
391393
}
392394

393395
return resultSet, nil
@@ -417,7 +419,7 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
417419
for _, matchedWarrant := range matchedWarrants {
418420
if matchedWarrant.Subject.Relation != "" {
419421
// handle group warrants
420-
userset, err := svc.query(ctx, Query{
422+
subset, err := svc.query(ctx, Query{
421423
Expand: query.Expand,
422424
SelectSubjects: &SelectSubjects{
423425
Relations: []string{matchedWarrant.Subject.Relation},
@@ -433,21 +435,23 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
433435
return nil, err
434436
}
435437

436-
for res := userset.List(); res != nil; res = res.Next() {
437-
resultSet.Add(res.ObjectType, res.ObjectId, relation, matchedWarrant, level > 0)
438+
for sub := subset.List(); sub != nil; sub = sub.Next() {
439+
resultSet.Add(sub.ObjectType, sub.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0)
438440
}
439441
} else if query.SelectSubjects.SubjectTypes[0] == matchedWarrant.Subject.ObjectType {
440-
resultSet.Add(matchedWarrant.Subject.ObjectType, matchedWarrant.Subject.ObjectId, relation, matchedWarrant, level > 0)
442+
resultSet.Add(matchedWarrant.Subject.ObjectType, matchedWarrant.Subject.ObjectId, relation, matchedWarrant, false)
441443
}
442444
}
443445

444446
if query.Expand {
445-
implicitResultSet, err := svc.queryRule(ctx, query, level, objectTypeDef.Relations[relation])
447+
implicitResultSet, err := svc.queryRule(ctx, query, level+1, relation, objectTypeDef.Relations[relation])
446448
if err != nil {
447449
return nil, err
448450
}
449451

450-
return resultSet.Union(implicitResultSet), nil
452+
for res := implicitResultSet.List(); res != nil; res = res.Next() {
453+
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
454+
}
451455
}
452456

453457
return resultSet, nil
@@ -456,14 +460,14 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
456460
}
457461
}
458462

459-
func (svc QueryService) queryRule(ctx context.Context, query Query, level int, rule objecttype.RelationRule) (*ResultSet, error) {
463+
func (svc QueryService) queryRule(ctx context.Context, query Query, level int, relation string, rule objecttype.RelationRule) (*ResultSet, error) {
460464
switch rule.InheritIf {
461465
case "":
462466
return NewResultSet(), nil
463467
case objecttype.InheritIfAllOf:
464468
var resultSet *ResultSet
465469
for _, r := range rule.Rules {
466-
res, err := svc.queryRule(ctx, query, level, r)
470+
res, err := svc.queryRule(ctx, query, level, relation, r)
467471
if err != nil {
468472
return nil, err
469473
}
@@ -479,7 +483,7 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
479483
case objecttype.InheritIfAnyOf:
480484
var resultSet *ResultSet
481485
for _, r := range rule.Rules {
482-
res, err := svc.queryRule(ctx, query, level, r)
486+
res, err := svc.queryRule(ctx, query, level, relation, r)
483487
if err != nil {
484488
return nil, err
485489
}
@@ -498,15 +502,25 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
498502
switch {
499503
case query.SelectObjects != nil:
500504
if rule.OfType == "" && rule.WithRelation == "" {
501-
return svc.query(ctx, Query{
505+
results, err := svc.query(ctx, Query{
502506
Expand: true,
503507
SelectObjects: &SelectObjects{
504508
ObjectTypes: query.SelectObjects.ObjectTypes,
505509
WhereSubject: query.SelectObjects.WhereSubject,
506510
Relations: []string{rule.InheritIf},
507511
},
508512
Context: query.Context,
509-
}, level+1)
513+
}, 0)
514+
if err != nil {
515+
return nil, err
516+
}
517+
518+
resultSet := NewResultSet()
519+
for res := results.List(); res != nil; res = res.Next() {
520+
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
521+
}
522+
523+
return resultSet, nil
510524
} else {
511525
indirectWarrants, err := svc.listWarrants(ctx, warrant.FilterParams{
512526
ObjectType: rule.OfType,
@@ -535,27 +549,39 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
535549
Relations: []string{rule.WithRelation},
536550
},
537551
Context: query.Context,
538-
}, level+1)
552+
}, 0)
539553
if err != nil {
540554
return nil, err
541555
}
542556

543-
resultSet = resultSet.Union(inheritedResults)
557+
for res := inheritedResults.List(); res != nil; res = res.Next() {
558+
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
559+
}
544560
}
545561

546562
return resultSet, nil
547563
}
548564
case query.SelectSubjects != nil:
549565
if rule.OfType == "" && rule.WithRelation == "" {
550-
return svc.query(ctx, Query{
566+
results, err := svc.query(ctx, Query{
551567
Expand: true,
552568
SelectSubjects: &SelectSubjects{
553569
SubjectTypes: query.SelectSubjects.SubjectTypes,
554570
Relations: []string{rule.InheritIf},
555571
ForObject: query.SelectSubjects.ForObject,
556572
},
557573
Context: query.Context,
558-
}, level+1)
574+
}, 0)
575+
if err != nil {
576+
return nil, err
577+
}
578+
579+
resultSet := NewResultSet()
580+
for res := results.List(); res != nil; res = res.Next() {
581+
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
582+
}
583+
584+
return resultSet, nil
559585
} else {
560586
userset, err := svc.listWarrants(ctx, warrant.FilterParams{
561587
ObjectType: query.SelectSubjects.ForObject.Type,
@@ -584,12 +610,14 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
584610
},
585611
},
586612
Context: query.Context,
587-
}, level+1)
613+
}, 0)
588614
if err != nil {
589615
return nil, err
590616
}
591617

592-
resultSet = resultSet.Union(subset)
618+
for res := subset.List(); res != nil; res = res.Next() {
619+
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
620+
}
593621
}
594622

595623
return resultSet, nil
@@ -628,11 +656,20 @@ func objectKey(objectType string, objectId string) string {
628656
return fmt.Sprintf("%s:%s", objectType, objectId)
629657
}
630658

631-
func objectTypeAndObjectIdFromCursor(cursor *service.Cursor) (string, string, error) {
632-
objectType, objectId, found := strings.Cut(cursor.ID(), ":")
659+
func objectRelationKey(objectType string, objectId string, relation string) string {
660+
return fmt.Sprintf("%s:%s#%s", objectType, objectId, relation)
661+
}
662+
663+
func objectTypeAndObjectIdAndRelationFromCursor(cursor *service.Cursor) (string, string, string, error) {
664+
objectType, objectIdRelation, found := strings.Cut(cursor.ID(), ":")
665+
if !found {
666+
return "", "", "", errors.New("invalid cursor")
667+
}
668+
669+
objectId, relation, found := strings.Cut(objectIdRelation, "#")
633670
if !found {
634-
return "", "", errors.New("invalid cursor")
671+
return "", "", "", errors.New("invalid cursor")
635672
}
636673

637-
return objectType, objectId, nil
674+
return objectType, objectId, relation, nil
638675
}

0 commit comments

Comments
 (0)