@@ -222,25 +222,25 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
222
222
paginatedQueryResults := make ([]QueryResult , 0 )
223
223
//nolint:gocritic
224
224
if listParams .NextCursor != nil { // seek forward if NextCursor passed in
225
- lastObjectType , lastObjectId , err := objectTypeAndObjectIdFromCursor (listParams .NextCursor )
225
+ lastObjectType , lastObjectId , lastRelation , err := objectTypeAndObjectIdAndRelationFromCursor (listParams .NextCursor )
226
226
if err != nil {
227
227
return nil , nil , nil , service .NewInvalidParameterError ("nextCursor" , "invalid cursor" )
228
228
}
229
229
230
230
start = 0
231
- for start < len (queryResults ) && (queryResults [start ].ObjectType != lastObjectType || queryResults [start ].ObjectId != lastObjectId ) {
231
+ for start < len (queryResults ) && (queryResults [start ].ObjectType != lastObjectType || queryResults [start ].ObjectId != lastObjectId || queryResults [ start ]. Relation != lastRelation ) {
232
232
start ++
233
233
}
234
234
235
235
end = start + listParams .Limit
236
236
} else if listParams .PrevCursor != nil { // seek backward if PrevCursor passed in
237
- lastObjectType , lastObjectId , err := objectTypeAndObjectIdFromCursor (listParams .PrevCursor )
237
+ lastObjectType , lastObjectId , lastRelation , err := objectTypeAndObjectIdAndRelationFromCursor (listParams .PrevCursor )
238
238
if err != nil {
239
239
return nil , nil , nil , service .NewInvalidParameterError ("prevCursor" , "invalid cursor" )
240
240
}
241
241
242
242
end = len (queryResults ) - 1
243
- for end > 0 && (queryResults [end ].ObjectType != lastObjectType || queryResults [end ].ObjectId != lastObjectId ) {
243
+ for end > 0 && (queryResults [end ].ObjectType != lastObjectType || queryResults [end ].ObjectId != lastObjectId || queryResults [ end ]. Relation != lastRelation ) {
244
244
end --
245
245
}
246
246
@@ -262,7 +262,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
262
262
value = queryResults [start ].Meta [listParams .SortBy ]
263
263
}
264
264
265
- prevCursor = service .NewCursor (objectKey (queryResults [start ].ObjectType , queryResults [start ].ObjectId ), value )
265
+ prevCursor = service .NewCursor (objectRelationKey (queryResults [start ].ObjectType , queryResults [start ].ObjectId , queryResults [ start ]. Relation ), value )
266
266
}
267
267
268
268
// if there are more results forward
@@ -277,7 +277,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
277
277
value = queryResults [end ].Meta [listParams .SortBy ]
278
278
}
279
279
280
- nextCursor = service .NewCursor (objectKey (queryResults [end ].ObjectType , queryResults [end ].ObjectId ), value )
280
+ nextCursor = service .NewCursor (objectRelationKey (queryResults [end ].ObjectType , queryResults [end ].ObjectId , queryResults [ end ]. Relation ), value )
281
281
}
282
282
283
283
for start < end && start < len (queryResults ) {
@@ -336,7 +336,7 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
336
336
for _ , matchedWarrant := range matchedWarrants {
337
337
if matchedWarrant .Subject .Relation != "" {
338
338
// handle group warrants
339
- userset , err := svc .query (ctx , Query {
339
+ subset , err := svc .query (ctx , Query {
340
340
Expand : query .Expand ,
341
341
SelectSubjects : & SelectSubjects {
342
342
Relations : []string {matchedWarrant .Subject .Relation },
@@ -352,8 +352,8 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
352
352
return nil , err
353
353
}
354
354
355
- for res := userset .List (); res != nil ; res = res .Next () {
356
- if res .ObjectType != query .SelectObjects .WhereSubject .Type || res .ObjectId != query .SelectObjects .WhereSubject .Id {
355
+ for sub := subset .List (); sub != nil ; sub = sub .Next () {
356
+ if sub .ObjectType != query .SelectObjects .WhereSubject .Type || sub .ObjectId != query .SelectObjects .WhereSubject .Id {
357
357
continue
358
358
}
359
359
@@ -367,27 +367,29 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
367
367
368
368
for _ , w := range expandedWildcardWarrants {
369
369
if w .ObjectId != warrant .Wildcard {
370
- resultSet .Add (w .ObjectType , w .ObjectId , relation , matchedWarrant , level > 0 )
370
+ resultSet .Add (w .ObjectType , w .ObjectId , relation , matchedWarrant , sub . IsImplicit || level > 0 )
371
371
}
372
372
}
373
373
} else {
374
- resultSet .Add (matchedWarrant .ObjectType , matchedWarrant .ObjectId , relation , matchedWarrant , level > 0 )
374
+ resultSet .Add (matchedWarrant .ObjectType , matchedWarrant .ObjectId , relation , matchedWarrant , sub . IsImplicit || level > 0 )
375
375
}
376
376
}
377
377
} else if query .SelectObjects .WhereSubject == nil ||
378
378
(matchedWarrant .Subject .ObjectType == query .SelectObjects .WhereSubject .Type &&
379
379
matchedWarrant .Subject .ObjectId == query .SelectObjects .WhereSubject .Id ) {
380
- resultSet .Add (matchedWarrant .ObjectType , matchedWarrant .ObjectId , relation , matchedWarrant , level > 0 )
380
+ resultSet .Add (matchedWarrant .ObjectType , matchedWarrant .ObjectId , relation , matchedWarrant , false )
381
381
}
382
382
}
383
383
384
384
if query .Expand {
385
- implicitResultSet , err := svc .queryRule (ctx , query , level , objectTypeDef .Relations [relation ])
385
+ implicitResultSet , err := svc .queryRule (ctx , query , level + 1 , relation , objectTypeDef .Relations [relation ])
386
386
if err != nil {
387
387
return nil , err
388
388
}
389
389
390
- resultSet = resultSet .Union (implicitResultSet )
390
+ for res := implicitResultSet .List (); res != nil ; res = res .Next () {
391
+ resultSet .Add (res .ObjectType , res .ObjectId , relation , res .Warrant , res .IsImplicit || level > 0 )
392
+ }
391
393
}
392
394
393
395
return resultSet , nil
@@ -417,7 +419,7 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
417
419
for _ , matchedWarrant := range matchedWarrants {
418
420
if matchedWarrant .Subject .Relation != "" {
419
421
// handle group warrants
420
- userset , err := svc .query (ctx , Query {
422
+ subset , err := svc .query (ctx , Query {
421
423
Expand : query .Expand ,
422
424
SelectSubjects : & SelectSubjects {
423
425
Relations : []string {matchedWarrant .Subject .Relation },
@@ -433,21 +435,23 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
433
435
return nil , err
434
436
}
435
437
436
- for res := userset .List (); res != nil ; res = res .Next () {
437
- resultSet .Add (res .ObjectType , res .ObjectId , relation , matchedWarrant , level > 0 )
438
+ for sub := subset .List (); sub != nil ; sub = sub .Next () {
439
+ resultSet .Add (sub .ObjectType , sub .ObjectId , relation , matchedWarrant , sub . IsImplicit || level > 0 )
438
440
}
439
441
} else if query .SelectSubjects .SubjectTypes [0 ] == matchedWarrant .Subject .ObjectType {
440
- resultSet .Add (matchedWarrant .Subject .ObjectType , matchedWarrant .Subject .ObjectId , relation , matchedWarrant , level > 0 )
442
+ resultSet .Add (matchedWarrant .Subject .ObjectType , matchedWarrant .Subject .ObjectId , relation , matchedWarrant , false )
441
443
}
442
444
}
443
445
444
446
if query .Expand {
445
- implicitResultSet , err := svc .queryRule (ctx , query , level , objectTypeDef .Relations [relation ])
447
+ implicitResultSet , err := svc .queryRule (ctx , query , level + 1 , relation , objectTypeDef .Relations [relation ])
446
448
if err != nil {
447
449
return nil , err
448
450
}
449
451
450
- return resultSet .Union (implicitResultSet ), nil
452
+ for res := implicitResultSet .List (); res != nil ; res = res .Next () {
453
+ resultSet .Add (res .ObjectType , res .ObjectId , relation , res .Warrant , res .IsImplicit || level > 0 )
454
+ }
451
455
}
452
456
453
457
return resultSet , nil
@@ -456,14 +460,14 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
456
460
}
457
461
}
458
462
459
- func (svc QueryService ) queryRule (ctx context.Context , query Query , level int , rule objecttype.RelationRule ) (* ResultSet , error ) {
463
+ func (svc QueryService ) queryRule (ctx context.Context , query Query , level int , relation string , rule objecttype.RelationRule ) (* ResultSet , error ) {
460
464
switch rule .InheritIf {
461
465
case "" :
462
466
return NewResultSet (), nil
463
467
case objecttype .InheritIfAllOf :
464
468
var resultSet * ResultSet
465
469
for _ , r := range rule .Rules {
466
- res , err := svc .queryRule (ctx , query , level , r )
470
+ res , err := svc .queryRule (ctx , query , level , relation , r )
467
471
if err != nil {
468
472
return nil , err
469
473
}
@@ -479,7 +483,7 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
479
483
case objecttype .InheritIfAnyOf :
480
484
var resultSet * ResultSet
481
485
for _ , r := range rule .Rules {
482
- res , err := svc .queryRule (ctx , query , level , r )
486
+ res , err := svc .queryRule (ctx , query , level , relation , r )
483
487
if err != nil {
484
488
return nil , err
485
489
}
@@ -498,15 +502,25 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
498
502
switch {
499
503
case query .SelectObjects != nil :
500
504
if rule .OfType == "" && rule .WithRelation == "" {
501
- return svc .query (ctx , Query {
505
+ results , err := svc .query (ctx , Query {
502
506
Expand : true ,
503
507
SelectObjects : & SelectObjects {
504
508
ObjectTypes : query .SelectObjects .ObjectTypes ,
505
509
WhereSubject : query .SelectObjects .WhereSubject ,
506
510
Relations : []string {rule .InheritIf },
507
511
},
508
512
Context : query .Context ,
509
- }, level + 1 )
513
+ }, 0 )
514
+ if err != nil {
515
+ return nil , err
516
+ }
517
+
518
+ resultSet := NewResultSet ()
519
+ for res := results .List (); res != nil ; res = res .Next () {
520
+ resultSet .Add (res .ObjectType , res .ObjectId , relation , res .Warrant , res .IsImplicit || level > 0 )
521
+ }
522
+
523
+ return resultSet , nil
510
524
} else {
511
525
indirectWarrants , err := svc .listWarrants (ctx , warrant.FilterParams {
512
526
ObjectType : rule .OfType ,
@@ -535,27 +549,39 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
535
549
Relations : []string {rule .WithRelation },
536
550
},
537
551
Context : query .Context ,
538
- }, level + 1 )
552
+ }, 0 )
539
553
if err != nil {
540
554
return nil , err
541
555
}
542
556
543
- resultSet = resultSet .Union (inheritedResults )
557
+ for res := inheritedResults .List (); res != nil ; res = res .Next () {
558
+ resultSet .Add (res .ObjectType , res .ObjectId , relation , res .Warrant , res .IsImplicit || level > 0 )
559
+ }
544
560
}
545
561
546
562
return resultSet , nil
547
563
}
548
564
case query .SelectSubjects != nil :
549
565
if rule .OfType == "" && rule .WithRelation == "" {
550
- return svc .query (ctx , Query {
566
+ results , err := svc .query (ctx , Query {
551
567
Expand : true ,
552
568
SelectSubjects : & SelectSubjects {
553
569
SubjectTypes : query .SelectSubjects .SubjectTypes ,
554
570
Relations : []string {rule .InheritIf },
555
571
ForObject : query .SelectSubjects .ForObject ,
556
572
},
557
573
Context : query .Context ,
558
- }, level + 1 )
574
+ }, 0 )
575
+ if err != nil {
576
+ return nil , err
577
+ }
578
+
579
+ resultSet := NewResultSet ()
580
+ for res := results .List (); res != nil ; res = res .Next () {
581
+ resultSet .Add (res .ObjectType , res .ObjectId , relation , res .Warrant , res .IsImplicit || level > 0 )
582
+ }
583
+
584
+ return resultSet , nil
559
585
} else {
560
586
userset , err := svc .listWarrants (ctx , warrant.FilterParams {
561
587
ObjectType : query .SelectSubjects .ForObject .Type ,
@@ -584,12 +610,14 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
584
610
},
585
611
},
586
612
Context : query .Context ,
587
- }, level + 1 )
613
+ }, 0 )
588
614
if err != nil {
589
615
return nil , err
590
616
}
591
617
592
- resultSet = resultSet .Union (subset )
618
+ for res := subset .List (); res != nil ; res = res .Next () {
619
+ resultSet .Add (res .ObjectType , res .ObjectId , relation , res .Warrant , res .IsImplicit || level > 0 )
620
+ }
593
621
}
594
622
595
623
return resultSet , nil
@@ -628,11 +656,20 @@ func objectKey(objectType string, objectId string) string {
628
656
return fmt .Sprintf ("%s:%s" , objectType , objectId )
629
657
}
630
658
631
- func objectTypeAndObjectIdFromCursor (cursor * service.Cursor ) (string , string , error ) {
632
- objectType , objectId , found := strings .Cut (cursor .ID (), ":" )
659
+ func objectRelationKey (objectType string , objectId string , relation string ) string {
660
+ return fmt .Sprintf ("%s:%s#%s" , objectType , objectId , relation )
661
+ }
662
+
663
+ func objectTypeAndObjectIdAndRelationFromCursor (cursor * service.Cursor ) (string , string , string , error ) {
664
+ objectType , objectIdRelation , found := strings .Cut (cursor .ID (), ":" )
665
+ if ! found {
666
+ return "" , "" , "" , errors .New ("invalid cursor" )
667
+ }
668
+
669
+ objectId , relation , found := strings .Cut (objectIdRelation , "#" )
633
670
if ! found {
634
- return "" , "" , errors .New ("invalid cursor" )
671
+ return "" , "" , "" , errors .New ("invalid cursor" )
635
672
}
636
673
637
- return objectType , objectId , nil
674
+ return objectType , objectId , relation , nil
638
675
}
0 commit comments