1
1
2
+ from collections import defaultdict
2
3
from typing import Optional
3
4
from sqlmodel import Session , func , select , update
4
5
@@ -25,17 +26,10 @@ async def reset_user_oid(session: Session, oid: int):
25
26
select (
26
27
UserModel .id ,
27
28
UserModel .oid ,
28
- func .coalesce (
29
- func .array_remove (
30
- func .array_agg (UserWsModel .oid ),
31
- None
32
- ),
33
- []
34
- ).label ("oid_list" )
29
+ UserWsModel .oid .label ("associated_oid" )
35
30
)
36
31
.join (UserWsModel , UserModel .id == UserWsModel .uid , isouter = True )
37
32
.where (UserModel .id != 1 )
38
- .group_by (UserModel .id )
39
33
)
40
34
41
35
user_filter = (
@@ -46,17 +40,30 @@ async def reset_user_oid(session: Session, oid: int):
46
40
)
47
41
stmt = stmt .where (UserModel .id .in_ (user_filter ))
48
42
49
- result_user_list = session .exec (stmt )
50
- for row in result_user_list :
51
- result_dict = {}
52
- for item , key in zip (row , row ._fields ):
53
- result_dict [key ] = item
54
-
55
- origin_oid = result_dict ['oid' ]
56
- oid_list : list = list (filter (lambda x : x != oid , result_dict ['oid_list' ]))
43
+ result_user_list = session .exec (stmt ).all ()
44
+ if not result_user_list :
45
+ return
46
+
47
+ merged = defaultdict (list )
48
+ extra_attrs = {}
49
+
50
+ for (id , oid , associated_oid ) in result_user_list :
51
+ item = {"id" : id , "oid" : oid }
52
+ merged [id ].append (associated_oid )
53
+ if id not in extra_attrs :
54
+ extra_attrs [id ] = {k : v for k , v in item .items ()}
55
+
56
+ # 组合结果
57
+ result = [
58
+ {** extra_attrs [user_id ], "oid_list" : oid_list }
59
+ for user_id , oid_list in merged .items ()
60
+ ]
61
+
62
+ for row in result :
63
+ origin_oid = row ['oid' ]
64
+ oid_list : list = list (filter (lambda x : x != oid , row ['oid_list' ]))
57
65
if origin_oid not in oid_list :
58
- result_dict ['oid' ] = oid_list [0 ] if oid_list else 0
59
- if result_dict ['oid' ] != origin_oid :
60
- result_dict .pop ("oid_list" , None )
61
- update_stmt = update (UserModel ).where (UserModel .id == result_dict ['id' ]).values (oid = result_dict ['oid' ])
66
+ row ['oid' ] = oid_list [0 ] if oid_list else 0
67
+ if row ['oid' ] != origin_oid :
68
+ update_stmt = update (UserModel ).where (UserModel .id == row ['id' ]).values (oid = row ['oid' ])
62
69
session .exec (update_stmt )
0 commit comments